浏览代码

new cloud training change

/develop/bisim-review
yanchaosun 4 年前
当前提交
666c8ba9
共有 4 个文件被更改,包括 15 次插入12 次删除
  1. 3
      config/ppo_transfer/CrawlerStaticOpbuffer.yaml
  2. 4
      ml-agents/mlagents/trainers/policy/transfer_policy.py
  3. 18
      ml-agents/mlagents/trainers/ppo_transfer/optimizer.py
  4. 2
      ml-agents/mlagents/trainers/ppo_transfer/trainer.py

3
config/ppo_transfer/CrawlerStaticOpbuffer.yaml


num_epoch: 3
learning_rate_schedule: linear
encoder_layers: 2
policy_layers: 3
policy_layers: 2
value_layers: 2
forward_layers: 1
inverse_layers: 1

in_batch_alter: true
use_var_predict: true
network_settings:
normalize: true
hidden_units: 512

4
ml-agents/mlagents/trainers/policy/transfer_policy.py


h_size,
num_layers,
vis_encode_type,
reuse=reuse_encoder
feature_size
feature_size,
reuse=reuse_encoder
)
latent_targ = latent_targ_distribution.sample()

18
ml-agents/mlagents/trainers/ppo_transfer/optimizer.py


self.policy.entropy,
self.policy.targ_encoder,
self.policy.predict,
self.policy.encoder_distribution,
beta,
epsilon,
lr,

self.policy.load_graph_partial(self.transfer_path, self.transfer_type)
self.policy.get_encoder_weights()
self.policy.get_policy_weights()
# saver = tf.train.Saver()
# model_checkpoint = os.path.join(self.transfer_path, f"model-4000544.ckpt")
# saver.restore(self.sess, model_checkpoint)
# self.policy._set_step(0)
for variable in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
print(variable.name)
for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
print(variable)
# tf.summary.FileWriter(self.policy.model_path, self.sess.graph)

)
def _create_losses(
self, probs, old_probs, value_heads, entropy, targ_encoder, predict, encoder_distribution,
self, probs, old_probs, value_heads, entropy, targ_encoder, predict,
beta, epsilon, lr, max_step
):
"""

# self.model_loss += self.policy.predict_distribution.kl_standard()
self.model_loss = self.policy.forward_loss
if self.with_prior:
if self.use_var_encoder:
self.model_loss += 0.2 * self.policy.encoder_distribution.kl_standard()
if self.use_var_predict:
self.model_loss += 0.2 * self.policy.predict_distribution.kl_standard()
self.model_loss += self.policy.inverse_loss
self.model_loss += 0.5 * self.policy.inverse_loss
# self.model_loss = 0.2 * self.policy.forward_loss + 0.8 * self.policy.inverse_loss
self.loss = (
self.policy_loss

2
ml-agents/mlagents/trainers/ppo_transfer/trainer.py


num_epoch = self.hyperparameters.num_epoch
batch_update_stats = defaultdict(list)
for _ in range(num_epoch):
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length)
# self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.off_policy_buffer
max_num_batch = buffer_length // batch_size
for i in range(0, max_num_batch * batch_size, batch_size):

正在加载...
取消
保存