|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
net_inputs, actions, memories, sequence_length |
|
|
|
inputs, actions, memories, sequence_length |
|
|
|
) |
|
|
|
output = self.value_heads(encoding) |
|
|
|
return output, memories |
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
|
self, |
|
|
|
net_inputs: List[List[torch.Tensor]], |
|
|
|
inputs: List[torch.Tensor], |
|
|
|
inputs: List[List[torch.Tensor]], |
|
|
|
actions: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
|
|
|
self.act_size_vector_deprecated, |
|
|
|
] |
|
|
|
return tuple(export_out) |
|
|
|
|
|
|
|
def get_action_stats( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[ |
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
net_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
action, log_probs, entropies = self.action_model(encoding, masks) |
|
|
|
return action, log_probs, entropies, memories |
|
|
|
|
|
|
|
|
|
|
|
class SharedActorCritic(SimpleActor, ActorCritic): |
|
|
|