浏览代码

Add team methods to AgentAction

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
3d2171c4
共有 1 个文件被更改,包括 81 次插入1 次删除
  1. 82
      ml-agents/mlagents/trainers/torch/agent_action.py

82
ml-agents/mlagents/trainers/torch/agent_action.py


from typing import List, Optional, NamedTuple
import itertools
import numpy as np
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.buffer import AgentBuffer, BufferKey, AgentBufferField
from mlagents.trainers.torch.utils import ModelUtils
from mlagents_envs.base_env import ActionTuple

discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return AgentAction(continuous, discrete)
@staticmethod
def _padded_time_to_batch(
agent_buffer_field: AgentBufferField, dtype: torch.dtype = torch.float32
) -> List[torch.Tensor]:
"""
Pad actions and convert to tensor. Note that data is padded by 0's, not NaNs
as the observations are.
"""
action_shape = None
for _action in agent_buffer_field:
if _action:
action_shape = _action[0].shape
break
# If there were no critic obs at all
if action_shape is None:
return []
new_list = list(
map(
lambda x: ModelUtils.list_to_tensor(x, dtype=dtype),
itertools.zip_longest(
*agent_buffer_field, fillvalue=np.full(action_shape, 0)
),
)
)
return new_list
@staticmethod
def _group_from_buffer(
buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey
) -> List["AgentAction"]:
continuous_tensors: List[torch.Tensor] = []
discrete_tensors: List[torch.Tensor] = [] # type: ignore
if cont_action_key in buff:
continuous_tensors = AgentAction._padded_time_to_batch(
buff[cont_action_key]
)
if disc_action_key in buff:
discrete_tensors = AgentAction._padded_time_to_batch(
buff[disc_action_key], dtype=torch.long
)
actions_list = []
for _cont, _disc in itertools.zip_longest(
continuous_tensors, discrete_tensors, fillvalue=None
):
if _disc is not None:
_disc = [_disc[..., i] for i in range(_disc.shape[-1])]
actions_list.append(AgentAction(_cont, _disc))
return actions_list
@staticmethod
def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]:
"""
A static method that accesses continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
"""
return AgentAction._group_from_buffer(
buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION
)
@staticmethod
def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]:
"""
A static method that accesses next continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
"""
return AgentAction._group_from_buffer(
buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION
)
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor:
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
discrete_oh = torch.cat(discrete_oh, dim=1)
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)
正在加载...
取消
保存