浏览代码

Merge branch 'develop-add-fire-checkpoint' of https://github.com/Unity-Technologies/ml-agents into develop-add-fire-checkpoint

/develop/add-fire/ckpt-2
Ruo-Ping Dong 4 年前
当前提交
0f148209
共有 1 个文件被更改,包括 15 次插入16 次删除
  1. 31
      ml-agents/mlagents/trainers/torch/networks.py

31
ml-agents/mlagents/trainers/torch/networks.py


memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
vec_encodes = []
encodes = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]
if actions is not None:

vec_encodes.append(hidden)
encodes.append(hidden)
vis_encodes = []
vis_encodes.append(hidden)
encodes.append(hidden)
if len(vec_encodes) > 0 and len(vis_encodes) > 0:
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
encoding = torch.stack(
[vec_encodes_tensor, vis_encodes_tensor], dim=-1
).sum(dim=-1)
elif len(vec_encodes) > 0:
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
elif len(vis_encodes) > 0:
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
else:
if len(encodes) == 0:
# Constants don't work in Barracuda
encoding = encodes[0]
if len(encodes) > 1:
for _enc in encodes[1:]:
encoding += _enc
if self.use_lstm:
encoding = encoding.view([sequence_length, -1, self.h_size])

)
action_list = self.sample_action(dists)
sampled_actions = torch.stack(action_list, dim=-1)
if self.act_type == ActionType.CONTINUOUS:
log_probs = dists[0].log_prob(sampled_actions)
else:
log_probs = dists[0].all_log_prob()
dists[0].pdf(sampled_actions),
log_probs,
self.version_number,
self.memory_size,
self.is_continuous_int,

正在加载...
取消
保存