浏览代码

Refactor multi input network slightly

/develop/centralizedcritic/mm
Ervin Teng 4 年前
当前提交
bf7195f1
共有 1 个文件被更改,包括 23 次插入18 次删除
  1. 41
      ml-agents/mlagents/trainers/torch/networks.py

41
ml-agents/mlagents/trainers/torch/networks.py


return encoding, memories
# NOTE: this class will be replaced with a multi-head attention when the time comes
class MultiInputNetworkBody(nn.Module):
def __init__(
self,

def forward(
self,
all_net_inputs: List[List[torch.Tensor]],
self_inputs: List[torch.Tensor],
aux_inputs: Optional[List[List[torch.Tensor]]] = None,
inputs = all_net_inputs[0]
obs_input = inputs[idx]
obs_input = self_inputs[idx]
all_net_inputs = [self_inputs]
if aux_inputs is not None:
all_net_inputs.extend(aux_inputs)
# Get attention masks by grabbing an arbitrary obs across all the agents
# Since these are raw obs, the padded values are still NaN
only_first_obs = [_all_obs[0] for _all_obs in all_net_inputs]

def forward(
self,
inputs: List[List[torch.Tensor]],
inputs: List[torch.Tensor],
aux_inputs: Optional[List[List[torch.Tensor]]] = None,
inputs, actions, memories, sequence_length
inputs, actions, memories, sequence_length, aux_inputs
)
output = self.value_heads(encoding)
return output, memories

masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:

critic_obs: List[List[torch.Tensor]] = None,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
actor_mem, critic_mem = self._get_actor_critic_mem(memories)
all_net_inputs = [inputs]
if critic_obs is not None:
all_net_inputs.extend(critic_obs)
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
inputs,
memories=critic_mem,
aux_inputs=critic_obs,
sequence_length=sequence_length,
)
if actor_mem is not None:
# Make memories with the actor mem unchanged

inputs, memories=actor_mem, sequence_length=sequence_length
)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
all_net_inputs = [inputs]
if critic_obs is not None:
all_net_inputs.extend(critic_obs)
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
inputs,
memories=critic_mem,
aux_inputs=critic_obs,
sequence_length=sequence_length,
)
return log_probs, entropies, value_outputs

inputs, memories=actor_mem, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
all_net_inputs = [inputs]
if critic_obs is not None:
all_net_inputs.extend(critic_obs)
all_net_inputs, memories=critic_mem, sequence_length=sequence_length
inputs,
memories=critic_mem,
aux_inputs=critic_obs,
sequence_length=sequence_length,
)
if self.use_lstm:
mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)

正在加载...
取消
保存