浏览代码

Fix to-flat and add tests

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
56d4c1f9
共有 2 个文件被更改,包括 24 次插入6 次删除
  1. 11
      ml-agents/mlagents/trainers/tests/torch/test_agent_action.py
  2. 19
      ml-agents/mlagents/trainers/torch/agent_action.py

11
ml-agents/mlagents/trainers/tests/torch/test_agent_action.py


def test_to_flat():
# Both continuous and discrete
aa = AgentAction(
torch.tensor([[1.0, 1.0, 1.0]]), [torch.tensor([2]), torch.tensor([1])]
)

).all()
# Just continuous
aa = AgentAction(torch.tensor([[1.0, 1.0, 1.0]]), None)
flattened_actions = aa.to_flat([])
assert torch.eq(flattened_actions, torch.tensor([1, 1, 1])).all()
# Just discrete
aa = AgentAction(torch.tensor([]), [torch.tensor([2]), torch.tensor([1])])
flattened_actions = aa.to_flat([3, 3])
assert torch.eq(flattened_actions, torch.tensor([0, 0, 1, 0, 1, 0])).all()

19
ml-agents/mlagents/trainers/torch/agent_action.py


discrete_list: Optional[List[torch.Tensor]]
@property
def discrete_tensor(self):
def discrete_tensor(self) -> torch.Tensor:
return torch.stack(self.discrete_list, dim=-1)
if self.discrete_list is not None and len(self.discrete_list) > 0:
return torch.stack(self.discrete_list, dim=-1)
else:
return torch.empty(0)
def to_action_tuple(self, clip: bool = False) -> ActionTuple:
"""

:param discrete_branches: List of sizes for discrete actions.
:return: Tensor of flattened actions.
"""
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
discrete_oh = torch.cat(discrete_oh, dim=1)
# if there are any discrete actions, create one-hot
if self.discrete_list is not None and self.discrete_list:
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
discrete_oh = torch.cat(discrete_oh, dim=1)
else:
discrete_oh = torch.empty(0)
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)
正在加载...
取消
保存