浏览代码

separate forward into q_net and baseline

/develop/coma-noact
Andrew Cohen 4 年前
当前提交
44088167
共有 1 个文件被更改,包括 125 次插入38 次删除
  1. 163
      ml-agents/mlagents/trainers/torch/networks.py

163
ml-agents/mlagents/trainers/torch/networks.py


attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
return attn_mask
def forward(
def q_net(
value_inputs: List[List[torch.Tensor]],
q_inputs: List[List[torch.Tensor]],
q_actions: List[AgentAction],
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
# Tensors that go into ResidualSelfAttention
self_attn_inputs = []
concat_f_inp = []
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
# Get the self encoding separately, but keep it in the entities
concat_enc_q_obs = []
for inputs, actions in zip(q_inputs, q_actions):
f_inp = torch.stack(concat_f_inp, dim=1)
self_attn_masks.append(self._get_masks_from_nans(obs))
encoding, memories = self.forward(
f_inp,
None,
self_attn_masks,
memories=memories,
sequence_length=sequence_length,
)
return encoding, memories
def baseline(
self,
self_obs: List[List[torch.Tensor]],
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
self_attn_masks = []
f_inp = None
concat_f_inp = []
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]

cat_encodes = [
torch.cat(encodes, dim=-1),
actions.to_flat(self.action_spec.discrete_branches),
action.to_flat(self.action_spec.discrete_branches),
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=1))
if concat_enc_q_obs:
q_input_concat = torch.stack(concat_enc_q_obs, dim=1)
self_attn_masks.append(self._get_masks_from_nans(q_inputs))
encoded_obs_action = self.obs_action_encoder(None, q_input_concat)
self_attn_inputs.append(encoded_obs_action)
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
if concat_f_inp:
f_inp = torch.stack(concat_f_inp, dim=1)
self_attn_masks.append(self._get_masks_from_nans(obs))
# Get the self encoding separately, but keep it in the entities
for inputs in value_inputs:
for inputs in self_obs:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]

concat_encoded_obs.append(torch.cat(encodes, dim=-1))
if concat_encoded_obs:
value_input_concat = torch.stack(concat_encoded_obs, dim=1)
# Get the mask from nans
self_attn_masks.append(self._get_masks_from_nans(value_inputs))
encoded_obs = self.obs_encoder(None, value_input_concat)
self_attn_inputs.append(encoded_obs)
g_inp = torch.stack(concat_encoded_obs, dim=1)
# Get the mask from nans
self_attn_masks.append(self._get_masks_from_nans(self_obs))
encoding, memories = self.forward(
f_inp,
g_inp,
self_attn_masks,
memories=memories,
sequence_length=sequence_length,
)
return encoding, memories
def forward(
self,
f_enc: torch.Tensor,
g_enc: torch.Tensor,
self_attn_masks: List[torch.Tensor],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
self_attn_inputs = []
if f_enc is not None:
self_attn_inputs.append(self.obs_action_encoder(None, f_enc))
if g_enc is not None:
self_attn_inputs.append(self.obs_encoder(None, g_enc))
encoded_entity = torch.cat(self_attn_inputs, dim=1)
encoded_state = self.self_attn(encoded_entity, self_attn_masks)

else:
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def q_net(
self,
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoding, memories = self.network_body.q_net(
obs, actions, memories, sequence_length
)
output = self.value_heads(encoding)
return output, memories
def baseline(
self,
self_obs: List[List[torch.Tensor]],
obs: List[List[torch.Tensor]],
actions: List[AgentAction],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encoding, memories = self.network_body.baseline(
self_obs, obs, actions, memories, sequence_length
)
output = self.value_heads(encoding)
return output, memories
def forward(
self,

actions: AgentAction,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
critic_obs: List[List[torch.Tensor]] = None,
team_obs: List[List[torch.Tensor]] = None,
team_act: List[AgentAction] = None,
all_net_inputs = [inputs]
if critic_obs is not None and critic_obs:
all_net_inputs.extend(critic_obs)
mar_value_outputs, _ = self.critic(
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length
)
value_outputs, critic_mem_out = self.critic(
critic_obs,
[inputs],
[actions],
all_obs = [inputs]
if team_obs is not None and team_obs:
all_obs.extend(team_obs)
all_acts = [actions]
if team_act is not None and team_act:
all_acts.extend(team_act)
baseline_outputs, _ = self.critic.baseline(
inputs,
team_obs,
team_act,
if mar_value_outputs is None:
mar_value_outputs = value_outputs
value_outputs, critic_mem_out = self.critic.q_net(
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length
)
# if mar_value_outputs is None:
# mar_value_outputs = value_outputs
if actor_mem is not None:
# Make memories with the actor mem unchanged

return value_outputs, mar_value_outputs, memories_out
return value_outputs, baseline_outputs, memories_out
def get_stats_and_value(
self,

正在加载...
取消
保存