|
|
|
|
|
|
lr, |
|
|
|
max_step, |
|
|
|
) |
|
|
|
self.create_ppo_optimizer() |
|
|
|
|
|
|
|
self.update_dict.update( |
|
|
|
{ |
|
|
|
"value_loss": self.value_loss, |
|
|
|
"policy_loss": self.abs_policy_loss, |
|
|
|
"update_batch": self.update_batch, |
|
|
|
} |
|
|
|
) |
|
|
|
|
|
|
|
def create_cc_critic( |
|
|
|
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|
|
|
|
|
|
def construct_feed_dict( |
|
|
|
self, model: PPOModel, mini_batch: AgentBuffer, num_sequences: int |
|
|
|
) -> Dict[tf.Tensor, Any]: |
|
|
|
|
|
|
|
model.sequence_length: model.sequence_length, |
|
|
|
model.sequence_length: len(mini_batch["advantages"]) |
|
|
|
/ num_sequences, # TODO: Fix LSTM |
|
|
|
model.mask_input: mini_batch["masks"], |
|
|
|
self.advantage: mini_batch["advantages"], |
|
|
|
self.all_old_log_probs: mini_batch["action_probs"], |
|
|
|
|
|
|
:param out_dict: Output dictionary mapping names to nodes. |
|
|
|
:return: Dictionary mapping names to input data. |
|
|
|
""" |
|
|
|
print(feed_dict) |
|
|
|
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict) |
|
|
|
run_out = dict(zip(list(out_dict.keys()), network_out)) |
|
|
|
return run_out |