浏览代码

move entity max computation to network body

/develop/singular-embeddings
Andrew Cohen 4 年前
当前提交
5caaef52
共有 2 个文件被更改,包括 12 次插入12 次删除
  1. 8
      ml-agents/mlagents/trainers/torch/networks.py
  2. 16
      ml-agents/mlagents/trainers/torch/utils.py

8
ml-agents/mlagents/trainers/torch/networks.py


else 0
)
self.processors, self.var_processors, self.embedding_sizes, entity_num_max = ModelUtils.create_input_processors(
self.processors, self.var_processors, self.embedding_sizes = ModelUtils.create_input_processors(
observation_specs,
self.h_size,
network_settings.vis_encode_type,

entity_num_max: int = 0
for var_processor in self.var_processors:
entity_max: int = var_processor.entity_num_max_elements
# Only adds entity max if it was known at construction
if entity_max > 0:
entity_num_max += entity_max
if len(self.var_processors) > 0:
self.embedding_norm = LayerNorm()
self.rsa = ResidualSelfAttention(self.h_size, entity_num_max)

16
ml-agents/mlagents/trainers/torch/utils.py


h_size: int,
vis_encode_type: EncoderType,
normalize: bool = False,
) -> Tuple[nn.ModuleList, nn.ModuleList, List[int], int]:
) -> Tuple[nn.ModuleList, nn.ModuleList, List[int]]:
"""
Creates visual and vector encoders, along with their normalizers.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.

var_encoders: List[nn.Module] = []
embedding_sizes: List[int] = []
var_len_indices: List[int] = []
entity_num_max_elements = 0
for idx, obs_spec in enumerate(observation_specs):
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
obs_spec, normalize, h_size, vis_encode_type

x_self_size = sum(embedding_sizes) # The size of the "self" embedding
for idx in var_len_indices:
entity_max: int = obs_spec[idx].shape[0]
EntityEmbedding(x_self_size, obs_spec[idx].shape[1], entity_max, h_size)
EntityEmbedding(
x_self_size, obs_spec[idx].shape[1], obs_spec[idx].shape[0], h_size
)
entity_num_max_elements += entity_max
return (
nn.ModuleList(encoders),
nn.ModuleList(var_encoders),
embedding_sizes,
entity_num_max_elements,
)
return (nn.ModuleList(encoders), nn.ModuleList(var_encoders), embedding_sizes)
@staticmethod
def list_to_tensor(

正在加载...
取消
保存