|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
diff=False, |
|
|
|
) -> Tuple[ |
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
|
|
|
return action, log_probs, entropies, memories |
|
|
|
|
|
|
|
|
|
|
|
class SharedActorCritic(SimpleActor, ActorCritic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
action_spec: ActionSpec, |
|
|
|
stream_names: List[str], |
|
|
|
conditional_sigma: bool = False, |
|
|
|
tanh_squash: bool = False, |
|
|
|
): |
|
|
|
self.use_lstm = network_settings.memory is not None |
|
|
|
super().__init__( |
|
|
|
observation_shapes, |
|
|
|
network_settings, |
|
|
|
action_spec, |
|
|
|
conditional_sigma, |
|
|
|
tanh_squash, |
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.value_heads = ValueHeads(stream_names, self.encoding_size) |
|
|
|
|
|
|
|
def critic_pass( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: |
|
|
|
encoding, memories_out = self.network_body( |
|
|
|
net_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
return self.value_heads(encoding), memories_out |
|
|
|
|
|
|
|
def get_stats_and_value( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
actions: AgentAction, |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
critic_obs: Optional[List[List[torch.Tensor]]] = None, |
|
|
|
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
net_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
value_outputs = self.value_heads(encoding) |
|
|
|
return log_probs, entropies, value_outputs |
|
|
|
|
|
|
|
def get_action_stats_and_value( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[ |
|
|
|
AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor |
|
|
|
]: |
|
|
|
|
|
|
|
encoding, memories = self.network_body( |
|
|
|
net_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
action, log_probs, entropies = self.action_model(encoding, masks) |
|
|
|
value_outputs = self.value_heads(encoding) |
|
|
|
return action, log_probs, entropies, value_outputs, memories |
|
|
|
|
|
|
|
|
|
|
|
class SeparateActorCritic(SimpleActor, ActorCritic): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
|
|
|
tanh_squash, |
|
|
|
) |
|
|
|
self.stream_names = stream_names |
|
|
|
self.critic = CentralizedValueNetwork( |
|
|
|
stream_names, observation_shapes, network_settings, num_agents=2 |
|
|
|
) |
|
|
|
self.critic = ValueNetwork(stream_names, observation_shapes, network_settings) |
|
|
|
# self.critic = CentralizedValueNetwork( |
|
|
|
# stream_names, observation_shapes, network_settings, num_agents=2 |
|
|
|
# ) |
|
|
|
|
|
|
|
@property |
|
|
|
def memory_size(self) -> int: |
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Use only the back half of memories for critic |
|
|
|
actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, -1) |
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
all_net_inputs = net_inputs |
|
|
|
# if critic_obs is not None: |
|
|
|
# all_net_inputs.extend(critic_obs) |
|
|
|
value_outputs, critic_mem_out = self.critic( |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
net_inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) |
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
all_net_inputs = net_inputs |
|
|
|
# if critic_obs is not None: |
|
|
|
# all_net_inputs.extend(critic_obs) |
|
|
|
value_outputs, critic_mem_outs = self.critic( |
|
|
|
all_net_inputs, memories=critic_mem, sequence_length=sequence_length |
|
|
|
) |
|
|
|
|
|
|
critic_mem = None |
|
|
|
actor_mem = None |
|
|
|
|
|
|
|
all_net_inputs = [net_inputs] |
|
|
|
if critic_obs is not None: |
|
|
|
all_net_inputs.extend(critic_obs) |
|
|
|
all_net_inputs = net_inputs |
|
|
|
# if critic_obs is not None: |
|
|
|
# all_net_inputs.extend(critic_obs) |
|
|
|
|
|
|
|
encoding, actor_mem_outs = self.network_body( |
|
|
|
net_inputs, memories=actor_mem, sequence_length=sequence_length |
|
|
|
|
|
|
else: |
|
|
|
mem_out = None |
|
|
|
return action, log_probs, entropies, value_outputs, mem_out |
|
|
|
|
|
|
|
def get_comms( |
|
|
|
self, |
|
|
|
net_inputs: List[torch.Tensor], |
|
|
|
masks: Optional[torch.Tensor] = None, |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor]: |
|
|
|
encoding, memories = self.network_body( |
|
|
|
net_inputs, memories=memories, sequence_length=sequence_length |
|
|
|
) |
|
|
|
comms = self.action_model.get_comms(encoding, masks) |
|
|
|
return comms |
|
|
|
|
|
|
|
def update_normalization(self, net_inputs: List[torch.Tensor]) -> None: |
|
|
|
super().update_normalization(net_inputs) |
|
|
|