浏览代码

add action_out to dist

/develop/actions-out
Andrew Cohen 4 年前
当前提交
85602279
共有 3 个文件被更改,包括 8 次插入10 次删除
  1. 6
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  2. 6
      ml-agents/mlagents/trainers/torch/distributions.py
  3. 6
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents/mlagents/trainers/tests/torch/test_networks.py


[sample_obs], [], masks=masks
)
for act in actions:
# This is different from above for ONNX export
if action_type == ActionType.CONTINUOUS:
assert act.shape == (act_size[0], 1)
else:
assert act.shape == tuple(act_size)
assert act.shape == tuple(act_size)
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)

6
ml-agents/mlagents/trainers/torch/distributions.py


def entropy(self):
return 0.5 * torch.log(2 * math.pi * math.e * self.std + EPSILON)
def action_out(self):
return self.sample()
class TanhGaussianDistInstance(GaussianDistInstance):
def __init__(self, mean, std):

def entropy(self):
return -torch.sum(self.probs * torch.log(self.probs), dim=-1)
def action_out(self):
return self.all_log_prob()
class GaussianDistribution(nn.Module):

6
ml-agents/mlagents/trainers/torch/networks.py


Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
if self.act_type == ActionType.CONTINUOUS:
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
else:
action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1)
action_out = torch.cat([dist.action_out() for dist in dists], dim=1)
return (
action_out,
self.version_number,

正在加载...
取消
保存