Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

161 行
6.2 KiB

from typing import List, Optional, NamedTuple, Dict
import itertools
from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.torch.utils import ModelUtils
from mlagents_envs.base_env import ActionTuple
from mlagents.trainers.buffer import AgentBuffer
class AgentAction(NamedTuple):
"""
A NamedTuple containing the tensor for continuous actions and list of tensors for
discrete actions. Utility functions provide numpy <=> tensor conversions to be
sent as actions to the environment manager as well as used by the optimizers.
:param continuous_tensor: Torch tensor corresponding to continuous actions
:param discrete_list: List of Torch tensors each corresponding to discrete actions
"""
continuous_tensor: torch.Tensor
discrete_list: Optional[List[torch.Tensor]]
@property
def discrete_tensor(self):
"""
Returns the discrete action list as a stacked tensor
"""
return torch.stack(self.discrete_list, dim=-1)
def to_action_tuple(self, clip: bool = False) -> ActionTuple:
"""
Returns an ActionTuple
"""
action_tuple = ActionTuple()
if self.continuous_tensor is not None:
_continuous_tensor = self.continuous_tensor
if clip:
_continuous_tensor = torch.clamp(_continuous_tensor, -3, 3) / 3
continuous = ModelUtils.to_numpy(_continuous_tensor)
action_tuple.add_continuous(continuous)
if self.discrete_list is not None:
discrete = ModelUtils.to_numpy(self.discrete_tensor[:, 0, :])
action_tuple.add_discrete(discrete)
return action_tuple
@staticmethod
def _padded_time_to_batch(
agent_buffer_field: AgentBuffer.AgentBufferField,
dtype: torch.dtype = torch.float32,
) -> List[torch.Tensor]:
"""
Pad actions and convert to tensor. Note that data is padded by 0's, not NaNs
as the observations are.
"""
action_shape = None
for _action in agent_buffer_field:
if _action:
action_shape = _action[0].shape
break
# If there were no critic obs at all
if action_shape is None:
return []
new_list = list(
map(
lambda x: ModelUtils.list_to_tensor(x, dtype=dtype),
itertools.zip_longest(
*agent_buffer_field, fillvalue=np.full(action_shape, 0)
),
)
)
return new_list
@staticmethod
def from_dict(buff: Dict[str, np.ndarray]) -> "AgentAction":
"""
A static method that accesses continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore
if "continuous_action" in buff:
continuous = ModelUtils.list_to_tensor(buff["continuous_action"])
if "discrete_action" in buff:
discrete_tensor = ModelUtils.list_to_tensor(
buff["discrete_action"], dtype=torch.long
)
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return AgentAction(continuous, discrete)
@staticmethod
def from_dict_next(buff: Dict[str, np.ndarray]) -> "AgentAction":
"""
A static method that accesses continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
"""
continuous: torch.Tensor = None
discrete: List[torch.Tensor] = None # type: ignore
if "next_continuous_action" in buff:
continuous = ModelUtils.list_to_tensor(buff["continuous_action"])
if "next_discrete_action" in buff:
discrete_tensor = ModelUtils.list_to_tensor(
buff["discrete_action"], dtype=torch.long
)
discrete = [
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return AgentAction(continuous, discrete)
@staticmethod
def _from_team_dict(
buff: Dict[str, np.ndarray], cont_action_key: str, disc_action_key: str
):
continuous_tensors: List[torch.Tensor] = []
discrete_tensors: List[torch.Tensor] = [] # type: ignore
if cont_action_key in buff:
continuous_tensors = AgentAction._padded_time_to_batch(
buff[cont_action_key]
)
if disc_action_key in buff:
discrete_tensors = AgentAction._padded_time_to_batch(
buff[disc_action_key], dtype=torch.long
)
actions_list = []
for _cont, _disc in itertools.zip_longest(
continuous_tensors, discrete_tensors, fillvalue=None
):
if _disc is not None:
_disc = [_disc[..., i] for i in range(_disc.shape[-1])]
actions_list.append(AgentAction(_cont, _disc))
return actions_list
@staticmethod
def from_team_dict(buff: Dict[str, np.ndarray]) -> List["AgentAction"]:
"""
A static method that accesses continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
"""
return AgentAction._from_team_dict(
buff, "team_continuous_action", "team_discrete_action"
)
@staticmethod
def from_team_dict_next(buff: Dict[str, np.ndarray]) -> List["AgentAction"]:
"""
A static method that accesses next continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
"""
return AgentAction._from_team_dict(
buff, "team_next_continuous_action", "team_next_discrete_action"
)
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor:
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
discrete_oh = torch.cat(discrete_oh, dim=1)
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)