浏览代码

Some more fixes

/develop/centralizedcritic
Ervin Teng 4 年前
当前提交
d02a1033
共有 2 个文件被更改,包括 13 次插入6 次删除
  1. 7
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 12
      ml-agents/mlagents/trainers/torch/networks.py

7
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


critic_obs = [
ModelUtils.list_to_tensor_list(_agent_obs) for _agent_obs in critic_obs_np
]
ModelUtils.list_to_tensor_list(_obs) for _obs in next_critic_obs
ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_critic_obs
]
# Expand dimensions of next critic obs
next_critic_obs = [
[_obs.unsqueeze(0) for _obs in _list_obs] for _list_obs in next_critic_obs
]
memory = torch.zeros([1, 1, self.policy.m_size])

12
ml-agents/mlagents/trainers/torch/networks.py


self.normalize = network_settings.normalize
self.use_lstm = network_settings.memory is not None
# Scale network depending on num agents
self.h_size = network_settings.hidden_units * num_obs_heads
self.h_size = network_settings.hidden_units
self.m_size = (
network_settings.memory.memory_size
if network_settings.memory is not None

normalize=self.normalize,
)
self.processors.append(_proc)
encoder_input_size += _input_size
encoder_input_size += sum(_input_size)
total_enc_size = encoder_input_size + encoded_act_size
self.linear_encoder = LinearEncoder(

if network_settings.memory is not None:
encoding_size = network_settings.memory.memory_size // 2
else:
encoding_size = network_settings.hidden_units * num_agents
encoding_size = network_settings.hidden_units
self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream)
def forward(

critic_obs: Optional[List[List[torch.Tensor]]] = None,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
encoding, memories = self.network_body(
inputs, memories=memories, sequence_length=sequence_length
inputs, memories=memories, sequence_length=sequence_length,
)
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions)
value_outputs = self.value_heads(encoding)

if critic_obs is not None:
all_net_inputs.extend(critic_obs)
value_outputs, critic_mem_outs = self.critic(
inputs, memories=critic_mem, sequence_length=sequence_length
all_net_inputs,
memories=critic_mem,
sequence_length=sequence_length,
)
return log_probs, entropies, value_outputs

正在加载...
取消
保存