|
|
|
|
|
|
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.settings import TrainerSettings, PPOSettings, PPOTransferSettings |
|
|
|
import tf_slim as slim |
|
|
|
# import tf_slim as slim |
|
|
|
|
|
|
|
class PPOTransferOptimizer(TFOptimizer): |
|
|
|
def __init__(self, policy: TransferPolicy, trainer_params: TrainerSettings): |
|
|
|
|
|
|
self.model_update_dict: Dict[str, tf.Tensor] = {} |
|
|
|
|
|
|
|
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|
|
|
policy.create_tf_graph(hyperparameters.encoder_layers, hyperparameters.policy_layers, |
|
|
|
policy.create_tf_graph(hyperparameters.encoder_layers, hyperparameters.policy_layers, hyperparameters.feature_size, |
|
|
|
self.use_transfer, self.separate_policy_train, self.use_var_encoder, self.use_var_predict, |
|
|
|
self.predict_return, self.use_inverse_model, self.reuse_encoder) |
|
|
|
|
|
|
|
|
|
|
# saver.restore(self.sess, model_checkpoint) |
|
|
|
# self.policy._set_step(0) |
|
|
|
|
|
|
|
slim.model_analyzer.analyze_vars(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), print_info=True) |
|
|
|
# slim.model_analyzer.analyze_vars(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), print_info=True) |
|
|
|
|
|
|
|
print("All variables in the graph:") |
|
|
|
for variable in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): |
|
|
|