浏览代码

no op buffer no acen

/develop/bisim-review
Andrew Cohen 4 年前
当前提交
9c012d6a
共有 4 个文件被更改,包括 16 次插入16 次删除
  1. 2
      config/ppo_transfer/3DBall.yaml
  2. 2
      config/ppo_transfer/3DBallHard.yaml
  3. 4
      config/ppo_transfer/3DBallHardTransfer.yaml
  4. 24
      ml-agents/mlagents/trainers/policy/transfer_policy.py

2
config/ppo_transfer/3DBall.yaml


action_feature_size: 32
reuse_encoder: true
in_epoch_alter: false
use_op_buffer: true
use_op_buffer: false
use_var_predict: true
with_prior: false
predict_return: true

2
config/ppo_transfer/3DBallHard.yaml


action_feature_size: 32
reuse_encoder: true
in_epoch_alter: false
use_op_buffer: true
use_op_buffer: false
use_var_predict: true
with_prior: false
predict_return: true

4
config/ppo_transfer/3DBallHardTransfer.yaml


action_feature_size: 32
reuse_encoder: true
in_epoch_alter: false
use_op_buffer: true
use_op_buffer: false
use_var_predict: true
with_prior: false
predict_return: true

load_model: true
train_action: false
load_action: true
load_action: false
train_policy: true
load_policy: false
train_value: true

24
ml-agents/mlagents/trainers/policy/transfer_policy.py


reuse_encoder,
)
self.action_encoder = self._create_action_encoder(
self.current_action,
self.h_size,
self.action_feature_size,
action_layers,
)
self.action_encoder = self.current_action # self._create_action_encoder(
# self.current_action,
# self.h_size,
# self.action_feature_size,
# action_layers,
# )
if not reuse_encoder:
self.targ_encoder = tf.stop_gradient(self.targ_encoder)

encoding_checkpoint = os.path.join(self.model_path, f"encoding.ckpt")
encoding_saver.save(self.sess, encoding_checkpoint)
action_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, "action_enc"
)
action_saver = tf.train.Saver(action_vars)
action_checkpoint = os.path.join(self.model_path, f"action_enc.ckpt")
action_saver.save(self.sess, action_checkpoint)
# action_vars = tf.get_collection(
# tf.GraphKeys.TRAINABLE_VARIABLES, "action_enc"
# )
# action_saver = tf.train.Saver(action_vars)
# action_checkpoint = os.path.join(self.model_path, f"action_enc.ckpt")
# action_saver.save(self.sess, action_checkpoint)
latent_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, "encoding/latent"

正在加载...
取消
保存