浏览代码

[refactor] Remove critic pass during inference (#4743)

/MLA-1734-demo-provider
GitHub 4 年前
当前提交
ad5f878c
共有 3 个文件被更改,包括 76 次插入19 次删除
  1. 2
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 2
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  3. 91
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/policy/torch_policy.py


:param seq_len: Sequence length when using RNN.
:return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories.
"""
actions, log_probs, entropies, _, memories = self.actor_critic.get_action_stats_and_value(
actions, log_probs, entropies, memories = self.actor_critic.get_action_stats(
vec_obs, vis_obs, masks, memories, seq_len
)
return (actions, log_probs, entropies, memories)

2
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


SAC_TORCH_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_networksettings,
max_steps=4000,
max_steps=3500,
)
check_environment_trains(env, {BRAIN_NAME: config})

91
ml-agents/mlagents/trainers/torch/networks.py


"""
pass
def get_action_stats(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
"""
Returns sampled actions.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.
:param masks: If using discrete actions, a Tensor of action masks.
:param memories: If using memory, a Tensor of initial memories.
:param sequence_length: If using memory, the sequence length.
:return: A Tuple of AgentAction, ActionLogProbs, entropies, and memories.
Memories will be None if not using memory.
"""
pass
@abc.abstractmethod
def forward(
self,

AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
"""
Returns distributions, from which actions can be sampled, and value estimates.
Returns sampled actions and value estimates.
If memory is enabled, return the memories as well.
:param vec_inputs: A List of vector inputs as tensors.
:param vis_inputs: A List of visual inputs as tensors.

def update_normalization(self, vector_obs: List[torch.Tensor]) -> None:
self.network_body.update_normalization(vector_obs)
def get_action_stats(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
return action, log_probs, entropies, memories
def forward(
self,
vec_inputs: List[torch.Tensor],

def memory_size(self) -> int:
return self.network_body.memory_size + self.critic.memory_size
def _get_actor_critic_mem(
self, memories: Optional[torch.Tensor] = None
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
if self.use_lstm and memories is not None:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
return actor_mem, critic_mem
def critic_pass(
self,
vec_inputs: List[torch.Tensor],

) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = None, None
if self.use_lstm:
# Use only the back half of memories for critic
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1)
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
value_outputs, critic_mem_out = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
encoding, actor_mem_outs = self.network_body(
vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length
)

return log_probs, entropies, value_outputs
def get_action_stats(
self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
action, log_probs, entropies, actor_mem_out = super().get_action_stats(
vec_inputs,
vis_inputs,
masks=masks,
memories=actor_mem,
sequence_length=sequence_length,
)
if critic_mem is not None:
# Make memories with the actor mem unchanged
memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1)
else:
memories_out = None
return action, log_probs, entropies, memories_out
def get_action_stats_and_value(
self,
vec_inputs: List[torch.Tensor],

) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
if self.use_lstm:
# Use only the back half of memories for critic and actor
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
else:
critic_mem = None
actor_mem = None
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
encoding, actor_mem_outs = self.network_body(
vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length
)

正在加载...
取消
保存