|
|
|
|
|
|
h_size: int, |
|
|
|
vis_encode_type: EncoderType, |
|
|
|
normalize: bool = False, |
|
|
|
) -> Tuple[nn.ModuleList, nn.ModuleList, List[int], int]: |
|
|
|
) -> Tuple[nn.ModuleList, nn.ModuleList, List[int]]: |
|
|
|
""" |
|
|
|
Creates visual and vector encoders, along with their normalizers. |
|
|
|
:param observation_specs: List of ObservationSpec that represent the observation dimensions. |
|
|
|
|
|
|
var_encoders: List[nn.Module] = [] |
|
|
|
embedding_sizes: List[int] = [] |
|
|
|
var_len_indices: List[int] = [] |
|
|
|
entity_num_max_elements = 0 |
|
|
|
for idx, obs_spec in enumerate(observation_specs): |
|
|
|
encoder, embedding_size = ModelUtils.get_encoder_for_obs( |
|
|
|
obs_spec, normalize, h_size, vis_encode_type |
|
|
|
|
|
|
|
|
|
|
x_self_size = sum(embedding_sizes) # The size of the "self" embedding |
|
|
|
for idx in var_len_indices: |
|
|
|
entity_max: int = obs_spec[idx].shape[0] |
|
|
|
EntityEmbedding(x_self_size, obs_spec[idx].shape[1], entity_max, h_size) |
|
|
|
EntityEmbedding( |
|
|
|
x_self_size, obs_spec[idx].shape[1], obs_spec[idx].shape[0], h_size |
|
|
|
) |
|
|
|
entity_num_max_elements += entity_max |
|
|
|
return ( |
|
|
|
nn.ModuleList(encoders), |
|
|
|
nn.ModuleList(var_encoders), |
|
|
|
embedding_sizes, |
|
|
|
entity_num_max_elements, |
|
|
|
) |
|
|
|
return (nn.ModuleList(encoders), nn.ModuleList(var_encoders), embedding_sizes) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def list_to_tensor( |
|
|
|