浏览代码

resolve conflicts

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
ae920478
共有 2 个文件被更改,包括 4 次插入3 次删除
  1. 4
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  2. 3
      ml-agents/mlagents/trainers/torch/utils.py

4
ml-agents/mlagents/trainers/tests/torch/test_policy.py


seq_len=policy.sequence_length,
)
assert log_probs.shape == (64, policy.behavior_spec.action_spec.size)
assert entropy.shape == (64, policy.behavior_spec.action_spec.size)
assert entropy.shape == (64,)
for val in values.values():
assert val.shape == (64,)

)
else:
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size)
assert entropies.shape == (64, policy.behavior_spec.action_spec.size)
assert entropies.shape == (64,)
if rnn:
assert memories.shape == (1, 1, policy.m_size)

3
ml-agents/mlagents/trainers/torch/utils.py


all_probs = None
else:
all_probs = torch.cat(all_probs_list, dim=-1)
return log_probs, entropies, all_probs
entropy_sum = torch.sum(entropies, dim=1)
return log_probs, entropy_sum, all_probs
@staticmethod
def masked_mean(tensor: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:

正在加载...
取消
保存