|
|
|
|
|
|
|
|
|
|
class Actor(abc.ABC): |
|
|
|
@abc.abstractmethod |
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None: |
|
|
|
""" |
|
|
|
Updates normalization of Actor based on the provided List of vector obs. |
|
|
|
:param vector_obs: A List of vector obs as tensors. |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def get_dists( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
) -> Tuple[torch.Tensor, int, int, int, int]: |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def get_dist_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size |
|
|
|
|
|
|
|
def update_normalization(self, obs: List[torch.Tensor]) -> None: |
|
|
|
self.network_body.update_normalization(obs) |
|
|
|
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None: |
|
|
|
self.network_body.update_normalization(net_inputs) |
|
|
|
|
|
|
|
def sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]: |
|
|
|
actions = [] |
|
|
|
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
if actor_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
dists, actor_mem_outs = self.get_dists( |
|
|
|
net_inputs, |
|
|
|
memories=actor_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
masks=masks, |
|
|
|
net_inputs, memories=actor_mem, sequence_length=sequence_length, masks=masks |
|
|
|
) |
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
|
|
|
mem_out = None |
|
|
|
return dists, value_outputs, mem_out |
|
|
|
|
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
super().update_normalization(vector_obs) |
|
|
|
self.critic.network_body.update_normalization(vector_obs) |
|
|
|
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None: |
|
|
|
super().update_normalization(net_inputs) |
|
|
|
self.critic.network_body.update_normalization(net_inputs) |
|
|
|
|
|
|
|
|
|
|
|
class GlobalSteps(nn.Module): |
|
|
|