|
|
|
|
|
|
sensor_specs: List[SensorSpec], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
num_obs_heads: int = 1, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
self.normalize = network_settings.normalize |
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
stream_names: List[str], |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
observation_shapes: List[SensorSpec], |
|
|
|
encoded_act_size: int = 0, |
|
|
|
action_spec: ActionSpec, |
|
|
|
observation_shapes, network_settings, encoded_act_size=encoded_act_size |
|
|
|
observation_shapes, network_settings, action_spec=action_spec |
|
|
|
) |
|
|
|
if network_settings.memory is not None: |
|
|
|
encoding_size = network_settings.memory.memory_size // 2 |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
inputs: List[List[torch.Tensor]], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
value_inputs: List[List[torch.Tensor]], |
|
|
|
q_inputs: List[List[torch.Tensor]], |
|
|
|
q_actions: List[AgentAction] = None, |
|
|
|
inputs, actions, memories, sequence_length |
|
|
|
value_inputs, q_inputs, q_actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, sensor_specs, network_settings |
|
|
|
stream_names, sensor_specs, network_settings, action_spec=action_spec |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
actions: AgentAction, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
critic_obs: List[List[torch.Tensor]] = None, |
|
|
|
|
|
|
if critic_obs is not None and critic_obs: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
mar_value_outputs, _ = self.critic( |
|
|
|
critic_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
else: |
|
|
|
mar_value_outputs = None |
|
|
|
mar_value_outputs, _ = self.critic( |
|
|
|
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
critic_obs, [inputs], [actions], memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if mar_value_outputs is None: |
|
|
|
mar_value_outputs = value_outputs |
|
|
|
|
|
|
all_net_inputs = [inputs] |
|
|
|
if critic_obs is not None and critic_obs: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
mar_value_outputs, _ = self.critic( |
|
|
|
critic_obs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
critic_obs = [] |
|
|
|
mar_value_outputs, _ = self.critic( |
|
|
|
all_net_inputs, [], [], memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
[inputs], critic_obs, actions, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
return log_probs, entropies, value_outputs, mar_value_outputs |
|
|
|