浏览代码

addressing comments

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
2bf6737f
共有 4 个文件被更改,包括 61 次插入51 次删除
  1. 8
      ml-agents/mlagents/trainers/torch/encoders.py
  2. 37
      ml-agents/mlagents/trainers/torch/model_serialization.py
  3. 25
      ml-agents/mlagents/trainers/torch/networks.py
  4. 42
      ml-agents/mlagents/trainers/torch/utils.py

8
ml-agents/mlagents/trainers/torch/encoders.py


from mlagents.trainers.torch.model_serialization import exporting_to_onnx
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return inputs
class Normalizer(nn.Module):
def __init__(self, vec_obs_size: int):
super().__init__()

37
ml-agents/mlagents/trainers/torch/model_serialization.py


# Any multi-dimentional input should follow that otherwise will
# cause problem to barracuda import.
self.policy = policy
observation_specs = self.policy.behavior_spec.observation_specs
for sens_spec in self.policy.behavior_spec.observation_specs:
if len(sens_spec.shape) == 1:
vec_obs_size += sens_spec.shape[0]
for obs_spec in observation_specs:
if len(obs_spec.shape) == 1:
vec_obs_size += obs_spec.shape[0]
1
for sens_spec in self.policy.behavior_spec.observation_specs
if len(sens_spec.shape) == 3
1 for obs_spec in observation_specs if len(obs_spec.shape) == 3
# (It's NHWC in self.policy.behavior_spec.observation_specs.shape)
# (It's NHWC in observation_specs.shape)
for obs_spec in self.policy.behavior_spec.observation_specs
for obs_spec in observation_specs
for obs_spec in self.policy.behavior_spec.observation_specs
for obs_spec in observation_specs
if len(obs_spec.shape) == 2
]

dummy_memories,
)
self.input_names = (
["vector_observation"]
+ [f"visual_observation_{i}" for i in range(num_vis_obs)]
+ [
f"obs_{i}"
for i, sens_spec in enumerate(
self.policy.behavior_spec.observation_specs
)
if len(sens_spec.shape) == 2
]
+ ["action_masks", "memories"]
)
self.input_names = ["vector_observation"]
for i in range(num_vis_obs):
self.input_names.append(f"visual_observation_{i}")
for i, obs_spec in enumerate(observation_specs):
if len(obs_spec.shape) == 2:
self.input_names.append(f"obs_{i}")
self.input_names += ["action_masks", "memories"]
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.output_names = ["version_number", "memory_size"]

25
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.encoders import VectorInput
from mlagents.trainers.torch.encoders import VectorInput, Identity
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.attention import (

var_len_inputs = [] # The list of variable length inputs
for idx, processor in enumerate(self.processors):
if processor is not None:
if not isinstance(processor, Identity):
# The input can be encoded without having to process other inputs
obs_input = inputs[idx]
processed_obs = processor(obs_input)

# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask(var_len_inputs)
embeddings: List[torch.Tensor] = []
if input_exist:
processed_self = self.x_self_encoder(encoded_self)
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(processed_self, var_len_input))
else:
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(None, var_len_input))
processed_self = self.x_self_encoder(encoded_self) if input_exist else None
for var_len_input, var_len_processor in zip(
var_len_inputs, self.var_processors
):
embeddings.append(var_len_processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
if not input_exist:

encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
if not input_exist:
raise Exception("No valid inputs to network.")
raise Exception(
"The trainer was unable to process any of the provided inputs. "
"Make sure the trained agents has at least one sensor attached to them."
)
# Constants don't work in Barracuda
if actions is not None:

42
ml-agents/mlagents/trainers/torch/utils.py


NatureVisualEncoder,
SmallVisualEncoder,
VectorInput,
Identity,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.attention import EntityEmbedding

EncoderType.NATURE_CNN: 36,
EncoderType.RESNET: 15,
}
VALID_VISUAL_PROP = frozenset(
[
(
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.NONE,
),
(DimensionProperty.UNSPECIFIED,) * 3,
]
)
VALID_VECTOR_PROP = frozenset(
[(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
)
VALID_VAR_LEN_PROP = frozenset(
[(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
)
@staticmethod
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:

dim_prop = obs_spec.dimension_property
# VISUAL
valid_visual = (
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.NONE,
)
valid_visual_unspecified = (DimensionProperty.UNSPECIFIED,) * 3
if dim_prop == valid_visual or dim_prop == valid_visual_unspecified:
if dim_prop in ModelUtils.VALID_VISUAL_PROP:
valid_vector = (DimensionProperty.NONE,)
valid_vector_unspecified = (DimensionProperty.UNSPECIFIED,)
if dim_prop == valid_vector or dim_prop == valid_vector_unspecified:
if dim_prop in ModelUtils.VALID_VECTOR_PROP:
valid_var_len = (DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)
if dim_prop == valid_var_len:
# None means the residual self attention must be used
return (None, 0)
if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
return (Identity, 0)
# OTHER
raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")

obs.
:param normalize: Normalize all vector inputs.
:return: Tuple of :
- ModuleList of the encoders (None if the input requires to be processed with a variable length
- ModuleList of the encoders (Identity if the input requires to be processed with a variable length
observation encoder)
- A list of embedding sizes (0 if the input requires to be processed with a variable length
observation encoder)

)
encoders.append(encoder)
embedding_sizes.append(embedding_size)
if encoder is None:
if encoder is Identity:
var_len_indices.append(idx)
x_self_size = sum(embedding_sizes) # The size of the "self" embedding

正在加载...
取消
保存