浏览代码

[coma2] Add support for variable length obs in COMA2 (#5038)

* Make group extrinsic part of extrinsic

* Fix test and init

* Fix tests and bug

* Add baseline loss to TensorBoard

* Add support for variable len obs in COMA2

* Remove weird merge artifact

* Make agent action run

* Fix __getitem__ replace with slice

* Revert "Fix __getitem__ replace with slice"

This reverts commit 87a2c9d9a9342a7d2be4e9f620d1294a5c3bf22c.

* Revert "Make agent action run"

This reverts commit 59531f3746c58d62cf52f58a88e27a3e428e8946.
/develop/action-slice
GitHub 4 年前
当前提交
6ae8ea1e
共有 3 个文件被更改,包括 147 次插入83 次删除
  1. 9
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  2. 124
      ml-agents/mlagents/trainers/torch/networks.py
  3. 97
      ml-agents/mlagents/trainers/torch/utils.py

9
ml-agents/mlagents/trainers/coma/optimizer_torch.py


all_next_value_mem: Optional[AgentBufferField] = None
all_next_baseline_mem: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem = self._evaluate_by_sequence_team(
(
value_estimates,
baseline_estimates,
all_next_value_mem,
all_next_baseline_mem,
next_value_mem,
next_baseline_mem,
) = self._evaluate_by_sequence_team(
current_obs, team_obs, team_actions, _init_value_mem, _init_baseline_mem
)
else:

124
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.attention import (
EntityEmbedding,
ResidualSelfAttention,
get_zero_entities_mask,
)
from mlagents.trainers.torch.attention import EntityEmbedding, ResidualSelfAttention
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

normalize=self.normalize,
)
entity_num_max: int = 0
var_processors = [p for p in self.processors if isinstance(p, EntityEmbedding)]
for processor in var_processors:
entity_max: int = processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(var_processors) > 0:
if sum(self.embedding_sizes):
self.x_self_encoder = LinearEncoder(
sum(self.embedding_sizes),
1,
self.h_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
self.processors, self.embedding_sizes, self.h_size
)
if self.rsa is not None:
total_enc_size = sum(self.embedding_sizes) + self.h_size
else:
total_enc_size = sum(self.embedding_sizes)

memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []
for idx, processor in enumerate(self.processors):
if not isinstance(processor, EntityEmbedding):
# The input can be encoded without having to process other inputs
obs_input = inputs[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)
else:
var_len_processor_inputs.append((processor, inputs[idx]))
if len(encodes) != 0:
encoded_self = torch.cat(encodes, dim=1)
input_exist = True
else:
input_exist = False
if len(var_len_processor_inputs) > 0:
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
embeddings: List[torch.Tensor] = []
processed_self = self.x_self_encoder(encoded_self) if input_exist else None
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
if not input_exist:
encoded_self = torch.cat([attention_embedding], dim=1)
input_exist = True
else:
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
if not input_exist:
raise Exception(
"The trainer was unable to process any of the provided inputs. "
"Make sure the trained agents has at least one sensor attached to them."
)
encoded_self = ModelUtils.encode_observations(
inputs, self.processors, self.rsa, self.x_self_encoder
)
if actions is not None:
encoded_self = torch.cat([encoded_self, actions], dim=1)
encoding = self.linear_encoder(encoded_self)

normalize=self.normalize,
)
self.action_spec = action_spec
# This RSA and input are for variable length obs, not for multi-agentt.
(
self.input_rsa,
self.input_x_self_encoder,
) = ModelUtils.create_residual_self_attention(
self.processors, _input_size, self.h_size
)
if self.input_rsa is not None:
_input_size.append(self.h_size)
# Modules for self-attention
# Modules for multi-agent self-attention
obs_only_ent_size = sum(_input_size)
q_ent_size = (
sum(_input_size)

attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor)
return attn_mask
def _remove_nans_from_obs(
self, all_obs: List[List[torch.Tensor]], attention_mask: torch.Tensor
) -> None:
"""
Helper function to remove NaNs from observations using an attention mask.
"""
for i_agent, single_agent_obs in enumerate(all_obs):
for obs in single_agent_obs:
obs[
attention_mask.type(torch.BoolTensor)[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
def forward(
self,
obs_only: List[List[torch.Tensor]],

concat_f_inp = []
if obs:
obs_attn_mask = self._get_masks_from_nans(obs)
for i_agent, (inputs, action) in enumerate(zip(obs, actions)):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[
obs_attn_mask.type(torch.BoolTensor)[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
processed_obs = processor(obs_input)
encodes.append(processed_obs)
self._remove_nans_from_obs(obs, obs_attn_mask)
for inputs, action in zip(obs, actions):
encoded = ModelUtils.encode_observations(
inputs, self.processors, self.input_rsa, self.input_x_self_encoder
)
torch.cat(encodes, dim=-1),
encoded,
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))

concat_encoded_obs = []
if obs_only:
obs_only_attn_mask = self._get_masks_from_nans(obs_only)
for i_agent, inputs in enumerate(obs_only):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[
obs_only_attn_mask.type(torch.BoolTensor)[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
self._remove_nans_from_obs(obs_only, obs_only_attn_mask)
for inputs in obs_only:
encoded = ModelUtils.encode_observations(
inputs, self.processors, self.input_rsa, self.input_x_self_encoder
)
concat_encoded_obs.append(encoded)
g_inp = torch.stack(concat_encoded_obs, dim=1)
self_attn_masks.append(obs_only_attn_mask)
self_attn_inputs.append(self.obs_encoder(None, g_inp))

97
ml-agents/mlagents/trainers/torch/utils.py


from typing import List, Optional, Tuple
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch.layers import LinearEncoder, Initialization
import numpy as np
from mlagents.trainers.torch.encoders import (

VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.attention import EntityEmbedding
from mlagents.trainers.torch.attention import (
EntityEmbedding,
ResidualSelfAttention,
get_zero_entities_mask,
)
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import ObservationSpec, DimensionProperty

alpha=tau,
out=target_param.data,
)
@staticmethod
def create_residual_self_attention(
input_processors: nn.ModuleList, embedding_sizes: List[int], hidden_size: int
) -> Tuple[Optional[ResidualSelfAttention], Optional[LinearEncoder]]:
"""
Creates an RSA if there are variable length observations found in the input processors.
:param input_processors: A ModuleList of input processors as returned by the function
create_input_processors().
:param embedding sizes: A List of embedding sizes as returned by create_input_processors().
:param hidden_size: The hidden size to use for the RSA.
:returns: A Tuple of the RSA itself, a self encoder, and the embedding size after the RSA.
Returns None for the RSA and encoder if no var len inputs are detected.
"""
rsa, x_self_encoder = None, None
entity_num_max: int = 0
var_processors = [p for p in input_processors if isinstance(p, EntityEmbedding)]
for processor in var_processors:
entity_max: int = processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(var_processors) > 0:
if sum(embedding_sizes):
x_self_encoder = LinearEncoder(
sum(embedding_sizes),
1,
hidden_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / hidden_size) ** 0.5,
)
rsa = ResidualSelfAttention(hidden_size, entity_num_max)
return rsa, x_self_encoder
@staticmethod
def encode_observations(
inputs: List[torch.Tensor],
processors: nn.ModuleList,
rsa: Optional[ResidualSelfAttention],
x_self_encoder: Optional[LinearEncoder],
) -> torch.Tensor:
"""
Helper method to encode observations using a listt of processors and an RSA.
:param inputs: List of Tensors corresponding to a set of obs.
:param processors: a ModuleList of the input processors to be applied to these obs.
:param rsa: Optionally, an RSA to use for variable length obs.
:param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
"""
encodes = []
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []
for idx, processor in enumerate(processors):
if not isinstance(processor, EntityEmbedding):
# The input can be encoded without having to process other inputs
obs_input = inputs[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)
else:
var_len_processor_inputs.append((processor, inputs[idx]))
if len(encodes) != 0:
encoded_self = torch.cat(encodes, dim=1)
input_exist = True
else:
input_exist = False
if len(var_len_processor_inputs) > 0 and rsa is not None:
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
embeddings: List[torch.Tensor] = []
processed_self = (
x_self_encoder(encoded_self)
if input_exist and x_self_encoder is not None
else None
)
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = rsa(qkv, masks)
if not input_exist:
encoded_self = torch.cat([attention_embedding], dim=1)
input_exist = True
else:
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
if not input_exist:
raise UnityTrainerException(
"The trainer was unable to process any of the provided inputs. "
"Make sure the trained agents has at least one sensor attached to them."
)
return encoded_self
正在加载...
取消
保存