浏览代码

addressing comments and adding the changes to rpc_utils

/develop/singular-embeddings
vincentpierre 4 年前
当前提交
efa5a164
共有 2 个文件被更改,包括 16 次插入26 次删除
  1. 29
      ml-agents-envs/mlagents_envs/rpc_utils.py
  2. 13
      ml-agents/mlagents/trainers/torch/networks.py

29
ml-agents-envs/mlagents_envs/rpc_utils.py


:param agent_info: protobuf object.
:return: BehaviorSpec object.
"""
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
dim_props = [
tuple(DimensionProperty(dim) for dim in obs.dimension_properties)
for obs in agent_info.observations
]
dim_props = [
dim_prop
if len(dim_prop) > 0
else (DimensionProperty.UNSPECIFIED,) * len(observation_shape[idx])
for idx, dim_prop in enumerate(dim_props)
]
obs_types = [
ObservationType(obs.observation_type) for obs in agent_info.observations
]
observation_specs = [
ObservationSpec(obs_shape, dim_p, obs_type)
for obs_shape, dim_p, obs_type in zip(observation_shape, dim_props, obs_types)
]
observation_specs = []
for obs in agent_info.observations:
observation_specs.append(
ObservationSpec(
tuple(obs.shape),
tuple(DimensionProperty(dim) for dim in obs.dimension_properties)
if len(obs.dimension_properties) > 0
else (DimensionProperty.UNSPECIFIED,) * len(obs.shape),
ObservationType(obs.observation_type),
)
)
# proto from communicator < v1.3 does not set action spec, use deprecated fields instead
if (

13
ml-agents/mlagents/trainers/torch/networks.py


sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
var_len_inputs = [] # The list of variable length inputs
var_len_processors = [
p for p in self.processors if isinstance(p, EntityEmbedding)
]
var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []
for idx, processor in enumerate(self.processors):
if not isinstance(processor, EntityEmbedding):

encodes.append(processed_obs)
else:
var_len_inputs.append(inputs[idx])
var_len_processor_inputs.append((processor, inputs[idx]))
if len(var_len_inputs) > 0:
if len(var_len_processor_inputs) > 0:
masks = get_zero_entities_mask(var_len_inputs)
masks = get_zero_entities_mask([p_i[1] for p_i in var_len_processor_inputs])
for var_len_input, processor in zip(var_len_inputs, var_len_processors):
for processor, var_len_input in var_len_processor_inputs:
embeddings.append(processor(processed_self, var_len_input))
qkv = torch.cat(embeddings, dim=1)
attention_embedding = self.rsa(qkv, masks)

正在加载...
取消
保存