浏览代码

Action slice (#5047)

* add slice function to agent action

* add type/docstring to slice

* add test
/develop/coma2/fixgroup
GitHub 4 年前
当前提交
d2635e58
共有 3 个文件被更改,包括 27 次插入15 次删除
  1. 4
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  2. 11
      ml-agents/mlagents/trainers/tests/torch/test_agent_action.py
  3. 27
      ml-agents/mlagents/trainers/torch/agent_action.py

4
ml-agents/mlagents/trainers/coma/optimizer_torch.py


first_seq_obs = _obs[0:first_seq_len]
seq_obs.append(first_seq_obs)
team_seq_obs.append(seq_obs)
_act = team_action[0:first_seq_len]
_act = team_action.slice(0, first_seq_len)
team_seq_act.append(_act)
# For the first sequence, the initial memory should be the one at the

first_seq_obs = _obs[start:end]
seq_obs.append(first_seq_obs)
team_seq_obs.append(seq_obs)
_act = team_action[start:end]
_act = team_action.slice(start, end)
team_seq_act.append(_act)
all_seq_obs = self_seq_obs + team_seq_obs

11
ml-agents/mlagents/trainers/tests/torch/test_agent_action.py


assert (agent_1_act.discrete_tensor[3:] == 0).all()
def test_slice():
# Both continuous and discrete
aa = AgentAction(
torch.tensor([[1.0], [1.0], [1.0]]),
[torch.tensor([2, 1, 0]), torch.tensor([1, 2, 0])],
)
saa = aa.slice(0, 2)
assert saa.continuous_tensor.shape == (2, 1)
assert saa.discrete_tensor.shape == (2, 2)
def test_to_flat():
# Both continuous and discrete
aa = AgentAction(

27
ml-agents/mlagents/trainers/torch/agent_action.py


continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
def __getitem__(self, index):
if isinstance(index, slice):
_cont = None
_disc_list = []
if self.continuous_tensor is not None:
_cont = self.continuous_tensor.__getitem__(index)
if self.discrete_list is not None and len(self.discrete_list) > 0:
for _disc in self.discrete_list:
_disc_list.append(_disc.__getitem__(index))
return AgentAction(_cont, _disc_list)
else:
return super().__getitem__(index)
@property
def discrete_tensor(self) -> torch.Tensor:
"""

return torch.stack(self.discrete_list, dim=-1)
else:
return torch.empty(0)
def slice(self, start: int, end: int) -> "AgentAction":
"""
Returns an AgentAction with the continuous and discrete tensors slices
from index start to index end.
"""
_cont = None
_disc_list = []
if self.continuous_tensor is not None:
_cont = self.continuous_tensor[start:end]
if self.discrete_list is not None and len(self.discrete_list) > 0:
for _disc in self.discrete_list:
_disc_list.append(_disc[start:end])
return AgentAction(_cont, _disc_list)
def to_action_tuple(self, clip: bool = False) -> ActionTuple:
"""

正在加载...
取消
保存