|
|
|
|
|
|
hyperparameters.load_action, |
|
|
|
) |
|
|
|
self.policy.run_hard_copy() |
|
|
|
self.num_updates = 0 |
|
|
|
|
|
|
|
print("All variables in the graph:") |
|
|
|
for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): |
|
|
|
|
|
|
self.update_dict = { |
|
|
|
"value_loss": self.total_value_loss, |
|
|
|
"policy_loss": self.policy_loss, |
|
|
|
# "model_loss": self.model_loss, |
|
|
|
# "model_learning_rate": self.model_learning_rate, |
|
|
|
# "reward_loss": self.policy.reward_loss, |
|
|
|
"q1_loss": self.q1_loss, |
|
|
|
"q2_loss": self.q2_loss, |
|
|
|
"entropy_coef": self.ent_coef, |
|
|
|
|
|
|
"learning_rate": self.learning_rate, |
|
|
|
} |
|
|
|
|
|
|
|
if self.use_transfer: |
|
|
|
self.update_dict.update({ |
|
|
|
"model_loss": self.model_loss, |
|
|
|
"model_learning_rate": self.model_learning_rate, |
|
|
|
"reward_loss": self.policy.reward_loss, |
|
|
|
}) |
|
|
|
|
|
|
|
def _create_inputs_and_outputs(self) -> None: |
|
|
|
""" |
|
|
|
|
|
|
self.update_batch_policy = policy_optimizer.minimize( |
|
|
|
self.policy_loss, var_list=policy_vars |
|
|
|
) |
|
|
|
print("value trainable:", critic_vars) |
|
|
|
# print("value trainable:", critic_vars) |
|
|
|
if self.use_transfer: |
|
|
|
value_loss = self.total_value_loss + self.model_loss |
|
|
|
else: |
|
|
|
value_loss = self.total_value_loss |
|
|
|
self.total_value_loss, var_list=critic_vars |
|
|
|
value_loss, var_list=critic_vars |
|
|
|
) |
|
|
|
# Add entropy coefficient optimization operation |
|
|
|
with tf.control_dependencies([self.update_batch_value]): |
|
|
|
|
|
|
stats_needed = self.stats_name_to_update_name |
|
|
|
update_stats: Dict[str, float] = {} |
|
|
|
|
|
|
|
# update_vals = self._execute_model(feed_dict, self.update_dict) |
|
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.model_update_dict) |
|
|
|
update_vals.update(self._execute_model(feed_dict, self.update_dict)) |
|
|
|
if self.use_transfer: |
|
|
|
update_vals = self._execute_model(feed_dict, self.update_dict) |
|
|
|
else: |
|
|
|
update_vals = self._execute_model(feed_dict, self.model_update_dict) |
|
|
|
update_vals.update(self._execute_model(feed_dict, self.update_dict)) |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
if update_name in update_vals.keys(): |
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
|
|
|
|
if self.use_bisim: |
|
|
|
bisim_stats = self.update_encoder(batch, batch_bisim) |
|
|
|
|
|
|
self.sess.run(self.target_update_op) |
|
|
|
self.policy.run_soft_copy() |
|
|
|
|
|
|
|
self.num_updates += 1 |
|
|
|
|
|
|
|
return update_stats |
|
|
|
|
|
|
|