|
|
|
|
|
|
from mlagents_envs.timers import timed |
|
|
|
|
|
|
|
from mlagents.trainers.settings import TrainerSettings |
|
|
|
from mlagents.trainers.torch.networks import ( |
|
|
|
SharedActorCritic, |
|
|
|
SeparateActorCritic, |
|
|
|
GlobalSteps, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.networks import SimpleActor, GlobalSteps |
|
|
|
|
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
|
|
|
) # could be much simpler if TorchPolicy is nn.Module |
|
|
|
self.grads = None |
|
|
|
|
|
|
|
reward_signal_configs = trainer_settings.reward_signals |
|
|
|
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
# reward_signal_configs = trainer_settings.reward_signals |
|
|
|
# reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
if separate_critic: |
|
|
|
ac_class = SeparateActorCritic |
|
|
|
else: |
|
|
|
ac_class = SharedActorCritic |
|
|
|
self.actor_critic = ac_class( |
|
|
|
ac_class = SimpleActor |
|
|
|
# if separate_critic: |
|
|
|
# ac_class = SimpleActor |
|
|
|
# else: |
|
|
|
# ac_class = SharedActorCritic |
|
|
|
self.actor = ac_class( |
|
|
|
stream_names=reward_signal_names, |
|
|
|
conditional_sigma=self.condition_sigma_on_obs, |
|
|
|
tanh_squash=tanh_squash, |
|
|
|
) |
|
|
|
|
|
|
self.m_size = self.actor_critic.memory_size |
|
|
|
self.m_size = self.actor.memory_size |
|
|
|
self.actor_critic.to(default_device()) |
|
|
|
self.actor.to(default_device()) |
|
|
|
self._clip_action = not tanh_squash |
|
|
|
|
|
|
|
@property |
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
if self.normalize: |
|
|
|
self.actor_critic.update_normalization(buffer) |
|
|
|
self.actor.update_normalization(buffer) |
|
|
|
|
|
|
|
@timed |
|
|
|
def sample_actions( |
|
|
|
|
|
|
:param seq_len: Sequence length when using RNN. |
|
|
|
:return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories. |
|
|
|
""" |
|
|
|
actions, log_probs, entropies, memories = self.actor_critic.get_action_stats( |
|
|
|
actions, log_probs, entropies, memories = self.actor.get_action_and_stats( |
|
|
|
obs, masks, memories, seq_len |
|
|
|
) |
|
|
|
return (actions, log_probs, entropies, memories) |
|
|
|
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
seq_len: int = 1, |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value( |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor]: |
|
|
|
log_probs, entropies = self.actor.get_stats( |
|
|
|
return log_probs, entropies, value_heads |
|
|
|
return log_probs, entropies |
|
|
|
|
|
|
|
@timed |
|
|
|
def evaluate( |
|
|
|
|
|
|
return ActionInfo( |
|
|
|
action=run_out.get("action"), |
|
|
|
env_action=run_out.get("env_action"), |
|
|
|
value=run_out.get("value"), |
|
|
|
outputs=run_out, |
|
|
|
agent_ids=list(decision_requests.agent_id), |
|
|
|
) |
|
|
|
|
|
|
return self.get_current_step() |
|
|
|
|
|
|
|
def load_weights(self, values: List[np.ndarray]) -> None: |
|
|
|
self.actor_critic.load_state_dict(values) |
|
|
|
self.actor.load_state_dict(values) |
|
|
|
return copy.deepcopy(self.actor_critic.state_dict()) |
|
|
|
return copy.deepcopy(self.actor.state_dict()) |
|
|
|
return {"Policy": self.actor_critic, "global_step": self.global_step} |
|
|
|
return {"Policy": self.actor, "global_step": self.global_step} |