浏览代码

clean ups

/develop/action-slice
Andrew Cohen 4 年前
当前提交
5d517c5e
共有 3 个文件被更改,包括 17 次插入16 次删除
  1. 4
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  2. 2
      ml-agents/mlagents/trainers/tests/mock_brain.py
  3. 27
      ml-agents/mlagents/trainers/torch/networks.py

4
ml-agents/mlagents/trainers/coma/optimizer_torch.py


encoding, memories = self.network_body(
obs_only=obs,
obs=None,
actions=None,
obs=[],
actions=[],
memories=memories,
sequence_length=sequence_length,
)

2
ml-agents/mlagents/trainers/tests/mock_brain.py


behavior_spec: BehaviorSpec,
memory_size: int = 10,
exclude_key_list: List[str] = None,
num_other_agents_in_group: int = 0,
) -> AgentBuffer:
trajectory = make_fake_trajectory(
length,

num_other_agents_in_group=num_other_agents_in_group,
)
buffer = trajectory.to_agentbuffer()
# If a key_list was given, remove those keys

27
ml-agents/mlagents/trainers/torch/networks.py


self,
obs_only: List[List[torch.Tensor]],
obs: List[List[torch.Tensor]],
actions: Optional[List[AgentAction]],
actions: List[AgentAction],
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:

concat_f_inp = []
if actions is not None:
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
if concat_f_inp:
f_inp = torch.stack(concat_f_inp, dim=1)

正在加载...
取消
保存