您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
44 行
1.7 KiB
44 行
1.7 KiB
from typing import List
|
|
from mlagents.torch_utils import torch
|
|
|
|
from mlagents_envs.base_env import ActionSpec
|
|
from mlagents.trainers.torch.agent_action import AgentAction
|
|
from mlagents.trainers.torch.utils import ModelUtils
|
|
|
|
|
|
class ActionFlattener:
|
|
def __init__(self, action_spec: ActionSpec):
|
|
"""
|
|
A torch module that creates the flattened form of an AgentAction object.
|
|
The flattened form is the continuous action concatenated with the
|
|
concatenated one hot encodings of the discrete actions.
|
|
:param action_spec: An ActionSpec that describes the action space dimensions
|
|
"""
|
|
self._specs = action_spec
|
|
|
|
@property
|
|
def flattened_size(self) -> int:
|
|
"""
|
|
The flattened size is the continuous size plus the sum of the branch sizes
|
|
since discrete actions are encoded as one hots.
|
|
"""
|
|
return self._specs.continuous_size + sum(self._specs.discrete_branches)
|
|
|
|
def forward(self, action: AgentAction) -> torch.Tensor:
|
|
"""
|
|
Returns a tensor corresponding the flattened action
|
|
:param action: An AgentAction object
|
|
"""
|
|
action_list: List[torch.Tensor] = []
|
|
if self._specs.continuous_size > 0:
|
|
action_list.append(action.continuous_tensor)
|
|
if self._specs.discrete_size > 0:
|
|
flat_discrete = torch.cat(
|
|
ModelUtils.actions_to_onehot(
|
|
torch.as_tensor(action.discrete_tensor, dtype=torch.long),
|
|
self._specs.discrete_branches,
|
|
),
|
|
dim=1,
|
|
)
|
|
action_list.append(flat_discrete)
|
|
return torch.cat(action_list, dim=1)
|