浏览代码

Support multiple goals in networkbody

/goal-conditioning/new
Arthur Juliani 4 年前
当前提交
1d106816
共有 1 个文件被更改,包括 13 次插入12 次删除
  1. 25
      ml-agents/mlagents/trainers/torch/networks.py

25
ml-agents/mlagents/trainers/torch/networks.py


memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
goal_signal = None
obs_encodes = []
goal_encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
processed_obs = processor(obs_input)

):
encodes.append(processed_obs)
obs_encodes.append(processed_obs)
if goal_signal is not None:
raise Exception("TODO : Cannot currently handle more than one goal")
goal_signal = processed_obs
goal_encodes.append(processed_obs)
else:
raise Exception("TODO : Something other than a goal or observation was passed to the agent.")
if len(encodes) == 0:
if len(obs_encodes) == 0:
inputs = torch.cat(encodes + [actions], dim=-1)
obs_inputs = torch.cat(obs_encodes + [actions], dim=-1)
inputs = torch.cat(encodes, dim=-1)
obs_inputs = torch.cat(obs_encodes, dim=-1)
if goal_signal is None:
encoding = self.linear_encoder(inputs)
if len(goal_encodes) == 0:
encoding = self.linear_encoder(obs_inputs)
encoding = self.linear_encoder(inputs, goal_signal)
goal_inputs = torch.cat(goal_encodes, dim=-1)
encoding = self.linear_encoder(obs_inputs, goal_inputs)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

正在加载...
取消
保存