浏览代码

Merge pull request #5158 from Unity-Technologies/v2-staging-rebase-2-fix-networks

Fixing networks.py for the merge
/goal-conditioning/sensors-3-pytest-fix
GitHub 4 年前
当前提交
ea2b2f20
共有 1 个文件被更改,包括 0 次插入22 次删除
  1. 22
      ml-agents/mlagents/trainers/torch/networks.py

22
ml-agents/mlagents/trainers/torch/networks.py


At this moment, torch.onnx.export() doesn't accept None as tensor to be exported,
so the size of return tuple varies with action spec.
"""
# This code will convert the vec and vis obs into a list of inputs for the network
concatenated_vec_obs = vec_inputs[0]
inputs = []
start = 0
end = 0
vis_index = 0
var_len_index = 0
for i, enc in enumerate(self.network_body.observation_encoder.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs
vec_size = self.network_body.observation_encoder.embedding_sizes[i]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end
elif isinstance(enc, EntityEmbedding):
inputs.append(var_len_inputs[var_len_index])
var_len_index += 1
else: # visual input
inputs.append(vis_inputs[vis_index])
vis_index += 1
# End of code to convert the vec and vis obs into a list of inputs for the network
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=1
)

正在加载...
取消
保存