浏览代码

Add curriculum support to PPO

/develop-generalizationTraining-TrainerController
Arthur Juliani 7 年前
当前提交
b6ce30bf
共有 5 个文件被更改,包括 62 次插入27 次删除
  1. 4
      python/curriculum.json
  2. 36
      python/ppo.py
  3. 9
      python/ppo/models.py
  4. 32
      python/unityagents/curriculum.py
  5. 8
      python/unityagents/environment.py

4
python/curriculum.json


{
"measure" : "progress",
"thresholds" : [0.1, 0.2, 0.5],
"measure" : "reward",
"thresholds" : [10, 20, 50],
"min_lesson_length" : 3,
"signal_smoothing" : true,
"parameters" :

36
python/ppo.py


Options:
--help Show this message.
--max-steps=<n> Maximum number of steps to run environment [default: 1e6].
--curriculum Whether to use curriculum for training (requires curriculum json) [default: False]
--curriculum-path=<path> Path to curriculum json file for environment [default: curriculum.json]
--max-steps=<n> Maximum number of steps to run environment [default: 1e6].
--train Whether to train model, or only run inference [default: True].
--train Whether to train model, or only run inference [default: False].
--summary-freq=<n> Frequency at which to save training statistics [default: 10000].
--save-freq=<n> Frequency at which to save model [default: 50000].
--gamma=<n> Reward discount rate [default: 0.99].

env_name = options['<env>']
keep_checkpoints = int(options['--keep-checkpoints'])
worker_id = int(options['--worker-id'])
use_curriculum = options['--curriculum']
if use_curriculum:
curriculum_path = str(options['--curriculum-path'])
else:
curriculum_path = None
# Algorithm-specific parameters for tuning
gamma = float(options['--gamma'])

hidden_units = int(options['--hidden-units'])
batch_size = int(options['--batch-size'])
env = UnityEnvironment(file_name=env_name, worker_id=worker_id)
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_path)
print(str(env))
brain_name = env.brain_names[0]

init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=keep_checkpoints)
def get_progress():
if use_curriculum:
if env._curriculum.measure_type == "progress":
return steps / max_steps
elif env._curriculum.measure_type == "reward":
return last_reward
else:
return None
else:
return None
with tf.Session() as sess:
# Instantiate model parameters
if load_model:

else:
sess.run(init)
steps = sess.run(ppo_model.global_step)
steps, last_reward = sess.run([ppo_model.global_step, ppo_model.last_reward])
info = env.reset(train_mode=train_model)[brain_name]
info = env.reset(train_mode=train_model, progress=get_progress())[brain_name]
info = env.reset(train_mode=train_model)[brain_name]
info = env.reset(train_mode=train_model, progress=get_progress())[brain_name]
# Decide and take an action
new_info = trainer.take_action(info, env, brain_name)
info = new_info

save_model(sess, model_path=model_path, steps=steps, saver=saver)
steps += 1
sess.run(ppo_model.increment_step)
if len(trainer.stats['cumulative_reward']) > 0:
mean_reward = np.mean(trainer.stats['cumulative_reward'])
print(mean_reward)
sess.run(ppo_model.update_reward, feed_dict={ppo_model.new_reward: mean_reward})
last_reward = sess.run(ppo_model.last_reward)
# Final save Tensorflow model
if steps != 0 and train_model:
save_model(sess, model_path=model_path, steps=steps, saver=saver)

9
python/ppo/models.py


class PPOModel(object):
def create_reward_encoder(self):
self.last_reward = tf.Variable(0, name="last_reward", trainable=False, dtype=tf.float32)
self.new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward')
self.update_reward = tf.assign(self.last_reward, self.new_reward)
def create_visual_encoder(self, o_size_h, o_size_w, bw, h_size, num_streams, activation):
"""
Builds a set of visual (CNN) encoders.

s_size = brain.state_space_size
a_size = brain.action_space_size
self.create_reward_encoder()
hidden_state, hidden_visual, hidden_policy, hidden_value = None, None, None, None
if brain.number_observations > 0:
h_size, w_size = brain.camera_resolutions[0]['height'], brain.camera_resolutions[0]['width']

:param brain: State-space size
:param h_size: Hidden layer size
"""
self.create_reward_encoder()
hidden_state, hidden_visual, hidden = None, None, None
if brain.number_observations > 0:
h_size, w_size = brain.camera_resolutions[0]['height'], brain.camera_resolutions[0]['width']

32
python/unityagents/curriculum.py


from .exception import UnityEnvironmentException
if location == None:
if location is None:
except:
except FileNotFoundError:
"The file {0} could not be found.".format(location))
"The file {0} could not be found.".format(location))
except UnicodeDecodeError:
raise UnityEnvironmentException("There was an error decoding {}".format(location))
'min_lesson_length', 'signal_smoothing']:
'min_lesson_length', 'signal_smoothing']:
"{1} field.".format(location, key))
"{1} field.".format(location, key))
parameters = self.data['parameters']
self.measure_type = self.data['measure']
self.max_lesson_number = len(self.data['thresholds'])

if len(parameters[key]) != self.max_lesson_number + 1:
raise UnityEnvironmentException(
"The parameter {0} in Curriculum {1} must have {2} values "
"but {3} were found".format(key, location,
self.max_lesson_number + 1, len(parameters[key])))
"but {3} were found".format(key, location,
self.max_lesson_number + 1, len(parameters[key])))
@property
def measure(self):

def set_lesson_number(self, value):
self.lesson_length = 0
self.lesson_number = max(0,min(value,self.max_lesson_number))
self.lesson_number = max(0, min(value, self.max_lesson_number))
if (self.data == None ) or (progress == None):
if self.data is None or progress is None:
progress = self.smoothing_value*0.9 + 0.1*progress
progress = self.smoothing_value * 0.9 + 0.1 * progress
if ((progress > self.data['thresholds'][self.lesson_number])
and (self.lesson_length > self.data['min_lesson_length'])):
if ((progress > self.data['thresholds'][self.lesson_number]) and
(self.lesson_length > self.data['min_lesson_length'])):
self.lesson_length = 0
self.lesson_number += 1
config = {}

return config

8
python/unityagents/environment.py


state_dict = json.loads(state)
return state_dict
def reset(self, train_mode=True, config=None, progress = None):
def reset(self, train_mode=True, config=None, progress=None):
"""
Sends a signal to reset the unity environment.
:return: A Data structure corresponding to the initial reset state of the environment.

if old_lesson != self._curriculum.get_lesson_number():
logger.info("\nLesson changed. Now in Lesson {0} : \n\t{1}"
logger.info("\nLesson changed. Now in Lesson {0} : \t{1}"
.format(self._curriculum.get_lesson_number(),
', '.join([str(x)+' -> '+str(config[x]) for x in config])))
else:
logger.info("\nEpisode Reset. In Lesson {0} : \t{1}"
.format(self._curriculum.get_lesson_number(),
', '.join([str(x)+' -> '+str(config[x]) for x in config])))
if self._loaded:

正在加载...
取消
保存