|
|
|
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
def get_action_stats( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: |
|
|
|
""" |
|
|
|
Returns sampled actions. |
|
|
|
If memory is enabled, return the memories as well. |
|
|
|
:param vec_inputs: A List of vector inputs as tensors. |
|
|
|
:param vis_inputs: A List of visual inputs as tensors. |
|
|
|
:param masks: If using discrete actions, a Tensor of action masks. |
|
|
|
:param memories: If using memory, a Tensor of initial memories. |
|
|
|
:param sequence_length: If using memory, the sequence length. |
|
|
|
:return: A Tuple of AgentAction, ActionLogProbs, entropies, and memories. |
|
|
|
Memories will be None if not using memory. |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def forward( |
|
|
|
self, |
|
|
|
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
""" |
|
|
|
Returns distributions, from which actions can be sampled, and value estimates. |
|
|
|
Returns sampled actions and value estimates. |
|
|
|
If memory is enabled, return the memories as well. |
|
|
|
:param vec_inputs: A List of vector inputs as tensors. |
|
|
|
:param vis_inputs: A List of visual inputs as tensors. |
|
|
|
|
|
|
def update_normalization(self, vector_obs: List[torch.Tensor]) -> None: |
|
|
|
self.network_body.update_normalization(vector_obs) |
|
|
|
|
|
|
|
def get_action_stats( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
action, log_probs, entropies = self.action_model(encoding, masks) |
|
|
|
return action, log_probs, entropies, memories |
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
def memory_size(self) -> int: |
|
|
|
return self.network_body.memory_size + self.critic.memory_size |
|
|
|
|
|
|
|
def _get_actor_critic_mem( |
|
|
|
self, memories: Optional[torch.Tensor] = None |
|
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
|
if self.use_lstm and memories is not None: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
return actor_mem, critic_mem |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
actor_mem, critic_mem = None, None |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
encoding, actor_mem_outs = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
return log_probs, entropies, value_outputs |
|
|
|
|
|
|
|
def get_action_stats( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
vis_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
action, log_probs, entropies, actor_mem_out = super().get_action_stats( |
|
|
|
vec_inputs, |
|
|
|
vis_inputs, |
|
|
|
masks=masks, |
|
|
|
memories=actor_mem, |
|
|
|
sequence_length=sequence_length, |
|
|
|
) |
|
|
|
if critic_mem is not None: |
|
|
|
# Make memories with the actor mem unchanged |
|
|
|
memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1) |
|
|
|
else: |
|
|
|
memories_out = None |
|
|
|
return action, log_probs, entropies, memories_out |
|
|
|
|
|
|
|
def get_action_stats_and_value( |
|
|
|
self, |
|
|
|
vec_inputs: List[torch.Tensor], |
|
|
|
|
|
|
) -> Tuple[ |
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic and actor |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) |
|
|
|
else: |
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
actor_mem, critic_mem = self._get_actor_critic_mem(memories) |
|
|
|
encoding, actor_mem_outs = self.network_body( |
|
|
|
vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|