|
|
|
|
|
|
from typing import Optional, Any, Dict, cast |
|
|
|
from typing import Optional, Any, Dict, cast, List, Tuple |
|
|
|
import numpy as np |
|
|
|
import os |
|
|
|
import copy |
|
|
|
|
|
|
from mlagents.trainers.trajectory import SplitObservations |
|
|
|
from mlagents.trainers.policy.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel |
|
|
|
from mlagents.trainers.policy.transfer_policy import TransferPolicy |
|
|
|
|
|
|
else: |
|
|
|
self._create_dc_critic(h_size, hyperparameters.value_layers, vis_encode_type) |
|
|
|
|
|
|
|
with tf.variable_scope("target_value"): |
|
|
|
if policy.use_continuous_act: |
|
|
|
self._create_cc_critic_target(h_size, hyperparameters.value_layers, vis_encode_type) |
|
|
|
else: |
|
|
|
self._create_dc_critic_target(h_size, hyperparameters.value_layers, vis_encode_type) |
|
|
|
|
|
|
|
self._create_soft_critic_copy() |
|
|
|
|
|
|
|
with tf.variable_scope("optimizer/"): |
|
|
|
self.learning_rate = ModelUtils.create_schedule( |
|
|
|
self._schedule, |
|
|
|
|
|
|
min_value=1e-10, |
|
|
|
) |
|
|
|
self.model_learning_rate = ModelUtils.create_schedule( |
|
|
|
# ScheduleType.LINEAR, |
|
|
|
ScheduleType.CONSTANT, |
|
|
|
self._schedule, |
|
|
|
lr, |
|
|
|
self.policy.global_step, |
|
|
|
int(max_step), |
|
|
|
|
|
|
ScheduleType.CONSTANT, |
|
|
|
lr/10, |
|
|
|
self._schedule, |
|
|
|
lr, |
|
|
|
self.policy.global_step, |
|
|
|
int(max_step), |
|
|
|
min_value=1e-10, |
|
|
|
|
|
|
self.old_log_probs, |
|
|
|
self.value_heads, |
|
|
|
self.policy.entropy, |
|
|
|
self.policy.targ_encoder, |
|
|
|
self.policy.next_encoder, |
|
|
|
self.policy.predict, |
|
|
|
beta, |
|
|
|
epsilon, |
|
|
|
|
|
|
if self.use_transfer: |
|
|
|
self.policy.load_graph_partial(self.transfer_path, self.transfer_type, |
|
|
|
hyperparameters.load_model, hyperparameters.load_policy, hyperparameters.load_value) |
|
|
|
self.run_soft_critic_copy() |
|
|
|
self.policy.get_encoder_weights() |
|
|
|
self.policy.get_policy_weights() |
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
def _create_losses( |
|
|
|
self, probs, old_probs, value_heads, entropy, targ_encoder, predict, |
|
|
|
self, probs, old_probs, value_heads, entropy, next_encoder, predict, |
|
|
|
beta, epsilon, lr, max_step |
|
|
|
): |
|
|
|
""" |
|
|
|
|
|
|
# For cleaner stats reporting |
|
|
|
self.abs_policy_loss = tf.abs(self.policy_loss) |
|
|
|
|
|
|
|
# encoder and predict loss |
|
|
|
# self.dis_returns = tf.placeholder( |
|
|
|
# shape=[None], dtype=tf.float32, name="dis_returns" |
|
|
|
# ) |
|
|
|
# target = tf.concat([targ_encoder, tf.expand_dims(self.dis_returns, -1)], axis=1) |
|
|
|
# if self.predict_return: |
|
|
|
# self.model_loss = tf.reduce_mean(tf.squared_difference(predict, target)) |
|
|
|
# else: |
|
|
|
# self.model_loss = tf.reduce_mean(tf.squared_difference(predict, targ_encoder)) |
|
|
|
# if self.with_prior: |
|
|
|
# if self.use_var_encoder: |
|
|
|
# self.model_loss += encoder_distribution.kl_standard() |
|
|
|
# if self.use_var_predict: |
|
|
|
# self.model_loss += self.policy.predict_distribution.kl_standard() |
|
|
|
|
|
|
|
self.model_loss = self.policy.forward_loss |
|
|
|
if self.predict_return: |
|
|
|
self.model_loss += 0.5 * self.policy.reward_loss |
|
|
|
|
|
|
reward_diff = tf.reduce_mean( |
|
|
|
tf.squared_difference(self.policy.bisim_pred_reward, self.policy.pred_reward) |
|
|
|
) |
|
|
|
predict_diff = self.reward_signals["extrinsic"].gamma * predict_diff + tf.abs(reward_diff) |
|
|
|
predict_diff = 0.99 * predict_diff + tf.abs(reward_diff) |
|
|
|
tf.squared_difference(self.policy.encoder, self.policy.bisim_encoder) |
|
|
|
tf.abs(self.policy.encoder - self.policy.bisim_encoder) |
|
|
|
self.encode_dist_val = encode_dist |
|
|
|
self.predict_diff_val = predict_diff |
|
|
|
self.bisim_loss = tf.squared_difference(encode_dist, predict_diff) |
|
|
|
|
|
|
|
self.loss = ( |
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.update_dict) |
|
|
|
|
|
|
|
# update target encoder |
|
|
|
if not self.reuse_encoder and self.num_updates % self.copy_every == 0: |
|
|
|
if self.num_updates % self.copy_every == 0: |
|
|
|
self.run_soft_critic_copy() |
|
|
|
# print("copy") |
|
|
|
# self.policy.get_encoder_weights() |
|
|
|
|
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.model_only_update_dict) |
|
|
|
|
|
|
|
# update target encoder |
|
|
|
if not self.reuse_encoder and self.num_updates % self.copy_every == 0: |
|
|
|
if self.num_updates % self.copy_every == 0: |
|
|
|
self.run_soft_critic_copy() |
|
|
|
# print("copy") |
|
|
|
# self.policy.get_encoder_weights() |
|
|
|
self.num_updates += 1 |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def update_encoder(self, mini_batch1: AgentBuffer, mini_batch2: AgentBuffer): |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
update_vals = self._execute_model(feed_dict, self.bisim_update_dict) |
|
|
|
# print("model difference:", self.policy.sess.run(self.predict_diff_val, feed_dict=feed_dict)) |
|
|
|
# print("encoder distance:", self.policy.sess.run(self.encode_dist_val, feed_dict=feed_dict)) |
|
|
|
|
|
|
|
for stat_name, update_name in stats_needed.items(): |
|
|
|
if update_name in update_vals.keys(): |
|
|
|
|
|
|
keepdims=True, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _create_cc_critic_target( |
|
|
|
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Creates Continuous control critic (value) network. |
|
|
|
:param h_size: Size of hidden linear layers. |
|
|
|
:param num_layers: Number of hidden linear layers. |
|
|
|
:param vis_encode_type: The type of visual encoder to use. |
|
|
|
""" |
|
|
|
input_state = self.policy.target_encoder |
|
|
|
|
|
|
|
hidden_value = ModelUtils.create_vector_observation_encoder( |
|
|
|
input_state, |
|
|
|
h_size, |
|
|
|
ModelUtils.swish, |
|
|
|
num_layers, |
|
|
|
scope=f"main_graph", |
|
|
|
reuse=False |
|
|
|
) |
|
|
|
self.target_value_heads, self.target_value = ModelUtils.create_value_heads( |
|
|
|
self.stream_names, hidden_value |
|
|
|
) |
|
|
|
|
|
|
|
for name in self.stream_names: |
|
|
|
self.target_value_heads[name] = tf.stop_gradient(self.target_value_heads[name]) |
|
|
|
|
|
|
|
self.target_all_old_log_probs = tf.placeholder( |
|
|
|
shape=[None, sum(self.policy.act_size)], |
|
|
|
dtype=tf.float32, |
|
|
|
name="old_probabilities", |
|
|
|
) |
|
|
|
|
|
|
|
self.target_old_log_probs = tf.reduce_sum( |
|
|
|
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True |
|
|
|
) |
|
|
|
|
|
|
|
def _create_dc_critic_target( |
|
|
|
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|
|
|
) -> None: |
|
|
|
""" |
|
|
|
Creates Discrete control critic (value) network. |
|
|
|
:param h_size: Size of hidden linear layers. |
|
|
|
:param num_layers: Number of hidden linear layers. |
|
|
|
:param vis_encode_type: The type of visual encoder to use. |
|
|
|
""" |
|
|
|
input_state = self.policy.target_encoder |
|
|
|
|
|
|
|
hidden_value = ModelUtils.create_vector_observation_encoder( |
|
|
|
input_state, |
|
|
|
h_size, |
|
|
|
ModelUtils.swish, |
|
|
|
num_layers, |
|
|
|
scope=f"main_graph", |
|
|
|
reuse=False |
|
|
|
) |
|
|
|
self.target_value_heads, self.target_value = ModelUtils.create_value_heads( |
|
|
|
self.stream_names, hidden_value |
|
|
|
) |
|
|
|
for name in self.stream_names: |
|
|
|
self.target_value_heads[name] = tf.stop_gradient(self.target_value_heads[name]) |
|
|
|
# self.target_value[name] = tf.stop_gradient(self.target_value[name]) |
|
|
|
|
|
|
|
self.target_all_old_log_probs = tf.placeholder( |
|
|
|
shape=[None, sum(self.policy.act_size)], |
|
|
|
dtype=tf.float32, |
|
|
|
name="old_probabilities", |
|
|
|
) |
|
|
|
|
|
|
|
# Break old log probs into separate branches |
|
|
|
old_log_prob_branches = ModelUtils.break_into_branches( |
|
|
|
self.target_all_old_log_probs, self.policy.act_size |
|
|
|
) |
|
|
|
|
|
|
|
_, _, old_normalized_logits = ModelUtils.create_discrete_action_masking_layer( |
|
|
|
old_log_prob_branches, self.policy.action_masks, self.policy.act_size |
|
|
|
) |
|
|
|
|
|
|
|
action_idx = [0] + list(np.cumsum(self.policy.act_size)) |
|
|
|
|
|
|
|
self.target_old_log_probs = tf.reduce_sum( |
|
|
|
( |
|
|
|
tf.stack( |
|
|
|
[ |
|
|
|
-tf.nn.softmax_cross_entropy_with_logits_v2( |
|
|
|
labels=self.policy.selected_actions[ |
|
|
|
:, action_idx[i] : action_idx[i + 1] |
|
|
|
], |
|
|
|
logits=old_normalized_logits[ |
|
|
|
:, action_idx[i] : action_idx[i + 1] |
|
|
|
], |
|
|
|
) |
|
|
|
for i in range(len(self.policy.act_size)) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
), |
|
|
|
axis=1, |
|
|
|
keepdims=True, |
|
|
|
) |
|
|
|
|
|
|
|
def get_trajectory_value_estimates( |
|
|
|
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|
|
|
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|
|
|
feed_dict: Dict[tf.Tensor, Any] = { |
|
|
|
self.policy.batch_size_ph: batch.num_experiences, |
|
|
|
self.policy.sequence_length_ph: batch.num_experiences, # We want to feed data in batch-wise, not time-wise. |
|
|
|
} |
|
|
|
|
|
|
|
if self.policy.vec_obs_size > 0: |
|
|
|
feed_dict[self.policy.vector_in] = batch["vector_obs"] |
|
|
|
if self.policy.vis_obs_size > 0: |
|
|
|
for i in range(len(self.policy.visual_in)): |
|
|
|
_obs = batch["visual_obs%d" % i] |
|
|
|
feed_dict[self.policy.visual_in[i]] = _obs |
|
|
|
if self.policy.use_recurrent: |
|
|
|
feed_dict[self.policy.memory_in] = [ |
|
|
|
np.zeros((self.policy.m_size), dtype=np.float32) |
|
|
|
] |
|
|
|
feed_dict[self.memory_in] = [np.zeros((self.m_size), dtype=np.float32)] |
|
|
|
if self.policy.prev_action is not None: |
|
|
|
feed_dict[self.policy.prev_action] = batch["prev_action"] |
|
|
|
|
|
|
|
if self.policy.use_recurrent: |
|
|
|
value_estimates, policy_mem, value_mem = self.sess.run( |
|
|
|
[self.target_value_heads, self.policy.memory_out, self.memory_out], feed_dict |
|
|
|
) |
|
|
|
prev_action = ( |
|
|
|
batch["actions"][-1] if not self.policy.use_continuous_act else None |
|
|
|
) |
|
|
|
else: |
|
|
|
value_estimates = self.sess.run(self.target_value_heads, feed_dict) |
|
|
|
prev_action = None |
|
|
|
policy_mem = None |
|
|
|
value_mem = None |
|
|
|
value_estimates = {k: np.squeeze(v, axis=1) for k, v in value_estimates.items()} |
|
|
|
|
|
|
|
# We do this in a separate step to feed the memory outs - a further optimization would |
|
|
|
# be to append to the obs before running sess.run. |
|
|
|
final_value_estimates = self._get_value_estimates( |
|
|
|
next_obs, done, policy_mem, value_mem, prev_action |
|
|
|
) |
|
|
|
|
|
|
|
return value_estimates, final_value_estimates |
|
|
|
|
|
|
|
def _create_soft_critic_copy(self): |
|
|
|
t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_value') |
|
|
|
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='value') |
|
|
|
|
|
|
|
with tf.variable_scope('hard_replacement'): |
|
|
|
self.target_replace_op = [tf.assign(t, 0.9*t + 0.1*e) for t, e in zip(t_params, e_params)] |
|
|
|
|
|
|
|
def run_soft_critic_copy(self): |
|
|
|
with self.policy.graph.as_default(): |
|
|
|
# print("before") |
|
|
|
# val = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "value/extrinsic_value/bias:0") |
|
|
|
# print("value:", self.sess.run(val)) |
|
|
|
# target_val = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "target_value/extrinsic_value/bias:0") |
|
|
|
# print("target_val:", self.sess.run(target_val)) |
|
|
|
self.policy.sess.run(self.target_replace_op) |
|
|
|
# print("copy") |
|
|
|
# val = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "value/extrinsic_value/bias:0") |
|
|
|
# print("value:", self.sess.run(val)) |
|
|
|
# target_val = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "target_value/extrinsic_value/bias:0") |
|
|
|
# print("target_val:", self.sess.run(target_val)) |