浏览代码

Faster NaN masking, fix masking for visual obs (#5015)

* Fix get mask from visual obs, large obs perf imp.

* Bug fix

* Fix typo
/develop/action-slice
GitHub 4 年前
当前提交
c9c7e3d0
共有 1 个文件被更改,包括 35 次插入27 次删除
  1. 62
      ml-agents/mlagents/trainers/torch/networks.py

62
ml-agents/mlagents/trainers/torch/networks.py


Since these are raw obs, the padded values are still NaN
"""
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors]
# flatten for correct dimensions with visual obs
# Just get the first element in each obs regardless of its dimension. This will speed up
# searching for NaNs.
# Get the mask from NaNs
attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor)
return attn_mask

self_attn_masks = []
self_attn_inputs = []
concat_f_inp = []
for inputs, action in zip(obs, actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
if concat_f_inp:
if obs:
obs_attn_mask = self._get_masks_from_nans(obs)
for i_agent, (inputs, action) in enumerate(zip(obs, actions)):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[
obs_attn_mask.type(torch.BoolTensor)[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
action.to_flat(self.action_spec.discrete_branches),
]
concat_f_inp.append(torch.cat(cat_encodes, dim=1))
self_attn_masks.append(self._get_masks_from_nans(obs))
self_attn_masks.append(obs_attn_mask)
for inputs in obs_only:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
g_inp = torch.stack(concat_encoded_obs, dim=1)
self_attn_masks.append(self._get_masks_from_nans(obs_only))
self_attn_inputs.append(self.obs_encoder(None, g_inp))
if obs_only:
obs_only_attn_mask = self._get_masks_from_nans(obs_only)
for i_agent, inputs in enumerate(obs_only):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[
obs_only_attn_mask.type(torch.BoolTensor)[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
processed_obs = processor(obs_input)
encodes.append(processed_obs)
concat_encoded_obs.append(torch.cat(encodes, dim=-1))
g_inp = torch.stack(concat_encoded_obs, dim=1)
self_attn_masks.append(obs_only_attn_mask)
self_attn_inputs.append(self.obs_encoder(None, g_inp))
encoded_entity = torch.cat(self_attn_inputs, dim=1)
encoded_state = self.self_attn(encoded_entity, self_attn_masks)

正在加载...
取消
保存