浏览代码

Merge pull request #4844 from Unity-Technologies/develop-att-network-integration

Integrate attention to networkbody
/bullet-hell-barracuda-test-1.3.1
GitHub 4 年前
当前提交
212ebfb9
共有 10 个文件被更改,包括 343 次插入136 次删除
  1. 2
      ml-agents-envs/mlagents_envs/base_env.py
  2. 5
      ml-agents-envs/mlagents_envs/rpc_utils.py
  3. 2
      ml-agents/mlagents/trainers/tests/dummy_config.py
  4. 9
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  5. 71
      ml-agents/mlagents/trainers/tests/torch/test_attention.py
  6. 38
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  7. 161
      ml-agents/mlagents/trainers/torch/attention.py
  8. 43
      ml-agents/mlagents/trainers/torch/model_serialization.py
  9. 84
      ml-agents/mlagents/trainers/torch/networks.py
  10. 64
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents-envs/mlagents_envs/base_env.py


spaces for a group of Agents under the same behavior.
- observation_specs is a List of ObservationSpec NamedTuple containing
information about the information of the Agent's observations such as their shapes.
The order of the SensorSpec is the same as the order of the observations of an
The order of the ObservationSpec is the same as the order of the observations of an
agent.
- action_spec is an ActionSpec NamedTuple.
"""

5
ml-agents-envs/mlagents_envs/rpc_utils.py


observation_specs.append(
ObservationSpec(
tuple(obs.shape),
tuple(DimensionProperty(dim) for dim in obs.dimension_properties),
tuple(DimensionProperty(dim) for dim in obs.dimension_properties)
if len(obs.dimension_properties) > 0
else (DimensionProperty.UNSPECIFIED,) * len(obs.shape),
# proto from communicator < v1.3 does not set action spec, use deprecated fields instead
if (
brain_param_proto.action_spec.num_continuous_actions == 0

2
ml-agents/mlagents/trainers/tests/dummy_config.py


obs_specs: List[ObservationSpec] = []
for shape in shapes:
dim_prop = (DimensionProperty.UNSPECIFIED,) * len(shape)
if len(shape) == 2:
dim_prop = (DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)
spec = ObservationSpec(shape, dim_prop, ObservationType.DEFAULT)
obs_specs.append(spec)
return obs_specs

9
ml-agents/mlagents/trainers/tests/simple_test_envs.py


OBS_SIZE = 1
VIS_OBS_SIZE = (20, 20, 3)
VAR_LEN_SIZE = (10, 5)
STEP_SIZE = 0.2
TIME_PENALTY = 0.01

step_size=STEP_SIZE,
num_visual=0,
num_vector=1,
num_var_len=0,
var_len_obs_size=VAR_LEN_SIZE,
self.num_var_len = num_var_len
self.var_len_obs_size = var_len_obs_size
continuous_action_size, discrete_action_size = action_sizes
discrete_tuple = tuple(2 for _ in range(discrete_action_size))
action_spec = ActionSpec(continuous_action_size, discrete_tuple)

obs_shape.append((self.vec_obs_size,))
for _ in range(self.num_visual):
obs_shape.append(self.vis_obs_size)
for _ in range(self.num_var_len):
obs_shape.append(self.var_len_obs_size)
obs_spec = create_observation_specs_with_shapes(obs_shape)
return obs_spec

obs.append(np.ones((1, self.vec_obs_size), dtype=np.float32) * value)
for _ in range(self.num_visual):
obs.append(np.ones((1,) + self.vis_obs_size, dtype=np.float32) * value)
for _ in range(self.num_var_len):
obs.append(np.ones((1,) + self.var_len_obs_size, dtype=np.float32) * value)
return obs
@property

71
ml-agents/mlagents/trainers/tests/torch/test_attention.py


import pytest
from mlagents.torch_utils import torch
import numpy as np

MultiHeadAttention,
EntityEmbeddings,
EntityEmbedding,
get_zero_entities_mask,
)

input_1 = generate_input_helper(masking_pattern_1)
input_2 = generate_input_helper(masking_pattern_2)
masks = EntityEmbeddings.get_masks([input_1, input_2])
masks = get_zero_entities_mask([input_1, input_2])
assert len(masks) == 2
masks_1 = masks[0]
masks_2 = masks[1]

assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
@pytest.mark.parametrize("mask_value", [0, 1])
def test_all_masking(mask_value):
# We make sure that a mask of all zeros or all ones will not trigger an error
np.random.seed(1336)
torch.manual_seed(1336)
size, n_k, = 3, 5
embedding_size = 64
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,
weight_decay=1e-6,
)
batch_size = 20
for _ in range(5):
center = torch.rand((batch_size, size))
key = torch.rand((batch_size, n_k, size))
with torch.no_grad():
# create the target : The key closest to the query in euclidean distance
distance = torch.sum(
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
)
argmin = torch.argmin(distance, dim=1)
target = []
for i in range(batch_size):
target += [key[i, argmin[i], :]]
target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, key)
masks = [torch.ones_like(key[:, :, 0]) * mask_value]
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))
error = torch.mean((prediction - target) ** 2, dim=1)
error = torch.mean(error) / 2
optimizer.zero_grad()
error.backward()
optimizer.step()
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
transformer = ResidualSelfAttention(embedding_size, [n_k])
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
entity_embeddings.add_self_embedding(size)
transformer = ResidualSelfAttention(embedding_size, n_k)
l_layer = linear_layer(embedding_size, size)
optimizer = torch.optim.Adam(
list(entity_embeddings.parameters())

target = torch.stack(target, dim=0)
target = target.detach()
embeddings = entity_embeddings(center, [key])
masks = EntityEmbeddings.get_masks([key])
embeddings = entity_embeddings(center, key)
masks = get_zero_entities_mask([key])
prediction = transformer.forward(embeddings, masks)
prediction = l_layer(prediction)
prediction = prediction.reshape((batch_size, size))

n_k = 5
size = n_k + 1
embedding_size = 64
entity_embeddings = EntityEmbeddings(
size, [size], embedding_size, [n_k], concat_self=False
)
entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self
list(entity_embeddings.parameters())
list(entity_embedding.parameters())
+ list(transformer.parameters())
+ list(l_layer.parameters()),
lr=0.001,

sliced_oh = onehots[:, : num + 1]
inp = torch.cat([inp, sliced_oh], dim=2)
embeddings = entity_embeddings(inp, [inp])
masks = EntityEmbeddings.get_masks([inp])
embeddings = entity_embedding(inp, inp)
masks = get_zero_entities_mask([inp])
prediction = transformer(embeddings, masks)
prediction = l_layer(prediction)
ce = loss(prediction, argmin)

38
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("num_var_len", [1, 2])
@pytest.mark.parametrize("num_vector", [0, 1])
@pytest.mark.parametrize("num_vis", [0, 1])
def test_var_len_obs_ppo(num_vis, num_vector, num_var_len, action_sizes):
env = SimpleEnvironment(
[BRAIN_NAME],
action_sizes=action_sizes,
num_visual=num_vis,
num_vector=num_vector,
num_var_len=num_var_len,
step_size=0.2,
)
new_hyperparams = attr.evolve(
PPO_TORCH_CONFIG.hyperparameters, learning_rate=3.0e-4
)
config = attr.evolve(PPO_TORCH_CONFIG, hyperparameters=new_hyperparams)
check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_visual_advanced_ppo(vis_encode_type, num_visual):

[BRAIN_NAME],
action_sizes=action_sizes,
num_visual=num_visual,
num_vector=0,
step_size=0.2,
)
new_hyperparams = attr.evolve(
SAC_TORCH_CONFIG.hyperparameters, batch_size=16, learning_rate=3e-4
)
config = attr.evolve(SAC_TORCH_CONFIG, hyperparameters=new_hyperparams)
check_environment_trains(env, {BRAIN_NAME: config})
@pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("num_var_len", [1, 2])
def test_var_len_obs_sac(num_var_len, action_sizes):
env = SimpleEnvironment(
[BRAIN_NAME],
action_sizes=action_sizes,
num_visual=0,
num_var_len=num_var_len,
num_vector=0,
step_size=0.2,
)

161
ml-agents/mlagents/trainers/torch/attention.py


from mlagents.trainers.exception import UnityTrainerException
class MultiHeadAttention(torch.nn.Module):
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]:
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
class MultiHeadAttention(torch.nn.Module):
"""
Multi Head Attention module. We do not use the regular Torch implementation since
Barracuda does not support some operators it uses.
Takes as input to the forward method 3 tensors:
- query: of dimensions (batch_size, number_of_queries, embedding_size)
- key: of dimensions (batch_size, number_of_keys, embedding_size)
- value: of dimensions (batch_size, number_of_keys, embedding_size)
The forward method will return 2 tensors:
- The output: (batch_size, number_of_queries, embedding_size)
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
:param embedding_size: The size of the embeddings that will be generated (should be
dividable by the num_heads)
:param total_max_elements: The maximum total number of entities that can be passed to
the module
:param num_heads: The number of heads of the attention module
"""
super().__init__()
self.n_heads = num_heads
self.head_size: int = embedding_size // self.n_heads

return value_attention, att
class EntityEmbeddings(torch.nn.Module):
class EntityEmbedding(torch.nn.Module):
"""
A module used to embed entities before passing them to a self-attention block.
Used in conjunction with ResidualSelfAttention to encode information about a self

def __init__(
self,
x_self_size: int,
entity_sizes: List[int],
entity_size: int,
entity_num_max_elements: Optional[int],
entity_num_max_elements: Optional[List[int]] = None,
concat_self: bool = True,
Constructs an EntityEmbeddings module.
Constructs an EntityEmbedding module.
:param entity_sizes: List of sizes for other entities. Should be of length
equivalent to the number of entities.
:param embedding_size: Embedding size for entity encoders.
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
:param entity_size: Size of other entities.
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
:param embedding_size: Embedding size for the entity encoder.
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
self.self_size: int = x_self_size
self.entity_sizes: List[int] = entity_sizes
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
self.self_size: int = 0
self.entity_size: int = entity_size
self.entity_num_max_elements: int = -1
self.embedding_size = embedding_size
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.self_ent_encoder = LinearEncoder(
self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
)
self.concat_self: bool = concat_self
# If not concatenating self, input to encoder is just entity size
if not concat_self:
self.self_size = 0
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
self.ent_encoders = torch.nn.ModuleList(
[
LinearEncoder(
self.self_size + ent_size,
1,
embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
for ent_size in self.entity_sizes
]
def add_self_embedding(self, size: int) -> None:
self.self_size = size
self.self_ent_encoder = LinearEncoder(
self.self_size + self.entity_size,
1,
self.embedding_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
self.embedding_norm = LayerNorm()
def forward(
self, x_self: torch.Tensor, entities: List[torch.Tensor]
) -> Tuple[torch.Tensor, int]:
if self.concat_self:
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
if self.self_size > 0:
num_entities = self.entity_num_max_elements
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = entities.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent: List[torch.Tensor] = []
for num_entities, ent in zip(self.entity_num_max_elements, entities):
if num_entities < 0:
if exporting_to_onnx.is_exporting():
raise UnityTrainerException(
"Trying to export an attention mechanism that doesn't have a set max \
number of elements."
)
num_entities = ent.shape[1]
expanded_self = x_self.reshape(-1, 1, self.self_size)
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
else:
self_and_ent = entities
# Encode and concatenate entites
encoded_entities = torch.cat(
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
dim=1,
)
encoded_entities = self.embedding_norm(encoded_entities)
entities = torch.cat([expanded_self, entities], dim=2)
# Encode entities
encoded_entities = self.self_ent_encoder(entities)
@staticmethod
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
"""
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
layer to mask the padding observations.
"""
with torch.no_grad():
# Generate the masking tensors for each entities tensor (mask only if all zeros)
key_masks: List[torch.Tensor] = [
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
]
return key_masks
with an EntityEmbeddings module, to apply multi head self attention to encode information
with an EntityEmbedding module, to apply multi head self attention to encode information
about a "Self" and a list of relevant "Entities".
"""

self,
embedding_size: int,
entity_num_max_elements: Optional[List[int]] = None,
entity_num_max_elements: Optional[int] = None,
num_heads: int = 4,
):
"""

super().__init__()
self.max_num_ent: Optional[int] = None
if entity_num_max_elements is not None:
_entity_num_max_elements = entity_num_max_elements
self.max_num_ent = sum(_entity_num_max_elements)
self.max_num_ent = entity_num_max_elements
self.attention = MultiHeadAttention(
num_heads=num_heads, embedding_size=embedding_size

kernel_init=Initialization.Normal,
kernel_gain=(0.125 / embedding_size) ** 0.5,
)
self.embedding_norm = LayerNorm()
inp = self.embedding_norm(inp)
# Feed to self attention
query = self.fc_q(inp) # (b, n_q, emb)
key = self.fc_k(inp) # (b, n_k, emb)

43
ml-agents/mlagents/trainers/torch/model_serialization.py


# Any multi-dimentional input should follow that otherwise will
# cause problem to barracuda import.
self.policy = policy
observation_specs = self.policy.behavior_spec.observation_specs
for sens_spec in self.policy.behavior_spec.observation_specs:
if len(sens_spec.shape) == 1:
vec_obs_size += sens_spec.shape[0]
for obs_spec in observation_specs:
if len(obs_spec.shape) == 1:
vec_obs_size += obs_spec.shape[0]
1
for sens_spec in self.policy.behavior_spec.observation_specs
if len(sens_spec.shape) == 3
1 for obs_spec in observation_specs if len(obs_spec.shape) == 3
# (It's NHWC in self.policy.behavior_spec.observation_specs.shape)
# (It's NHWC in observation_specs.shape)
for obs_spec in self.policy.behavior_spec.observation_specs
for obs_spec in observation_specs
dummy_var_len_obs = [
torch.zeros(batch_dim + [obs_spec.shape[0], obs_spec.shape[1]])
for obs_spec in observation_specs
if len(obs_spec.shape) == 2
]
dummy_masks = torch.ones(
batch_dim + [sum(self.policy.behavior_spec.action_spec.discrete_branches)]
)

self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories)
self.dummy_input = (
dummy_vec_obs,
dummy_vis_obs,
dummy_var_len_obs,
dummy_masks,
dummy_memories,
)
self.input_names = ["vector_observation"]
for i in range(num_vis_obs):
self.input_names.append(f"visual_observation_{i}")
for i, obs_spec in enumerate(observation_specs):
if len(obs_spec.shape) == 2:
self.input_names.append(f"obs_{i}")
self.input_names += ["action_masks", "memories"]
self.input_names = (
["vector_observation"]
+ [f"visual_observation_{i}" for i in range(num_vis_obs)]
+ ["action_masks", "memories"]
)
self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.output_names = ["version_number", "memory_size"]

84
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.trainers.settings import NetworkSettings
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.layers import LSTM, LinearEncoder, Initialization
from mlagents.trainers.torch.attention import (
EntityEmbedding,
ResidualSelfAttention,
get_zero_entities_mask,
)
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]

normalize=self.normalize,
)
total_enc_size = sum(self.embedding_sizes) + encoded_act_size
entity_num_max: int = 0
var_processors = [p for p in self.processors if isinstance(p, EntityEmbedding)]
for processor in var_processors:
entity_max: int = processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(var_processors) > 0:
if sum(self.embedding_sizes):
self.x_self_encoder = LinearEncoder(
sum(self.embedding_sizes),
1,
self.h_size,
kernel_init=Initialization.Normal,
kernel_gain=(0.125 / self.h_size) ** 0.5,
)
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)
total_enc_size = sum(self.embedding_sizes) + self.h_size
else:
total_enc_size = sum(self.embedding_sizes)
total_enc_size += encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size
)

sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []
obs_input = inputs[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)
if not isinstance(processor, EntityEmbedding):
# The input can be encoded without having to process other inputs
obs_input = inputs[idx]
processed_obs = processor(obs_input)
encodes.append(processed_obs)
else:
var_len_processor_inputs.append((processor, inputs[idx]))
if len(encodes) != 0:
encoded_self = torch.cat(encodes, dim=1)
input_exist = True
else:
input_exist = False
if len(var_len_processor_inputs) > 0:
# Some inputs need to be processed with a variable length encoder
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
embeddings: List[torch.Tensor] = []
processed_self = self.x_self_encoder(encoded_self) if input_exist else None
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)
if not input_exist:
encoded_self = torch.cat([attention_embedding], dim=1)
input_exist = True
else:
encoded_self = torch.cat([encoded_self, attention_embedding], dim=1)
if len(encodes) == 0:
raise Exception("No valid inputs to network.")
if not input_exist:
raise Exception(
"The trainer was unable to process any of the provided inputs. "
"Make sure the trained agents has at least one sensor attached to them."
)
# Constants don't work in Barracuda
inputs = torch.cat(encodes + [actions], dim=-1)
else:
inputs = torch.cat(encodes, dim=-1)
encoding = self.linear_encoder(inputs)
encoded_self = torch.cat([encoded_self, actions], dim=1)
encoding = self.linear_encoder(encoded_self)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:

self,
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
var_len_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[Union[int, torch.Tensor], ...]:

start = 0
end = 0
vis_index = 0
var_len_index = 0
for i, enc in enumerate(self.network_body.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs

start = end
else:
elif isinstance(enc, EntityEmbedding):
inputs.append(var_len_inputs[var_len_index])
var_len_index += 1
else: # visual input
# End of code to convert the vec and vis obs into a list of inputs for the network
encoding, memories_out = self.network_body(
inputs, memories=memories, sequence_length=1

64
ml-agents/mlagents/trainers/torch/utils.py


VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.torch.attention import EntityEmbedding
from mlagents_envs.base_env import ObservationSpec
from mlagents_envs.base_env import ObservationSpec, DimensionProperty
class ModelUtils:

EncoderType.NATURE_CNN: 36,
EncoderType.RESNET: 15,
}
VALID_VISUAL_PROP = frozenset(
[
(
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.TRANSLATIONAL_EQUIVARIANCE,
DimensionProperty.NONE,
),
(DimensionProperty.UNSPECIFIED,) * 3,
]
)
VALID_VECTOR_PROP = frozenset(
[(DimensionProperty.NONE,), (DimensionProperty.UNSPECIFIED,)]
)
VALID_VAR_LEN_PROP = frozenset(
[(DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)]
)
@staticmethod
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:

@staticmethod
def get_encoder_for_obs(
shape: Tuple[int, ...],
obs_spec: ObservationSpec,
normalize: bool,
h_size: int,
vis_encode_type: EncoderType,

:param h_size: Number of hidden units per layer.
:param vis_encode_type: Type of visual encoder to use.
"""
if len(shape) == 1:
# Case rank 1 tensor
return (VectorInput(shape[0], normalize), shape[0])
if len(shape) == 3:
ModelUtils._check_resolution_for_encoder(
shape[0], shape[1], vis_encode_type
)
shape = obs_spec.shape
dim_prop = obs_spec.dimension_property
# VISUAL
if dim_prop in ModelUtils.VALID_VISUAL_PROP:
raise UnityTrainerException(f"Unsupported shape of {shape} for observation")
# VECTOR
if dim_prop in ModelUtils.VALID_VECTOR_PROP:
return (VectorInput(shape[0], normalize), shape[0])
# VARIABLE LENGTH
if dim_prop in ModelUtils.VALID_VAR_LEN_PROP:
return (
EntityEmbedding(
entity_size=shape[1],
entity_num_max_elements=shape[0],
embedding_size=h_size,
),
0,
)
# OTHER
raise UnityTrainerException(f"Unsupported Sensor with specs {obs_spec}")
@staticmethod
def create_input_processors(

:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
obs.
:param normalize: Normalize all vector inputs.
:return: Tuple of visual encoders and vector encoders each as a list.
:return: Tuple of :
- ModuleList of the encoders
- A list of embedding sizes (0 if the input requires to be processed with a variable length
observation encoder)
obs_spec.shape, normalize, h_size, vis_encode_type
obs_spec, normalize, h_size, vis_encode_type
x_self_size = sum(embedding_sizes) # The size of the "self" embedding
if x_self_size > 0:
for enc in encoders:
if isinstance(enc, EntityEmbedding):
enc.add_self_embedding(h_size)
return (nn.ModuleList(encoders), embedding_sizes)
@staticmethod

正在加载...
取消
保存