|
|
|
|
|
|
from typing import List, Optional, NamedTuple |
|
|
|
import itertools |
|
|
|
import numpy as np |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey |
|
|
|
from mlagents.trainers.buffer import AgentBuffer, BufferKey, AgentBufferField |
|
|
|
from mlagents.trainers.torch.utils import ModelUtils |
|
|
|
from mlagents_envs.base_env import ActionTuple |
|
|
|
|
|
|
|
|
|
|
discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1]) |
|
|
|
] |
|
|
|
return AgentAction(continuous, discrete) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _padded_time_to_batch( |
|
|
|
agent_buffer_field: AgentBufferField, dtype: torch.dtype = torch.float32 |
|
|
|
) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Pad actions and convert to tensor. Note that data is padded by 0's, not NaNs |
|
|
|
as the observations are. |
|
|
|
""" |
|
|
|
action_shape = None |
|
|
|
for _action in agent_buffer_field: |
|
|
|
if _action: |
|
|
|
action_shape = _action[0].shape |
|
|
|
break |
|
|
|
# If there were no critic obs at all |
|
|
|
if action_shape is None: |
|
|
|
return [] |
|
|
|
|
|
|
|
new_list = list( |
|
|
|
map( |
|
|
|
lambda x: ModelUtils.list_to_tensor(x, dtype=dtype), |
|
|
|
itertools.zip_longest( |
|
|
|
*agent_buffer_field, fillvalue=np.full(action_shape, 0) |
|
|
|
), |
|
|
|
) |
|
|
|
) |
|
|
|
return new_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _group_from_buffer( |
|
|
|
buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey |
|
|
|
) -> List["AgentAction"]: |
|
|
|
continuous_tensors: List[torch.Tensor] = [] |
|
|
|
discrete_tensors: List[torch.Tensor] = [] # type: ignore |
|
|
|
if cont_action_key in buff: |
|
|
|
continuous_tensors = AgentAction._padded_time_to_batch( |
|
|
|
buff[cont_action_key] |
|
|
|
) |
|
|
|
if disc_action_key in buff: |
|
|
|
discrete_tensors = AgentAction._padded_time_to_batch( |
|
|
|
buff[disc_action_key], dtype=torch.long |
|
|
|
) |
|
|
|
|
|
|
|
actions_list = [] |
|
|
|
for _cont, _disc in itertools.zip_longest( |
|
|
|
continuous_tensors, discrete_tensors, fillvalue=None |
|
|
|
): |
|
|
|
if _disc is not None: |
|
|
|
_disc = [_disc[..., i] for i in range(_disc.shape[-1])] |
|
|
|
actions_list.append(AgentAction(_cont, _disc)) |
|
|
|
return actions_list |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]: |
|
|
|
""" |
|
|
|
A static method that accesses continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs the corresponding AgentAction from the retrieved np arrays. |
|
|
|
""" |
|
|
|
return AgentAction._group_from_buffer( |
|
|
|
buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION |
|
|
|
) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]: |
|
|
|
""" |
|
|
|
A static method that accesses next continuous and discrete action fields in an AgentBuffer |
|
|
|
and constructs the corresponding AgentAction from the retrieved np arrays. |
|
|
|
""" |
|
|
|
return AgentAction._group_from_buffer( |
|
|
|
buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION |
|
|
|
) |
|
|
|
|
|
|
|
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor: |
|
|
|
discrete_oh = ModelUtils.actions_to_onehot( |
|
|
|
self.discrete_tensor, discrete_branches |
|
|
|
) |
|
|
|
discrete_oh = torch.cat(discrete_oh, dim=1) |
|
|
|
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1) |