import numpy as np from typing import Dict, List, Optional, Any, Mapping, cast import copy from mlagents.tf_utils import tf from mlagents_envs.logging_util import get_logger from mlagents.trainers.sac_transfer.network import SACTransferPolicyNetwork, SACTransferTargetNetwork from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork from mlagents.trainers.models import ModelUtils from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer from mlagents.trainers.policy.tf_policy import TFPolicy from mlagents.trainers.policy.transfer_policy import TransferPolicy from mlagents.trainers.buffer import AgentBuffer from mlagents_envs.timers import timed from mlagents.trainers.settings import TrainerSettings, SACSettings, SACTransferSettings EPSILON = 1e-6 # Small value to avoid divide by zero logger = get_logger(__name__) POLICY_SCOPE = "" TARGET_SCOPE = "target_network" class SACTransferOptimizer(TFOptimizer): def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings): """ Takes a Unity environment and model-specific hyper-parameters and returns the appropriate PPO agent model for the environment. :param brain: Brain parameters used to generate specific network graph. :param lr: Learning rate. :param lr_schedule: Learning rate decay schedule. :param h_size: Size of hidden layers :param init_entcoef: Initial value for entropy coefficient. Set lower to learn faster, set higher to explore more. :return: a sub-class of PPOAgent tailored to the environment. :param max_step: Total number of training steps. :param normalize: Whether to normalize vector observation input. :param use_recurrent: Whether to use an LSTM layer in the network. :param num_layers: Number of hidden layers between encoded input and policy & value layers :param tau: Strength of soft-Q update. :param m_size: Size of brain memory. """ hyperparameters: SACTransferSettings = cast( SACTransferSettings, trainer_params.hyperparameters ) self.batch_size = hyperparameters.batch_size self.separate_value_train = hyperparameters.separate_value_train self.separate_policy_train = hyperparameters.separate_policy_train self.separate_model_train = hyperparameters.separate_model_train self.use_var_encoder = hyperparameters.use_var_encoder self.use_var_predict = hyperparameters.use_var_predict self.with_prior = hyperparameters.with_prior self.use_inverse_model = hyperparameters.use_inverse_model self.predict_return = hyperparameters.predict_return self.reuse_encoder = hyperparameters.reuse_encoder self.use_bisim = hyperparameters.use_bisim self.use_alter = hyperparameters.use_alter self.in_batch_alter = hyperparameters.in_batch_alter self.in_epoch_alter = hyperparameters.in_epoch_alter self.op_buffer = hyperparameters.use_op_buffer self.train_encoder = hyperparameters.train_encoder self.train_action = hyperparameters.train_action self.train_model = hyperparameters.train_model self.train_policy = hyperparameters.train_policy self.train_value = hyperparameters.train_value # Transfer self.use_transfer = hyperparameters.use_transfer self.transfer_path = ( hyperparameters.transfer_path ) self.smart_transfer = hyperparameters.smart_transfer self.conv_thres = hyperparameters.conv_thres self.sac_update_dict: Dict[str, tf.Tensor] = {} self.model_update_dict: Dict[str, tf.Tensor] = {} self.model_only_update_dict: Dict[str, tf.Tensor] = {} self.bisim_update_dict: Dict[str, tf.Tensor] = {} # Create the graph here to give more granular control of the TF graph to the Optimizer. policy.create_tf_graph( hyperparameters.encoder_layers, hyperparameters.action_layers, hyperparameters.policy_layers, hyperparameters.forward_layers, hyperparameters.inverse_layers, hyperparameters.feature_size, hyperparameters.action_feature_size, self.use_transfer, self.separate_policy_train, self.separate_model_train, self.use_var_encoder, self.use_var_predict, self.predict_return, self.use_inverse_model, self.reuse_encoder, self.use_bisim, hyperparameters.tau ) with policy.graph.as_default(): with tf.variable_scope(""): super().__init__(policy, trainer_params) lr = hyperparameters.learning_rate lr_schedule = hyperparameters.learning_rate_schedule max_step = trainer_params.max_steps self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef self.policy = policy self.act_size = policy.act_size policy_network_settings = policy.network_settings h_size = policy_network_settings.hidden_units num_layers = policy_network_settings.num_layers vis_encode_type = policy_network_settings.vis_encode_type self.tau = hyperparameters.tau self.burn_in_ratio = 0.0 # Non-exposed SAC parameters self.discrete_target_entropy_scale = ( 0.2 ) # Roughly equal to e-greedy 0.05 self.continuous_target_entropy_scale = 1.0 stream_names = list(self.reward_signals.keys()) # Use to reduce "survivor bonus" when using Curiosity or GAIL. self.gammas = [ _val.gamma for _val in trainer_params.reward_signals.values() ] self.use_dones_in_backup = { name: tf.Variable(1.0) for name in stream_names } self.disable_use_dones = { name: self.use_dones_in_backup[name].assign(0.0) for name in stream_names } if num_layers < 1: num_layers = 1 self.target_init_op: List[tf.Tensor] = [] self.target_update_op: List[tf.Tensor] = [] self.update_batch_policy: Optional[tf.Operation] = None self.update_batch_value: Optional[tf.Operation] = None self.update_batch_entropy: Optional[tf.Operation] = None if not hyperparameters.separate_value_net: self.policy_network = SACTransferPolicyNetwork( policy=self.policy, m_size=self.policy.m_size, # 3x policy.m_size h_size=h_size, normalize=self.policy.normalize, use_recurrent=self.policy.use_recurrent, encoder_layers=hyperparameters.encoder_layers, num_layers=hyperparameters.value_layers, action_layers=hyperparameters.action_layers, action_features=hyperparameters.action_feature_size, stream_names=stream_names, vis_encode_type=vis_encode_type, separate_train=hyperparameters.separate_value_train, ) self.target_network = SACTransferTargetNetwork( policy=self.policy, m_size=self.policy.m_size, # 1x policy.m_size h_size=h_size, normalize=self.policy.normalize, use_recurrent=self.policy.use_recurrent, encoder_layers=hyperparameters.encoder_layers, num_layers=hyperparameters.value_layers, action_layers=hyperparameters.action_layers, action_features=hyperparameters.action_feature_size, stream_names=stream_names, vis_encode_type=vis_encode_type, separate_train=hyperparameters.separate_value_train, ) else: self.policy_network = SACPolicyNetwork( policy=self.policy, m_size=self.policy.m_size, # 3x policy.m_size h_size=h_size, normalize=self.policy.normalize, use_recurrent=self.policy.use_recurrent, num_layers=num_layers, stream_names=stream_names, vis_encode_type=vis_encode_type, ) self.target_network = SACTargetNetwork( policy=self.policy, m_size=self.policy.m_size, # 1x policy.m_size h_size=h_size, normalize=self.policy.normalize, use_recurrent=self.policy.use_recurrent, num_layers=num_layers, stream_names=stream_names, vis_encode_type=vis_encode_type, ) # The optimizer's m_size is 3 times the policy (Q1, Q2, and Value) self.m_size = 3 * self.policy.m_size self._create_inputs_and_outputs() self.learning_rate = ModelUtils.create_schedule( lr_schedule, lr, self.policy.global_step, int(max_step), min_value=1e-10, ) self.model_learning_rate = ModelUtils.create_schedule( hyperparameters.model_schedule, lr, self.policy.global_step, int(max_step), min_value=1e-10, ) self.bisim_learning_rate = ModelUtils.create_schedule( hyperparameters.model_schedule, lr / 10, self.policy.global_step, int(max_step), min_value=1e-10, ) self._create_losses( self.policy_network.q1_heads, self.policy_network.q2_heads, lr, int(max_step), stream_names, discrete=not self.policy.use_continuous_act, ) self._create_sac_optimizer_ops() self.selected_actions = ( self.policy.selected_actions ) # For GAIL and other reward signals if self.policy.normalize: target_update_norm = self.target_network.copy_normalization( self.policy.running_mean, self.policy.running_variance, self.policy.normalization_steps, ) # Update the normalization of the optimizer when the policy does. self.policy.update_normalization_op = tf.group( [self.policy.update_normalization_op, target_update_norm] ) self.policy.initialize_or_load() if self.use_transfer: self.policy.load_graph_partial( self.transfer_path, hyperparameters.load_model, hyperparameters.load_policy, hyperparameters.load_value, hyperparameters.load_encoder, hyperparameters.load_action, ) self.policy.run_hard_copy() # self.sess.run(self.target_init_op) self.num_updates = 0 print("All variables in the graph:") for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): print(variable) self.stats_name_to_update_name = { "Losses/Value Loss": "value_loss", "Losses/Model Loss": "model_loss", "Losses/Policy Loss": "policy_loss", "Losses/Q1 Loss": "q1_loss", "Losses/Q2 Loss": "q2_loss", "Policy/Entropy Coeff": "entropy_coef", "Policy/Learning Rate": "learning_rate", "Policy/Model Learning Rate": "model_learning_rate", } if self.predict_return: self.stats_name_to_update_name.update({ "Losses/Reward Loss": "reward_loss", }) self.update_dict = { "value_loss": self.total_value_loss, "policy_loss": self.policy_loss, "q1_loss": self.q1_loss, "q2_loss": self.q2_loss, "entropy_coef": self.ent_coef, "update_batch": self.update_batch_policy, "update_value": self.update_batch_value, "update_entropy": self.update_batch_entropy, "learning_rate": self.learning_rate, } if self.use_transfer: self.update_dict.update({ "model_loss": self.model_loss, "model_learning_rate": self.model_learning_rate, "reward_loss": self.policy.reward_loss, }) def _create_inputs_and_outputs(self) -> None: """ Assign the higher-level SACModel's inputs and outputs to those of its policy or target network. """ self.vector_in = self.policy.vector_in self.visual_in = self.policy.visual_in self.next_vector_in = self.target_network.vector_in self.next_visual_in = self.target_network.visual_in self.sequence_length_ph = self.policy.sequence_length_ph self.next_sequence_length_ph = self.target_network.sequence_length_ph if not self.policy.use_continuous_act: self.action_masks = self.policy_network.action_masks else: self.output_pre = self.policy_network.output_pre # Don't use value estimate during inference. self.value = tf.identity( self.policy_network.value, name="value_estimate_unused" ) self.value_heads = self.policy_network.value_heads self.dones_holder = tf.placeholder( shape=[None], dtype=tf.float32, name="dones_holder" ) if self.policy.use_recurrent: self.memory_in = self.policy_network.memory_in self.memory_out = self.policy_network.memory_out if not self.policy.use_continuous_act: self.prev_action = self.policy_network.prev_action self.next_memory_in = self.target_network.memory_in def _create_losses( self, q1_streams: Dict[str, tf.Tensor], q2_streams: Dict[str, tf.Tensor], lr: tf.Tensor, max_step: int, stream_names: List[str], discrete: bool = False, ) -> None: """ Creates training-specific Tensorflow ops for SAC models. :param q1_streams: Q1 streams from policy network :param q1_streams: Q2 streams from policy network :param lr: Learning rate :param max_step: Total number of training steps. :param stream_names: List of reward stream names. :param discrete: Whether or not to use discrete action losses. """ if discrete: self.target_entropy = [ self.discrete_target_entropy_scale * np.log(i).astype(np.float32) for i in self.act_size ] discrete_action_probs = tf.exp(self.policy.all_log_probs) per_action_entropy = discrete_action_probs * self.policy.all_log_probs else: self.target_entropy = ( -1 * self.continuous_target_entropy_scale * np.prod(self.act_size[0]).astype(np.float32) ) self.rewards_holders = {} self.min_policy_qs = {} for name in stream_names: if discrete: _branched_mpq1 = ModelUtils.break_into_branches( self.policy_network.q1_pheads[name] * discrete_action_probs, self.act_size, ) branched_mpq1 = tf.stack( [ tf.reduce_sum(_br, axis=1, keep_dims=True) for _br in _branched_mpq1 ] ) _q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0) _branched_mpq2 = ModelUtils.break_into_branches( self.policy_network.q2_pheads[name] * discrete_action_probs, self.act_size, ) branched_mpq2 = tf.stack( [ tf.reduce_sum(_br, axis=1, keep_dims=True) for _br in _branched_mpq2 ] ) _q2_p_mean = tf.reduce_mean(branched_mpq2, axis=0) self.min_policy_qs[name] = tf.minimum(_q1_p_mean, _q2_p_mean) else: self.min_policy_qs[name] = tf.minimum( self.policy_network.q1_pheads[name], self.policy_network.q2_pheads[name], ) rewards_holder = tf.placeholder( shape=[None], dtype=tf.float32, name="{}_rewards".format(name) ) self.rewards_holders[name] = rewards_holder q1_losses = [] q2_losses = [] # Multiple q losses per stream expanded_dones = tf.expand_dims(self.dones_holder, axis=-1) for i, name in enumerate(stream_names): _expanded_rewards = tf.expand_dims(self.rewards_holders[name], axis=-1) q_backup = tf.stop_gradient( _expanded_rewards + (1.0 - self.use_dones_in_backup[name] * expanded_dones) * self.gammas[i] * self.target_network.value_heads[name] ) if discrete: # We need to break up the Q functions by branch, and update them individually. branched_q1_stream = ModelUtils.break_into_branches( self.policy.selected_actions * q1_streams[name], self.act_size ) branched_q2_stream = ModelUtils.break_into_branches( self.policy.selected_actions * q2_streams[name], self.act_size ) # Reduce each branch into scalar branched_q1_stream = [ tf.reduce_sum(_branch, axis=1, keep_dims=True) for _branch in branched_q1_stream ] branched_q2_stream = [ tf.reduce_sum(_branch, axis=1, keep_dims=True) for _branch in branched_q2_stream ] q1_stream = tf.reduce_mean(branched_q1_stream, axis=0) q2_stream = tf.reduce_mean(branched_q2_stream, axis=0) else: q1_stream = q1_streams[name] q2_stream = q2_streams[name] _q1_loss = 0.5 * tf.reduce_mean( tf.to_float(self.policy.mask) * tf.squared_difference(q_backup, q1_stream) ) _q2_loss = 0.5 * tf.reduce_mean( tf.to_float(self.policy.mask) * tf.squared_difference(q_backup, q2_stream) ) q1_losses.append(_q1_loss) q2_losses.append(_q2_loss) self.q1_loss = tf.reduce_mean(q1_losses) self.q2_loss = tf.reduce_mean(q2_losses) # Learn entropy coefficient if discrete: # Create a log_ent_coef for each branch self.log_ent_coef = tf.get_variable( "log_ent_coef", dtype=tf.float32, initializer=np.log([self.init_entcoef] * len(self.act_size)).astype( np.float32 ), trainable=True, ) else: self.log_ent_coef = tf.get_variable( "log_ent_coef", dtype=tf.float32, initializer=np.log(self.init_entcoef).astype(np.float32), trainable=True, ) self.ent_coef = tf.exp(self.log_ent_coef) if discrete: # We also have to do a different entropy and target_entropy per branch. branched_per_action_ent = ModelUtils.break_into_branches( per_action_entropy, self.act_size ) branched_ent_sums = tf.stack( [ tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te for _lp, _te in zip(branched_per_action_ent, self.target_entropy) ], axis=1, ) self.entropy_loss = -tf.reduce_mean( tf.to_float(self.policy.mask) * tf.reduce_mean( self.log_ent_coef * tf.squeeze(tf.stop_gradient(branched_ent_sums), axis=2), axis=1, ) ) # Same with policy loss, we have to do the loss per branch and average them, # so that larger branches don't get more weight. # The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q branched_q_term = ModelUtils.break_into_branches( discrete_action_probs * self.policy_network.q1_p, self.act_size ) branched_policy_loss = tf.stack( [ tf.reduce_sum(self.ent_coef[i] * _lp - _qt, axis=1, keep_dims=True) for i, (_lp, _qt) in enumerate( zip(branched_per_action_ent, branched_q_term) ) ] ) self.policy_loss = tf.reduce_mean( tf.to_float(self.policy.mask) * tf.squeeze(branched_policy_loss) ) # Do vbackup entropy bonus per branch as well. branched_ent_bonus = tf.stack( [ tf.reduce_sum(self.ent_coef[i] * _lp, axis=1, keep_dims=True) for i, _lp in enumerate(branched_per_action_ent) ] ) value_losses = [] for name in stream_names: v_backup = tf.stop_gradient( self.min_policy_qs[name] - tf.reduce_mean(branched_ent_bonus, axis=0) ) value_losses.append( 0.5 * tf.reduce_mean( tf.to_float(self.policy.mask) * tf.squared_difference( self.policy_network.value_heads[name], v_backup ) ) ) else: self.entropy_loss = -tf.reduce_mean( self.log_ent_coef * tf.to_float(self.policy.mask) * tf.stop_gradient( tf.reduce_sum( self.policy.all_log_probs + self.target_entropy, axis=1, keep_dims=True, ) ) ) batch_policy_loss = tf.reduce_mean( self.ent_coef * self.policy.all_log_probs - self.policy_network.q1_p, axis=1, ) self.policy_loss = tf.reduce_mean( tf.to_float(self.policy.mask) * batch_policy_loss ) value_losses = [] for name in stream_names: v_backup = tf.stop_gradient( self.min_policy_qs[name] - tf.reduce_sum(self.ent_coef * self.policy.all_log_probs, axis=1) ) value_losses.append( 0.5 * tf.reduce_mean( tf.to_float(self.policy.mask) * tf.squared_difference( self.policy_network.value_heads[name], v_backup ) ) ) self.value_loss = tf.reduce_mean(value_losses) self.total_value_loss = self.q1_loss + self.q2_loss + self.value_loss self.entropy = self.policy_network.entropy self.model_loss = self.policy.forward_loss if self.predict_return: self.model_loss += 0.5 * self.policy.reward_loss if self.with_prior: if self.use_var_encoder: self.model_loss += 0.2 * self.policy.encoder_distribution.kl_standard() if self.use_var_predict: self.model_loss += 0.2 * self.policy.predict_distribution.kl_standard() if self.use_bisim: if self.use_var_predict: predict_diff = self.policy.predict_distribution.w_distance( self.policy.bisim_predict_distribution ) else: predict_diff = tf.reduce_mean( tf.reduce_sum( tf.squared_difference( self.policy.bisim_predict, self.policy.predict ), axis=1, ) ) if self.predict_return: reward_diff = tf.reduce_sum( tf.abs(self.policy.bisim_pred_reward - self.policy.pred_reward), axis=1, ) predict_diff = ( self.reward_signals["extrinsic"].gamma * predict_diff + reward_diff ) encode_dist = tf.reduce_sum( tf.abs(self.policy.encoder - self.policy.bisim_encoder), axis=1 ) self.predict_difference = predict_diff self.reward_difference = reward_diff self.encode_difference = encode_dist self.bisim_loss = tf.reduce_mean( tf.squared_difference(encode_dist, predict_diff) ) def _create_sac_optimizer_ops(self) -> None: """ Creates the Adam optimizers and update ops for SAC, including the policy, value, and entropy updates, as well as the target network update. """ policy_optimizer = self.create_optimizer_op( learning_rate=self.learning_rate, name="sac_policy_opt" ) entropy_optimizer = self.create_optimizer_op( learning_rate=self.learning_rate, name="sac_entropy_opt" ) value_optimizer = self.create_optimizer_op( learning_rate=self.learning_rate, name="sac_value_opt" ) self.target_update_op = [ tf.assign(target, (1 - self.tau) * target + self.tau * source) for target, source in zip( self.target_network.value_vars, self.policy_network.value_vars ) ] policy_vars = self.policy.get_trainable_variables( train_encoder=not self.separate_policy_train, train_action=not self.separate_policy_train, train_model=False, train_policy=self.train_policy ) model_vars = self.policy.get_trainable_variables( train_encoder=self.train_encoder, train_action=self.train_action, train_model=self.train_model, train_policy=False ) encoding_vars = self.policy.encoding_variables if self.train_value: critic_vars = self.policy_network.critic_vars + model_vars else: critic_vars = model_vars self.target_init_op = [ tf.assign(target, source) for target, source in zip( self.target_network.value_vars, self.policy_network.value_vars ) ] self.update_batch_policy = policy_optimizer.minimize( self.policy_loss, var_list=policy_vars ) # print("value trainable:", critic_vars) # Make sure policy is updated first, then value, then entropy. if self.use_transfer: value_loss = self.total_value_loss + self.model_loss else: value_loss = self.total_value_loss with tf.control_dependencies([self.update_batch_policy]): self.update_batch_value = value_optimizer.minimize( value_loss, var_list=critic_vars ) # Add entropy coefficient optimization operation with tf.control_dependencies([self.update_batch_value]): self.update_batch_entropy = entropy_optimizer.minimize( self.entropy_loss, var_list=self.log_ent_coef ) model_optimizer = self.create_optimizer_op( learning_rate=self.model_learning_rate, name="sac_model_opt" ) self.update_batch_model = model_optimizer.minimize( self.model_loss, var_list=model_vars ) self.model_update_dict.update( { "model_loss": self.model_loss, "update_batch": self.update_batch_model, "model_learning_rate": self.model_learning_rate, } ) if self.predict_return: self.model_update_dict.update({"reward_loss": self.policy.reward_loss}) if self.use_bisim: bisim_train_vars = tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES, "encoding" ) self.bisim_optimizer = self.create_optimizer_op(self.bisim_learning_rate) self.bisim_update_batch = self.bisim_optimizer.minimize( self.bisim_loss, var_list=bisim_train_vars ) self.bisim_update_dict.update( { "bisim_loss": self.bisim_loss, "update_batch": self.bisim_update_batch, "bisim_learning_rate": self.bisim_learning_rate, } ) def print_all_vars(self, variables): for _var in variables: logger.debug(_var) @timed def update(self, batch: AgentBuffer, batch_bisim: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ feed_dict = self._construct_feed_dict(self.policy, batch, num_sequences) stats_needed = self.stats_name_to_update_name update_stats: Dict[str, float] = {} if self.use_transfer: update_vals = self._execute_model(feed_dict, self.update_dict) else: update_vals = self._execute_model(feed_dict, self.model_update_dict) update_vals.update(self._execute_model(feed_dict, self.update_dict)) for stat_name, update_name in stats_needed.items(): if update_name in update_vals.keys(): update_stats[stat_name] = update_vals[update_name] if self.use_bisim: bisim_stats = self.update_encoder(batch, batch_bisim) update_stats.update(bisim_stats) # Update target network. By default, target update happens at every policy update. self.sess.run(self.target_update_op) self.policy.run_soft_copy() self.num_updates += 1 return update_stats def update_encoder(self, mini_batch1: AgentBuffer, mini_batch2: AgentBuffer): stats_needed = { "Losses/Bisim Loss": "bisim_loss", "Policy/Bisim Learning Rate": "bisim_learning_rate", } update_stats = {} selected_action_1 = self.policy.sess.run( self.policy.selected_actions, feed_dict={self.policy.vector_in: mini_batch1["vector_obs"]}, ) selected_action_2 = self.policy.sess.run( self.policy.selected_actions, feed_dict={self.policy.vector_in: mini_batch2["vector_obs"]}, ) feed_dict = { self.policy.vector_in: mini_batch1["vector_obs"], self.policy.vector_bisim: mini_batch2["vector_obs"], self.policy.current_action: selected_action_1, self.policy.bisim_action: selected_action_2, } update_vals = self._execute_model(feed_dict, self.bisim_update_dict) for stat_name, update_name in stats_needed.items(): if update_name in update_vals.keys(): update_stats[stat_name] = update_vals[update_name] return update_stats def update_reward_signals( self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int ) -> Dict[str, float]: """ Only update the reward signals. :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. """ # Collect feed dicts for all reward signals. feed_dict: Dict[tf.Tensor, Any] = {} update_dict: Dict[str, tf.Tensor] = {} update_stats: Dict[str, float] = {} stats_needed: Dict[str, str] = {} if reward_signal_minibatches: self.add_reward_signal_dicts( feed_dict, update_dict, stats_needed, reward_signal_minibatches, num_sequences, ) update_vals = self._execute_model(feed_dict, update_dict) for stat_name, update_name in stats_needed.items(): update_stats[stat_name] = update_vals[update_name] return update_stats def add_reward_signal_dicts( self, feed_dict: Dict[tf.Tensor, Any], update_dict: Dict[str, tf.Tensor], stats_needed: Dict[str, str], reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int, ) -> None: """ Adds the items needed for reward signal updates to the feed_dict and stats_needed dict. :param feed_dict: Feed dict needed update :param update_dit: Update dict that needs update :param stats_needed: Stats needed to get from the update. :param reward_signal_minibatches: Minibatches to use for updating the reward signals, indexed by name. """ for name, r_batch in reward_signal_minibatches.items(): feed_dict.update( self.reward_signals[name].prepare_update( self.policy, r_batch, num_sequences ) ) update_dict.update(self.reward_signals[name].update_dict) stats_needed.update(self.reward_signals[name].stats_name_to_update_name) def _construct_feed_dict( self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int ) -> Dict[tf.Tensor, Any]: """ Builds the feed dict for updating the SAC model. :param model: The model to update. May be different when, e.g. using multi-GPU. :param batch: Mini-batch to use to update. :param num_sequences: Number of LSTM sequences in batch. """ # Do an optional burn-in for memories num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length) burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32) burn_in_mask[range(0, num_burn_in)] = 0 burn_in_mask = np.tile(burn_in_mask, num_sequences) feed_dict = { policy.batch_size_ph: num_sequences, policy.sequence_length_ph: self.policy.sequence_length, self.next_sequence_length_ph: self.policy.sequence_length, self.policy.mask_input: batch["masks"] * burn_in_mask, self.policy.current_action: batch["actions"], self.policy.current_reward: batch["extrinsic_rewards"], } for name in self.reward_signals: feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)] if self.policy.use_continuous_act: feed_dict[self.policy_network.external_action_in] = batch["actions"] else: feed_dict[policy.output] = batch["actions"] if self.policy.use_recurrent: feed_dict[policy.prev_action] = batch["prev_action"] feed_dict[policy.action_masks] = batch["action_mask"] if self.policy.use_vec_obs: feed_dict[policy.vector_in] = batch["vector_obs"] feed_dict[self.next_vector_in] = batch["next_vector_in"] feed_dict[policy.vector_next] = batch["next_vector_in"] if self.policy.vis_obs_size > 0: for i, _ in enumerate(policy.visual_in): _obs = batch["visual_obs%d" % i] feed_dict[policy.visual_in[i]] = _obs for i, _ in enumerate(self.next_visual_in): _obs = batch["next_visual_obs%d" % i] feed_dict[self.next_visual_in[i]] = _obs feed_dict[policy.visual_next[i]] = _obs if self.policy.use_recurrent: feed_dict[policy.memory_in] = [ batch["memory"][i] for i in range(0, len(batch["memory"]), self.policy.sequence_length) ] feed_dict[self.policy_network.memory_in] = self._make_zero_mem( self.m_size, batch.num_experiences ) feed_dict[self.target_network.memory_in] = self._make_zero_mem( self.m_size // 3, batch.num_experiences ) feed_dict[self.dones_holder] = batch["done"] return feed_dict