|
|
|
|
|
|
return encoding, memories |
|
|
|
|
|
|
|
|
|
|
|
class ValueNetwork(nn.Module): |
|
|
|
class Critic(abc.ABC): |
|
|
|
@abc.abstractmethod |
|
|
|
def update_normalization(self, buffer: AgentBuffer) -> None: |
|
|
|
""" |
|
|
|
Updates normalization of Actor based on the provided List of vector obs. |
|
|
|
:param vector_obs: A List of vector obs as tensors. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
""" |
|
|
|
Get value outputs for the given obs. |
|
|
|
:param inputs: List of inputs as tensors. |
|
|
|
:param memories: Tensor of memories, if using memory. Otherwise, None. |
|
|
|
:returns: Dict of reward stream to output tensor for values. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class ValueNetwork(nn.Module, Critic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
stream_names: List[str], |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
""" |
|
|
|
Get value outputs for the given obs. |
|
|
|
:param inputs: List of inputs as tensors. |
|
|
|
:param memories: Tensor of memories, if using memory. Otherwise, None. |
|
|
|
:returns: Dict of reward stream to output tensor for values. |
|
|
|
""" |
|
|
|
value_outputs, critic_mem_out = self.forward( |
|
|
|
inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
self.act_size_vector_deprecated, |
|
|
|
] |
|
|
|
return tuple(export_out) |
|
|
|
|
|
|
|
|
|
|
|
class SharedActorCritic(SimpleActor, Critic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_specs: List[ObservationSpec], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
stream_names: List[str], |
|
|
|
conditional_sigma: bool = False, |
|
|
|
tanh_squash: bool = False, |
|
|
|
): |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
super().__init__( |
|
|
|
observation_specs, |
|
|
|
network_settings, |
|
|
|
action_spec, |
|
|
|
conditional_sigma, |
|
|
|
tanh_squash, |
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories_out = self.network_body( |
|
|
|
inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
return self.value_heads(encoding), memories_out |
|
|
|
|
|
|
|
|
|
|
|
class GlobalSteps(nn.Module): |
|
|
|