|
|
|
|
|
|
self.continuous = continuous |
|
|
|
|
|
|
|
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): |
|
|
|
super().__init__(policy, trainer_params) |
|
|
|
reward_signal_configs = trainer_params.reward_signals |
|
|
|
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] |
|
|
|
|
|
|
|
self.value_network = ValueNetwork( |
|
|
|
reward_signal_names, |
|
|
|
policy.behavior_spec.observation_specs, |
|
|
|
policy.network_settings, |
|
|
|
) |
|
|
|
|
|
|
|
super().__init__(policy, self.value_network, trainer_params) |
|
|
|
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) |
|
|
|
self.tau = hyperparameters.tau |
|
|
|
self.init_entcoef = hyperparameters.init_entcoef |
|
|
|
|
|
|
} |
|
|
|
self._action_spec = self.policy.behavior_spec.action_spec |
|
|
|
|
|
|
|
self.value_network = TorchSACOptimizer.PolicyValueNetwork( |
|
|
|
self.q_network = TorchSACOptimizer.PolicyValueNetwork( |
|
|
|
self.stream_names, |
|
|
|
self.policy.behavior_spec.observation_specs, |
|
|
|
policy_network_settings, |
|
|
|
|
|
|
self.policy.behavior_spec.observation_specs, |
|
|
|
policy_network_settings, |
|
|
|
) |
|
|
|
ModelUtils.soft_update( |
|
|
|
self.policy.actor_critic.critic, self.target_network, 1.0 |
|
|
|
) |
|
|
|
ModelUtils.soft_update(self.value_network, self.target_network, 1.0) |
|
|
|
|
|
|
|
# We create one entropy coefficient per action, whether discrete or continuous. |
|
|
|
_disc_log_ent_coef = torch.nn.Parameter( |
|
|
|
|
|
|
self.target_entropy = TorchSACOptimizer.TargetEntropy( |
|
|
|
continuous=_cont_target, discrete=_disc_target |
|
|
|
) |
|
|
|
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( |
|
|
|
self.policy.actor_critic.action_model.parameters() |
|
|
|
) |
|
|
|
value_params = list(self.value_network.parameters()) + list( |
|
|
|
self.policy.actor_critic.critic.parameters() |
|
|
|
policy_params = list(self.policy.actor.parameters()) |
|
|
|
value_params = list(self.q_network.parameters()) + list( |
|
|
|
self.value_network.parameters() |
|
|
|
) |
|
|
|
|
|
|
|
logger.debug("value_vars") |
|
|
|
|
|
|
def _move_to_device(self, device: torch.device) -> None: |
|
|
|
self._log_ent_coef.to(device) |
|
|
|
self.target_network.to(device) |
|
|
|
self.value_network.to(device) |
|
|
|
self.q_network.to(device) |
|
|
|
|
|
|
|
def sac_q_loss( |
|
|
|
self, |
|
|
|
|
|
|
offset = 1 if self.policy.sequence_length > 1 else 0 |
|
|
|
next_memories_list = [ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
batch[BufferKey.MEMORY][i][self.policy.m_size // 2 :] |
|
|
|
batch[BufferKey.MEMORY][i] |
|
|
|
) # only pass value part of memory to target network |
|
|
|
for i in range( |
|
|
|
offset, len(batch[BufferKey.MEMORY]), self.policy.sequence_length |
|
|
|
|
|
|
else: |
|
|
|
memories = None |
|
|
|
next_memories = None |
|
|
|
# Q network memories are 0'ed out, since we don't have them during inference. |
|
|
|
|
|
|
|
# Q and V network memories are 0'ed out, since we don't have them during inference. |
|
|
|
v_memories = ( |
|
|
|
torch.zeros_like(next_memories) if next_memories is not None else None |
|
|
|
) |
|
|
|
self.value_network.q1_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
self.q_network.q1_network.network_body.copy_normalization( |
|
|
|
self.policy.actor.network_body |
|
|
|
self.value_network.q2_network.network_body.copy_normalization( |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
self.q_network.q2_network.network_body.copy_normalization( |
|
|
|
self.policy.actor.network_body |
|
|
|
self.policy.actor_critic.network_body |
|
|
|
self.policy.actor.network_body |
|
|
|
( |
|
|
|
sampled_actions, |
|
|
|
log_probs, |
|
|
|
_, |
|
|
|
value_estimates, |
|
|
|
_, |
|
|
|
) = self.policy.actor_critic.get_action_stats_and_value( |
|
|
|
self.value_network.network_body.copy_normalization( |
|
|
|
self.policy.actor.network_body |
|
|
|
) |
|
|
|
sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats( |
|
|
|
) |
|
|
|
value_estimates, _ = self.value_network.critic_pass( |
|
|
|
current_obs, v_memories, sequence_length=self.policy.sequence_length |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
q1p_out, q2p_out = self.q_network( |
|
|
|
current_obs, |
|
|
|
cont_sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
q1_out, q2_out = self.q_network( |
|
|
|
current_obs, |
|
|
|
cont_actions, |
|
|
|
memories=q_memories, |
|
|
|
|
|
|
self.entropy_optimizer.step() |
|
|
|
|
|
|
|
# Update target network |
|
|
|
ModelUtils.soft_update( |
|
|
|
self.policy.actor_critic.critic, self.target_network, self.tau |
|
|
|
) |
|
|
|
ModelUtils.soft_update(self.value_network, self.target_network, self.tau) |
|
|
|
update_stats = { |
|
|
|
"Losses/Policy Loss": policy_loss.item(), |
|
|
|
"Losses/Value Loss": value_loss.item(), |
|
|
|
|
|
|
|
|
|
|
def get_modules(self): |
|
|
|
modules = { |
|
|
|
"Optimizer:value_network": self.value_network, |
|
|
|
"Optimizer:value_network": self.q_network, |
|
|
|
"Optimizer:target_network": self.target_network, |
|
|
|
"Optimizer:policy_optimizer": self.policy_optimizer, |
|
|
|
"Optimizer:value_optimizer": self.value_optimizer, |
|
|
|