浏览代码

add slice function to agent action

/develop/action-slice
Andrew Cohen 4 年前
当前提交
0afe5f24
共有 2 个文件被更改,包括 12 次插入15 次删除
  1. 4
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  2. 23
      ml-agents/mlagents/trainers/torch/agent_action.py

4
ml-agents/mlagents/trainers/coma/optimizer_torch.py


first_seq_obs = _obs[0:first_seq_len]
seq_obs.append(first_seq_obs)
team_seq_obs.append(seq_obs)
_act = team_action[0:first_seq_len]
_act = team_action.slice(0, first_seq_len)
team_seq_act.append(_act)
# For the first sequence, the initial memory should be the one at the

first_seq_obs = _obs[start:end]
seq_obs.append(first_seq_obs)
team_seq_obs.append(seq_obs)
_act = team_action[start:end]
_act = team_action.slice(start, end)
team_seq_act.append(_act)
all_seq_obs = self_seq_obs + team_seq_obs

23
ml-agents/mlagents/trainers/torch/agent_action.py


continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
def __getitem__(self, index):
if isinstance(index, slice):
_cont = None
_disc_list = []
if self.continuous_tensor is not None:
_cont = self.continuous_tensor.__getitem__(index)
if self.discrete_list is not None and len(self.discrete_list) > 0:
for _disc in self.discrete_list:
_disc_list.append(_disc.__getitem__(index))
return AgentAction(_cont, _disc_list)
else:
return super().__getitem__(index)
@property
def discrete_tensor(self) -> torch.Tensor:
"""

return torch.stack(self.discrete_list, dim=-1)
else:
return torch.empty(0)
def slice(self, start, end):
_cont = None
_disc_list = []
if self.continuous_tensor is not None:
_cont = self.continuous_tensor[start:end]
if self.discrete_list is not None and len(self.discrete_list) > 0:
for _disc in self.discrete_list:
_disc_list.append(_disc[start:end])
return AgentAction(_cont, _disc_list)
def to_action_tuple(self, clip: bool = False) -> ActionTuple:
"""

正在加载...
取消
保存