浏览代码

Fix the attention module embedding size

/develop/fix-attn-embedding
vincentpierre 3 年前
当前提交
51adab1c
共有 3 个文件被更改,包括 31 次插入16 次删除
  1. 2
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  2. 31
      ml-agents/mlagents/trainers/torch/networks.py
  3. 14
      ml-agents/mlagents/trainers/torch/utils.py

2
ml-agents/mlagents/trainers/tests/torch/test_utils.py


h_size = 128
obs_spec = create_observation_specs_with_shapes(obs_shapes)
encoders, embedding_sizes = ModelUtils.create_input_processors(
obs_spec, h_size, encoder_type, normalize
obs_spec, h_size, encoder_type, h_size, normalize
)
total_output = sum(embedding_sizes)
vec_enc = []

31
ml-agents/mlagents/trainers/torch/networks.py


class ObservationEncoder(nn.Module):
ATTENTION_EMBEDDING_SIZE = 128 # The embedding size of attention is fixed
def __init__(
self,
observation_specs: List[ObservationSpec],

"""
super().__init__()
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
observation_specs, h_size, vis_encode_type, normalize=normalize
observation_specs,
h_size,
vis_encode_type,
self.ATTENTION_EMBEDDING_SIZE,
normalize=normalize,
self.processors, self.embedding_sizes, h_size
self.processors, self.embedding_sizes, self.ATTENTION_EMBEDDING_SIZE
total_enc_size = sum(self.embedding_sizes) + h_size
total_enc_size = sum(self.embedding_sizes) + self.ATTENTION_EMBEDDING_SIZE
else:
total_enc_size = sum(self.embedding_sizes)
self.normalize = normalize

class MultiAgentNetworkBody(torch.nn.Module):
ATTENTION_EMBEDDING_SIZE = 128
"""
A network body that uses a self attention layer to handle state
and action input from a potentially variable number of agents that

+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, self.h_size)
self.obs_action_encoder = EntityEmbedding(q_ent_size, None, self.h_size)
self.obs_encoder = EntityEmbedding(
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
)
self.obs_action_encoder = EntityEmbedding(
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
)
self.self_attn = ResidualSelfAttention(self.h_size)
self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE)
self.h_size,
self.ATTENTION_EMBEDDING_SIZE,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,

no_nan_obs = []
for obs in single_agent_obs:
new_obs = obs.clone()
new_obs[
attention_mask.bool()[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
new_obs[attention_mask.bool()[:, i_agent], ::] = 0.0 # Remove NaNs fast
no_nan_obs.append(new_obs)
obs_with_no_nans.append(no_nan_obs)
return obs_with_no_nans

14
ml-agents/mlagents/trainers/torch/utils.py


obs_spec: ObservationSpec,
normalize: bool,
h_size: int,
attention_embedding_size: int,
vis_encode_type: EncoderType,
) -> Tuple[nn.Module, int]:
"""

:param h_size: Number of hidden units per layer.
:param h_size: Number of hidden units per layer excluding attention layers.
:param attention_embedding_size: Number of hidden units per attention layer.
:param vis_encode_type: Type of visual encoder to use.
"""
shape = obs_spec.shape

EntityEmbedding(
entity_size=shape[1],
entity_num_max_elements=shape[0],
embedding_size=h_size,
embedding_size=attention_embedding_size,
),
0,
)

observation_specs: List[ObservationSpec],
h_size: int,
vis_encode_type: EncoderType,
attention_embedding_size: int,
normalize: bool = False,
) -> Tuple[nn.ModuleList, List[int]]:
"""

conditioning network on other values (e.g. actions for a Q function)
:param h_size: Number of hidden units per layer.
:param h_size: Number of hidden units per layer excluding attention layers.
:param attention_embedding_size: Number of hidden units per attention layer.
:param vis_encode_type: Type of visual encoder to use.
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
obs.

embedding_sizes: List[int] = []
for obs_spec in observation_specs:
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
obs_spec, normalize, h_size, vis_encode_type
obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
)
encoders.append(encoder)
embedding_sizes.append(embedding_size)

for enc in encoders:
if isinstance(enc, EntityEmbedding):
enc.add_self_embedding(h_size)
enc.add_self_embedding(attention_embedding_size)
return (nn.ModuleList(encoders), embedding_sizes)
@staticmethod

正在加载...
取消
保存