浏览代码

Add Q function with attention

/develop/centralizedcritic/counterfact
Ervin Teng 4 年前
当前提交
5d7345a6
共有 2 个文件被更改,包括 55 次插入32 次删除
  1. 6
      ml-agents/mlagents/trainers/torch/agent_action.py
  2. 81
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents/mlagents/trainers/torch/agent_action.py


discrete_tensor[..., i] for i in range(discrete_tensor.shape[-1])
]
return AgentAction(continuous, discrete)
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor:
discrete_oh = ModelUtils.actions_to_onehot(
self.discrete_tensor, discrete_branches
)
return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)

81
ml-agents/mlagents/trainers/torch/networks.py


self,
sensor_specs: List[SensorSpec],
network_settings: NetworkSettings,
encoded_act_size: int = 0,
action_spec: ActionSpec,
num_obs_heads: int = 1,
):
super().__init__()

network_settings.vis_encode_type,
normalize=self.normalize,
)
self.action_spec = action_spec
obs_only_ent_size = sum(_input_size)
q_ent_size = (
sum(_input_size)
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
sum(_input_size), [sum(_input_size)], self.h_size
0, [obs_only_ent_size, q_ent_size], self.h_size, concat_self=False
total_enc_size = encoder_input_size + encoded_act_size
total_enc_size, network_settings.num_layers, self.h_size
encoder_input_size, network_settings.num_layers, self.h_size
)
if self.use_lstm:

if isinstance(n1, VectorInput) and isinstance(n2, VectorInput):
n1.copy_normalization(n2)
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
"""
Get attention masks by grabbing an arbitrary obs across all the agents
Since these are raw obs, the padded values are still NaN
"""
only_first_obs = [_all_obs[0] for _all_obs in obs_tensors]
obs_for_mask = torch.stack(only_first_obs, dim=1)
# Get the mask from nans
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
return attn_mask
all_net_inputs: List[List[torch.Tensor]],
actions: Optional[torch.Tensor] = None,
value_inputs: List[List[torch.Tensor]],
q_inputs: List[List[torch.Tensor]],
q_actions: List[AgentAction],
concat_encoded_obs = []
x_self = None
self_encodes = []
inputs = all_net_inputs[0]
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
processed_obs = processor(obs_input)
self_encodes.append(processed_obs)
x_self = torch.cat(self_encodes, dim=-1)
# Get attention masks by grabbing an arbitrary obs across all the agents
# Since these are raw obs, the padded values are still NaN
only_first_obs = [_all_obs[0] for _all_obs in all_net_inputs]
obs_for_mask = torch.stack(only_first_obs, dim=1)
# Get the mask from nans
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor)
# Get the self encoding separately, but keep it in the entities
concat_enc_q_obs = []
for inputs, actions in zip(q_inputs, q_actions):
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs
processed_obs = processor(obs_input)
encodes.append(processed_obs)
cat_encodes = [
torch.cat(encodes, dim=-1),
actions.to_flat(self.action_spec.discrete_branches),
]
concat_enc_q_obs.append(torch.cat(cat_encodes, dim=-1))
q_input_concat = torch.stack(concat_enc_q_obs, dim=1)
concat_encoded_obs = [x_self]
for inputs in all_net_inputs[1:]:
concat_encoded_obs = []
for inputs in value_inputs:
encodes = []
for idx, processor in enumerate(self.processors):
obs_input = inputs[idx]

concat_encoded_obs.append(torch.cat(encodes, dim=-1))
concat_entites = torch.stack(concat_encoded_obs, dim=1)
value_input_concat = torch.stack(concat_encoded_obs, dim=1)
encoded_entity = self.entity_encoder(x_self, [concat_entites])
encoded_state = self.self_attn(encoded_entity, [attn_mask])
# Get the mask from nans
value_masks = self._get_masks_from_nans(value_inputs)
q_masks = self._get_masks_from_nans(q_inputs)
encoded_entity = self.entity_encoder(None, [value_input_concat, q_input_concat])
encoded_state = self.self_attn(encoded_entity, [value_masks, q_masks])
# Constants don't work in Barracuda
if actions is not None:
inputs = torch.cat([encoded_state, actions], dim=-1)
else:
inputs = encoded_state
inputs = encoded_state
encoding = self.linear_encoder(inputs)
if self.use_lstm:

正在加载...
取消
保存