浏览代码

changes on the ppo.py

/tag-0.2.0
vincentpierre 7 年前
当前提交
d71ee998
共有 1 个文件被更改,包括 3 次插入3 次删除
  1. 6
      python/ppo.py

6
python/ppo.py


hidden_units = int(options['--hidden-units'])
batch_size = int(options['--batch-size'])
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_path)
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file)
print(str(env))
brain_name = env.brain_names[0]

def get_progress():
if use_curriculum:
if curriculum_file is not None:
if env._curriculum.measure_type == "progress":
return steps / max_steps
elif env._curriculum.measure_type == "reward":

trainer.update_model(batch_size, num_epoch)
if steps % summary_freq == 0 and steps != 0 and train_model:
# Write training statistics to tensorboard.
trainer.write_summary(summary_writer, steps)
trainer.write_summary(summary_writer, steps, env._curriculum.lesson_number)
if steps % save_freq == 0 and steps != 0 and train_model:
# Save Tensorflow model
save_model(sess, model_path=model_path, steps=steps, saver=saver)

正在加载...
取消
保存