浏览代码

Use NaNs to get masks for attention

/develop/cc-teammanager
Ervin Teng 4 年前
当前提交
b2c498de
共有 2 个文件被更改,包括 7 次插入6 次删除
  1. 9
      ml-agents/mlagents/trainers/torch/networks.py
  2. 4
      ml-agents/mlagents/trainers/trajectory.py

9
ml-agents/mlagents/trainers/torch/networks.py


x_self = torch.cat(self_encodes, dim=-1)
# Get attention masks by grabbing an arbitrary obs across all the agents
# Since these are raw obs, the padded values are still 0
# Since these are raw obs, the padded values are still NaN
# Get the mask from nans
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
# Get the self encoding separately, but keep it in the entities
concat_encoded_obs = [x_self]

obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))

encoded_entity = self.entity_encoder(x_self, [concat_entites])
encoded_state = self.self_attn(
encoded_entity, EntityEmbeddings.get_masks([obs_for_mask])
)
encoded_state = self.self_attn(encoded_entity, [attn_mask])
if len(concat_encoded_obs) == 0:
raise Exception("No valid inputs to network.")

4
ml-agents/mlagents/trainers/trajectory.py


Convert an AgentBufferField of List of obs, where one of the dimension is time and the other is number (e.g.
in the case of a variable number of critic observations) to a List of obs, where time is in the batch dimension
of the obs, and the List is the variable number of agents. For cases where there are varying number of agents,
pad the non-existent agents with 0.
pad the non-existent agents with NaN.
"""
# Find the first observation. This should be USUALLY O(1)
obs_shape = None

map(
lambda x: np.asanyarray(x),
itertools.zip_longest(
*agent_buffer_field, fillvalue=np.zeros(obs_shape)
*agent_buffer_field, fillvalue=np.full(obs_shape, np.nan)
),
)
)

正在加载...
取消
保存