|
|
|
|
|
|
sensor_specs: List[SensorSpec], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
baseline: bool = False |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.normalize = network_settings.normalize |
|
|
|
|
|
|
if network_settings.memory is not None |
|
|
|
else 0 |
|
|
|
) |
|
|
|
self.is_baseline = True |
|
|
|
|
|
|
|
|
|
|
|
self.processors, _input_size = ModelUtils.create_input_processors( |
|
|
|
sensor_specs, |
|
|
|
self.h_size, |
|
|
|
|
|
|
|
|
|
|
# Modules for self-attention |
|
|
|
obs_only_ent_size = sum(_input_size) |
|
|
|
q_ent_size = ( |
|
|
|
sum(_input_size) |
|
|
|
+ sum(self.action_spec.discrete_branches) |
|
|
|
+ self.action_spec.continuous_size |
|
|
|
) |
|
|
|
self.obs_encoder = EntityEmbedding( |
|
|
|
0, obs_only_ent_size, None, self.h_size, concat_self=False |
|
|
|
) |
|
|
|
self.obs_action_encoder = EntityEmbedding( |
|
|
|
0, q_ent_size, None, self.h_size, concat_self=False |
|
|
|
) |
|
|
|
|
|
|
|
self.self_encoder = None |
|
|
|
|
|
|
|
if baseline: |
|
|
|
self.self_encoder = LinearEncoder( |
|
|
|
obs_only_ent_size, 1, self.h_size |
|
|
|
) |
|
|
|
|
|
|
|
self.obs_encoder = EntityEmbedding( |
|
|
|
self.h_size, obs_only_ent_size, None, self.h_size, concat_self=True |
|
|
|
) |
|
|
|
|
|
|
|
else: |
|
|
|
self.obs_encoder = EntityEmbedding( |
|
|
|
0, obs_only_ent_size, None, self.h_size, concat_self=False |
|
|
|
) |
|
|
|
|
|
|
|
self.linear_encoder = LinearEncoder( |
|
|
|
encoder_input_size, network_settings.num_layers, self.h_size |
|
|
|
|
|
|
attn_mask = torch.any(obs_for_mask.isnan(), dim=2).type(torch.FloatTensor) |
|
|
|
return attn_mask |
|
|
|
|
|
|
|
def q_net( |
|
|
|
self, |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
self_attn_masks = [] |
|
|
|
concat_f_inp = [] |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
|
|
|
|
f_inp = torch.stack(concat_f_inp, dim=1) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans(obs)) |
|
|
|
encoding, memories = self.forward( |
|
|
|
f_inp, |
|
|
|
None, |
|
|
|
self_attn_masks, |
|
|
|
memories=memories, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
actions: List[AgentAction], |
|
|
|
|
|
|
|
f_inp = None |
|
|
|
concat_f_inp = [] |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
|
|
|
|
if concat_f_inp: |
|
|
|
f_inp = torch.stack(concat_f_inp, dim=1) |
|
|
|
concat_encoded_obs = [] |
|
|
|
g_inp = None |
|
|
|
if len(obs) > 0: |
|
|
|
for inputs in obs: |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
g_inp = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
concat_encoded_obs = [] |
|
|
|
encodes = [] |
|
|
|
self_encodes = [] |
|
|
|
encodes.append(processed_obs) |
|
|
|
concat_encoded_obs.append(torch.cat(encodes, dim=-1)) |
|
|
|
g_inp = torch.stack(concat_encoded_obs, dim=1) |
|
|
|
self_encodes.append(processed_obs) |
|
|
|
|
|
|
|
encoded_self = self.self_encoder(torch.cat(self_encodes, dim=-1)) |
|
|
|
self_attn_masks.append(self._get_masks_from_nans([self_obs])) |
|
|
|
f_inp, |
|
|
|
encoded_self, |
|
|
|
|
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
def value( |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
f_enc: torch.Tensor, |
|
|
|
encoded_self: torch.Tensor, |
|
|
|
g_enc: torch.Tensor, |
|
|
|
self_attn_masks: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
|
|
|
self_attn_inputs = [] |
|
|
|
|
|
|
|
if f_enc is not None: |
|
|
|
self_attn_inputs.append(self.obs_action_encoder(None, f_enc)) |
|
|
|
if g_enc is not None: |
|
|
|
if self.is_baseline: |
|
|
|
if g_enc is not None: |
|
|
|
self_attn_inputs.append(self.obs_encoder(encoded_self, g_enc)) |
|
|
|
encoded_entity = torch.cat(self_attn_inputs, dim=1) |
|
|
|
inputs = self.self_attn(encoded_entity, self_attn_masks) |
|
|
|
else: |
|
|
|
inputs = encoded_self |
|
|
|
else: |
|
|
|
|
|
|
|
encoded_entity = torch.cat(self_attn_inputs, dim=1) |
|
|
|
encoded_state = self.self_attn(encoded_entity, self_attn_masks) |
|
|
|
encoded_entity = torch.cat(self_attn_inputs, dim=1) |
|
|
|
inputs = self.self_attn(encoded_entity, self_attn_masks) |
|
|
|
inputs = encoded_state |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
outputs_per_stream: int = 1, |
|
|
|
baseline: bool = False |
|
|
|
observation_shapes, network_settings, action_spec=action_spec |
|
|
|
observation_shapes, network_settings, action_spec=action_spec, baseline=baseline |
|
|
|
) |
|
|
|
if network_settings.memory is not None: |
|
|
|
encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
|
|
|
|
|
|
|
def q_net( |
|
|
|
self, |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encoding, memories = self.network_body.q_net( |
|
|
|
obs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
def value( |
|
|
|
self, |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
|
|
|
self, |
|
|
|
self_obs: List[List[torch.Tensor]], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: List[AgentAction], |
|
|
|
self_obs, obs, actions, memories, sequence_length |
|
|
|
self_obs, obs, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
value_inputs: List[List[torch.Tensor]], |
|
|
|
q_inputs: List[List[torch.Tensor]], |
|
|
|
q_actions: List[AgentAction] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec |
|
|
|
) |
|
|
|
self.target = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec |
|
|
|
self.baseline_critic = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec, baseline=True |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
actor_mem = None |
|
|
|
return actor_mem, critic_mem |
|
|
|
|
|
|
|
def target_critic_value( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
team_obs: List[List[torch.Tensor]] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
|
|
|
|
all_obs = [inputs] |
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
|
|
|
|
value_outputs, _ = self.target.value( |
|
|
|
all_obs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
# mar_value_outputs = value_outputs |
|
|
|
|
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) |
|
|
|
else: |
|
|
|
memories_out = None |
|
|
|
return value_outputs, memories_out |
|
|
|
|
|
|
|
def critic_value( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
team_obs: List[List[torch.Tensor]] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
|
|
|
|
all_obs = [inputs] |
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
|
|
|
|
value_outputs, _ = self.critic.value( |
|
|
|
all_obs, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
# mar_value_outputs = value_outputs |
|
|
|
|
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) |
|
|
|
else: |
|
|
|
memories_out = None |
|
|
|
return value_outputs, memories_out |
|
|
|
|
|
|
|
def target_critic_pass( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
actions: AgentAction, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
team_obs: List[List[torch.Tensor]] = None, |
|
|
|
team_act: List[AgentAction] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
|
|
|
|
all_obs = [inputs] |
|
|
|
if team_obs is not None and team_obs: |
|
|
|
all_obs.extend(team_obs) |
|
|
|
all_acts = [actions] |
|
|
|
if team_act is not None and team_act: |
|
|
|
all_acts.extend(team_act) |
|
|
|
|
|
|
|
baseline_outputs, _ = self.target.baseline( |
|
|
|
inputs, |
|
|
|
team_obs, |
|
|
|
team_act, |
|
|
|
memories=critic_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
|
|
|
|
value_outputs, critic_mem_out = self.target.q_net( |
|
|
|
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
# mar_value_outputs = value_outputs |
|
|
|
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
actions: AgentAction, |
|
|
|
team_act: List[AgentAction] = None, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
|
|
|
|
|
|
|
all_acts = [actions] |
|
|
|
if team_act is not None and team_act: |
|
|
|
all_acts.extend(team_act) |
|
|
|
|
|
|
|
baseline_outputs, _ = self.critic.baseline( |
|
|
|
|
|
|
|
baseline_outputs, _ = self.baseline_critic.baseline( |
|
|
|
team_act, |
|
|
|
value_outputs, critic_mem_out = self.critic.q_net( |
|
|
|
all_obs, all_acts, memories=critic_mem, sequence_length=sequence_length |
|
|
|
value_outputs, critic_mem_out = self.critic.value( |
|
|
|
all_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
# if mar_value_outputs is None: |
|
|
|
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
|
|
|
|
q_outputs, baseline_outputs, _ = self.critic_pass( |
|
|
|
value_outputs, baseline_outputs, _ = self.critic_pass( |
|
|
|
actions, |
|
|
|
team_act=team_act, |
|
|
|
value_outputs, _ = self.target_critic_value(inputs, memories=critic_mem, sequence_length=sequence_length, team_obs=team_obs) |
|
|
|
return log_probs, entropies, q_outputs, baseline_outputs, value_outputs |
|
|
|
return log_probs, entropies, value_outputs, baseline_outputs |
|
|
|
|
|
|
|
def get_action_stats( |
|
|
|
self, |
|
|
|