|
|
|
|
|
|
"Losses/Value Loss": "value_loss", |
|
|
|
"Losses/Policy Loss": "policy_loss", |
|
|
|
"Losses/Model Loss": "model_loss", |
|
|
|
"Losses/Reward Loss": "reward_loss", |
|
|
|
"Policy/Learning Rate": "learning_rate", |
|
|
|
"Policy/Model Learning Rate": "model_learning_rate", |
|
|
|
"Policy/Epsilon": "decay_epsilon", |
|
|
|
|
|
|
if policy.use_continuous_act: |
|
|
|
self._create_cc_critic_old(h_size, hyperparameters.value_layers, vis_encode_type) |
|
|
|
else: |
|
|
|
self._create_dc_critic(h_size, hyperparameters.value_layers, vis_encode_type) |
|
|
|
self._create_dc_critic_old(h_size, hyperparameters.value_layers, vis_encode_type) |
|
|
|
|
|
|
|
with tf.variable_scope("optimizer/"): |
|
|
|
self.learning_rate = ModelUtils.create_schedule( |
|
|
|
|
|
|
"value_loss": self.value_loss, |
|
|
|
"policy_loss": self.abs_policy_loss, |
|
|
|
"model_loss": self.model_loss, |
|
|
|
"reward_loss": self.policy.reward_loss, |
|
|
|
"update_batch": self.update_batch, |
|
|
|
"learning_rate": self.learning_rate, |
|
|
|
"decay_epsilon": self.decay_epsilon, |
|
|
|
|
|
|
# self.model_loss += self.policy.predict_distribution.kl_standard() |
|
|
|
|
|
|
|
self.model_loss = self.policy.forward_loss |
|
|
|
if self.predict_return: |
|
|
|
self.model_loss += self.policy.reward_loss |
|
|
|
if self.with_prior: |
|
|
|
if self.use_var_encoder: |
|
|
|
self.model_loss += 0.2 * self.policy.encoder_distribution.kl_standard() |
|
|
|
|
|
|
"learning_rate": self.learning_rate, |
|
|
|
"decay_epsilon": self.decay_epsilon, |
|
|
|
"decay_beta": self.decay_beta, |
|
|
|
"reward_loss": self.policy.reward_loss, |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
"model_learning_rate": self.model_learning_rate, |
|
|
|
"decay_epsilon": self.decay_epsilon, |
|
|
|
"decay_beta": self.decay_beta, |
|
|
|
"reward_loss": self.policy.reward_loss, |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
:return: Results of update. |
|
|
|
""" |
|
|
|
feed_dict = self._construct_feed_dict(batch, num_sequences) |
|
|
|
if update_type == "model": |
|
|
|
stats_needed = { |
|
|
|
"Losses/Model Loss": "model_loss", |
|
|
|
"Policy/Learning Rate": "model_learning_rate", |
|
|
|
"Policy/Epsilon": "decay_epsilon", |
|
|
|
"Policy/Beta": "decay_beta", |
|
|
|
} |
|
|
|
elif update_type == "policy": |
|
|
|
stats_needed = { |
|
|
|
"Losses/Value Loss": "value_loss", |
|
|
|
"Losses/Policy Loss": "policy_loss", |
|
|
|
"Policy/Learning Rate": "learning_rate", |
|
|
|
"Policy/Epsilon": "decay_epsilon", |
|
|
|
"Policy/Beta": "decay_beta", |
|
|
|
} |
|
|
|
stats_needed = self.stats_name_to_update_name |
|
|
|
# if update_type == "model": |
|
|
|
# stats_needed = { |
|
|
|
# "Losses/Model Loss": "model_loss", |
|
|
|
# "Policy/Learning Rate": "model_learning_rate", |
|
|
|
# "Policy/Epsilon": "decay_epsilon", |
|
|
|
# "Policy/Beta": "decay_beta", |
|
|
|
# } |
|
|
|
# elif update_type == "policy": |
|
|
|
# stats_needed = { |
|
|
|
# "Losses/Value Loss": "value_loss", |
|
|
|
# "Losses/Policy Loss": "policy_loss", |
|
|
|
# "Policy/Learning Rate": "learning_rate", |
|
|
|
# "Policy/Epsilon": "decay_epsilon", |
|
|
|
# "Policy/Beta": "decay_beta", |
|
|
|
# } |
|
|
|
update_stats = {} |
|
|
|
# Collect feed dicts for all reward signals. |
|
|
|
for _, reward_signal in self.reward_signals.items(): |
|
|
|
|
|
|
self.policy.run_hard_copy() |
|
|
|
|
|
|
|
for stat_name, update_name in stats_needed.items(): |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
if update_name in update_vals.keys(): |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
|
|
|
|
self.num_updates += 1 |
|
|
|
return update_stats |
|
|
|
|
|
|
self.policy.processed_vector_next: mini_batch["next_vector_in"], |
|
|
|
# self.policy.next_vector_in: mini_batch["next_vector_in"], |
|
|
|
self.policy.current_action: mini_batch["actions"], |
|
|
|
self.policy.current_reward: mini_batch["discounted_returns"], |
|
|
|
self.policy.current_reward: mini_batch["extrinsic_rewards"], |
|
|
|
# self.dis_returns: mini_batch["discounted_returns"] |
|
|
|
} |
|
|
|
for name in self.reward_signals: |
|
|
|