浏览代码

Add more comments

/develop/coma2/samenet
Ervin Teng 4 年前
当前提交
13fca55f
共有 1 个文件被更改,包括 24 次插入4 次删除
  1. 28
      ml-agents/mlagents/trainers/torch/agent_action.py

28
ml-agents/mlagents/trainers/torch/agent_action.py


def _group_from_buffer(
buff: AgentBuffer, cont_action_key: BufferKey, disc_action_key: BufferKey
) -> List["AgentAction"]:
"""
Extracts continuous and discrete groupmate actions, as specified by BufferKey, and
returns a List of AgentActions that correspond to the groupmate's actions. List will
be of length equal to the maximum number of groupmates in the buffer. Any spots where
there are less agents than maximum, the actions will be padded with 0's.
"""
continuous_tensors: List[torch.Tensor] = []
discrete_tensors: List[torch.Tensor] = []
if cont_action_key in buff:

@staticmethod
def group_from_buffer(buff: AgentBuffer) -> List["AgentAction"]:
"""
A static method that accesses continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
A static method that accesses next group continuous and discrete action fields in an AgentBuffer
and constructs a padded List of AgentActions that represent the group agent actions.
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0.
:param buff: AgentBuffer of a batch or trajectory
:return: List of groupmate's AgentActions
"""
return AgentAction._group_from_buffer(
buff, BufferKey.GROUP_CONTINUOUS_ACTION, BufferKey.GROUP_DISCRETE_ACTION

def group_from_buffer_next(buff: AgentBuffer) -> List["AgentAction"]:
"""
A static method that accesses next continuous and discrete action fields in an AgentBuffer
and constructs the corresponding AgentAction from the retrieved np arrays.
A static method that accesses next group continuous and discrete action fields in an AgentBuffer
and constructs a padded List of AgentActions that represent the next group agent actions.
The List is of length equal to max number of groupmate agents in the buffer, and the AgentBuffer iss
of the same length as the buffer. Empty spots (e.g. when agents die) are padded with 0.
:param buff: AgentBuffer of a batch or trajectory
:return: List of groupmate's AgentActions
"""
return AgentAction._group_from_buffer(
buff, BufferKey.GROUP_NEXT_CONT_ACTION, BufferKey.GROUP_NEXT_DISC_ACTION

"""
Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete).
Discrete actions are converted into one-hot and concatenated with continuous actions.
:param discrete_branches: List of sizes for discrete actions.
:return: Tensor of flattened actions.
"""
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
正在加载...
取消
保存