浏览代码
Multi Brain Training and Recurrent state encoder (#166)
Multi Brain Training and Recurrent state encoder (#166)
* `learn.py` is now main script for training brains. * Simultaneous multi-brain training is now possible. * `ghost-trainer` allows for proper training in adversarial scenarios. * `imitation-trainer` provides a basic implementation of real-time behavioral cloning. * All trainer hyperparameters now exist in `.yaml` files. * `PPO.ipynb` removed. * LSTM model added. * More dynamic buffer class to handle greater variety of scenarios./develop-generalizationTraining-TrainerController
Arthur Juliani
7 年前
当前提交
de700c3a
共有 21 个文件被更改,包括 2214 次插入 和 439 次删除
-
3python/requirements.txt
-
159python/test_unityagents.py
-
40python/unityagents/curriculum.py
-
48python/unityagents/environment.py
-
12python/unityagents/exception.py
-
1unity-environment/Assets/ML-Agents/Scripts/Agent.cs
-
15unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
-
20unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs
-
223python/learn.py
-
46python/trainer_configurations.yaml
-
227python/trainers/buffer.py
-
204python/trainers/ghost_trainer.py
-
315python/trainers/imitation_trainer.py
-
381python/trainers/ppo_models.py
-
403python/trainers/ppo_trainer.py
-
159python/trainers/trainer.py
-
237python/PPO.ipynb
-
160python/ppo.py
-
0/python/trainers/__init__.py
|
|||
# # Unity ML Agents |
|||
# ## ML-Agent Learning (PPO) |
|||
# Launches trainers for each External Brains in a Unity Environemnt |
|||
|
|||
import logging |
|||
import os |
|||
import re |
|||
import yaml |
|||
|
|||
from docopt import docopt |
|||
|
|||
from trainers.ghost_trainer import GhostTrainer |
|||
from trainers.ppo_models import * |
|||
from trainers.ppo_trainer import PPOTrainer |
|||
from trainers.imitation_trainer import ImitationTrainer |
|||
from unityagents import UnityEnvironment, UnityEnvironmentException |
|||
|
|||
def get_progress(): |
|||
if curriculum_file is not None: |
|||
if env.curriculum.measure_type == "progress": |
|||
progress = 0 |
|||
for brain_name in env.external_brain_names: |
|||
progress += trainers[brain_name].get_step / trainers[brain_name].get_max_steps |
|||
return progress / len(env.external_brain_names) |
|||
elif env.curriculum.measure_type == "reward": |
|||
progress = 0 |
|||
for brain_name in env.external_brain_names: |
|||
progress += trainers[brain_name].get_last_reward |
|||
return progress |
|||
else: |
|||
return None |
|||
else: |
|||
return None |
|||
|
|||
if __name__ == '__main__' : |
|||
logger = logging.getLogger("unityagents") |
|||
_USAGE = ''' |
|||
Usage: |
|||
ppo (<env>) [options] |
|||
|
|||
Options: |
|||
--help Show this message. |
|||
--curriculum=<file> Curriculum json file for environment [default: None]. |
|||
--slow Whether to run the game at training speed [default: False]. |
|||
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5]. |
|||
--lesson=<n> Start learning from this lesson [default: 0]. |
|||
--load Whether to load the model or randomly initialize [default: False]. |
|||
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo]. |
|||
--save-freq=<n> Frequency at which to save model [default: 50000]. |
|||
--train Whether to train model, or only run inference [default: False]. |
|||
--worker-id=<n> Number to add to communication port (5005). Used for multi-environment [default: 0]. |
|||
''' |
|||
|
|||
options = docopt(_USAGE) |
|||
logger.info(options) |
|||
|
|||
# General parameters |
|||
model_path = './models/{}'.format(str(options['--run-path'])) |
|||
|
|||
load_model = options['--load'] |
|||
train_model = options['--train'] |
|||
save_freq = int(options['--save-freq']) |
|||
env_name = options['<env>'] |
|||
keep_checkpoints = int(options['--keep-checkpoints']) |
|||
worker_id = int(options['--worker-id']) |
|||
curriculum_file = str(options['--curriculum']) |
|||
if curriculum_file == "None": |
|||
curriculum_file = None |
|||
lesson = int(options['--lesson']) |
|||
fast_simulation = not bool(options['--slow']) |
|||
|
|||
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file) |
|||
env.curriculum.set_lesson_number(lesson) |
|||
logger.info(str(env)) |
|||
|
|||
tf.reset_default_graph() |
|||
|
|||
try: |
|||
if not os.path.exists(model_path): |
|||
os.makedirs(model_path) |
|||
except: |
|||
raise UnityEnvironmentException("The folder {} containing the generated model could not be accessed." |
|||
" Please make sure the permissions are set correctly.".format(model_path)) |
|||
|
|||
try: |
|||
with open("trainer_configurations.yaml") as data_file: |
|||
trainer_configurations = yaml.load(data_file) |
|||
except IOError: |
|||
raise UnityEnvironmentException("The file {} could not be found. Will use default Hyperparameters" |
|||
.format("trainer_configurations.yaml")) |
|||
except UnicodeDecodeError: |
|||
raise UnityEnvironmentException("There was an error decoding {}".format("trainer_configurations.yaml")) |
|||
|
|||
with tf.Session() as sess: |
|||
trainers = {} |
|||
trainer_parameters_dict = {} |
|||
for brain_name in env.external_brain_names: |
|||
trainer_parameters = trainer_configurations['default'].copy() |
|||
if len(env.external_brain_names) > 1: |
|||
graph_scope = re.sub('[^0-9a-zA-Z]+', '-', brain_name) |
|||
trainer_parameters['graph_scope'] = graph_scope |
|||
trainer_parameters['summary_path'] = './summaries/{}'.format(str(options['--run-path']))+'_'+graph_scope |
|||
else : |
|||
trainer_parameters['graph_scope'] = '' |
|||
trainer_parameters['summary_path'] = './summaries/{}'.format(str(options['--run-path'])) |
|||
if brain_name in trainer_configurations: |
|||
_brain_key = brain_name |
|||
while not isinstance(trainer_configurations[_brain_key], dict): |
|||
_brain_key = trainer_configurations[_brain_key] |
|||
for k in trainer_configurations[_brain_key]: |
|||
trainer_parameters[k] = trainer_configurations[_brain_key][k] |
|||
trainer_parameters_dict[brain_name] = trainer_parameters.copy() |
|||
for brain_name in env.external_brain_names: |
|||
if 'is_ghost' not in trainer_parameters_dict[brain_name]: |
|||
trainer_parameters_dict[brain_name]['is_ghost'] = False |
|||
if 'is_imitation' not in trainer_parameters_dict[brain_name]: |
|||
trainer_parameters_dict[brain_name]['is_imitation'] = False |
|||
if trainer_parameters_dict[brain_name]['is_ghost']: |
|||
if trainer_parameters_dict[brain_name]['brain_to_copy'] not in env.external_brain_names: |
|||
raise UnityEnvironmentException("The external brain {0} could not be found in the environment " |
|||
"even though the ghost trainer of brain {1} is trying to ghost it." |
|||
.format(trainer_parameters_dict[brain_name]['brain_to_copy'], brain_name)) |
|||
trainer_parameters_dict[brain_name]['original_brain_parameters'] = trainer_parameters_dict[ |
|||
trainer_parameters_dict[brain_name]['brain_to_copy']] |
|||
trainers[brain_name] = GhostTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model) |
|||
elif trainer_parameters_dict[brain_name]['is_imitation']: |
|||
trainers[brain_name] = ImitationTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model) |
|||
else: |
|||
trainers[brain_name] = PPOTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model) |
|||
|
|||
for k, t in trainers.items(): |
|||
logger.info(t) |
|||
init = tf.global_variables_initializer() |
|||
saver = tf.train.Saver(max_to_keep=keep_checkpoints) |
|||
# Instantiate model parameters |
|||
if load_model: |
|||
logger.info('Loading Model...') |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
if ckpt == None: |
|||
logger.info('The model {0} could not be found. Make sure you specified the right ' |
|||
'--run-path'.format(model_path)) |
|||
saver.restore(sess, ckpt.model_checkpoint_path) |
|||
else: |
|||
sess.run(init) |
|||
global_step = 0 # This is only for saving the model |
|||
env.curriculum.increment_lesson(get_progress()) |
|||
info = env.reset(train_mode= fast_simulation) |
|||
if train_model: |
|||
for brain_name, trainer in trainers.items(): |
|||
trainer.write_tensorboard_text('Hyperparameters', trainer.parameters) |
|||
try: |
|||
while any([t.get_step <= t.get_max_steps for k, t in trainers.items()]) or not train_model: |
|||
if env.global_done: |
|||
env.curriculum.increment_lesson(get_progress()) |
|||
info = env.reset(train_mode=fast_simulation) |
|||
for brain_name, trainer in trainers.items(): |
|||
trainer.end_episode() |
|||
# Decide and take an action |
|||
take_action_actions = {} |
|||
take_action_memories = {} |
|||
take_action_values = {} |
|||
take_action_outputs = {} |
|||
for brain_name, trainer in trainers.items(): |
|||
(take_action_actions[brain_name], |
|||
take_action_memories[brain_name], |
|||
take_action_values[brain_name], |
|||
take_action_outputs[brain_name]) = trainer.take_action(info) |
|||
new_info = env.step(action = take_action_actions, memory = take_action_memories, value = take_action_values) |
|||
for brain_name, trainer in trainers.items(): |
|||
trainer.add_experiences(info, new_info, take_action_outputs[brain_name]) |
|||
info = new_info |
|||
for brain_name, trainer in trainers.items(): |
|||
trainer.process_experiences(info) |
|||
if trainer.is_ready_update() and train_model and trainer.get_step <= trainer.get_max_steps: |
|||
# Perform gradient descent with experience buffer |
|||
trainer.update_model() |
|||
# Write training statistics to tensorboard. |
|||
trainer.write_summary(env.curriculum.lesson_number) |
|||
if train_model and trainer.get_step <= trainer.get_max_steps: |
|||
trainer.increment_step() |
|||
trainer.update_last_reward() |
|||
if train_model and trainer.get_step <= trainer.get_max_steps: |
|||
global_step += 1 |
|||
if global_step % save_freq == 0 and global_step != 0 and train_model: |
|||
# Save Tensorflow model |
|||
save_model(sess, model_path=model_path, steps=global_step, saver=saver) |
|||
|
|||
# Final save Tensorflow model |
|||
if global_step != 0 and train_model: |
|||
save_model(sess, model_path=model_path, steps=global_step, saver=saver) |
|||
except KeyboardInterrupt: |
|||
if train_model: |
|||
logger.info("Learning was interupted. Please wait while the graph is generated.") |
|||
save_model(sess, model_path=model_path, steps=global_step, saver=saver) |
|||
pass |
|||
env.close() |
|||
if train_model: |
|||
graph_name = (env_name.strip() |
|||
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', '')) |
|||
graph_name = os.path.basename(os.path.normpath(graph_name)) |
|||
nodes = [] |
|||
scopes = [] |
|||
for brain_name in trainers.keys(): |
|||
if trainers[brain_name].graph_scope is not None: |
|||
scope = trainers[brain_name].graph_scope + '/' |
|||
if scope == '/': |
|||
scope = '' |
|||
scopes += [scope] |
|||
if trainers[brain_name].parameters["is_imitation"]: |
|||
nodes +=[scope + x for x in ["action"]] |
|||
elif not trainers[brain_name].parameters["use_recurrent"]: |
|||
nodes +=[scope + x for x in ["action","value_estimate","action_probs"]] |
|||
else: |
|||
nodes +=[scope + x for x in ["action","value_estimate","action_probs","recurrent_out"]] |
|||
export_graph(model_path, graph_name, target_nodes=','.join(nodes)) |
|||
if len(scopes) > 1: |
|||
logger.info("List of available scopes :") |
|||
for scope in scopes: |
|||
logger.info("\t" + scope ) |
|||
logger.info("List of nodes exported :") |
|||
for n in nodes: |
|||
logger.info("\t" + n) |
|||
|
|
|||
default: |
|||
batch_size: 64 |
|||
beta: 2.5e-3 |
|||
buffer_size: 2048 |
|||
epsilon: 0.2 |
|||
gamma: 0.99 |
|||
hidden_units: 64 |
|||
lambd: 0.95 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 1.0e6 |
|||
normalize: false |
|||
num_epoch: 5 |
|||
num_layers: 2 |
|||
time_horizon: 2048 |
|||
sequence_length: 32 |
|||
summary_freq: 10000 |
|||
use_recurrent: false |
|||
|
|||
Ball3DBrain: |
|||
summary_freq: 1000 |
|||
normalize: true |
|||
batch_size: 1024 |
|||
max_steps: 1.0e4 |
|||
|
|||
TurretBrain: ExampleBrain |
|||
ghost-HunterBrain: |
|||
brain_to_copy : HunterBrain |
|||
is_ghost : true |
|||
new_model_freq : 10000 |
|||
max_num_models : 20 |
|||
ghost-HunteeBrain : |
|||
brain_to_copy : HunteeBrain |
|||
is_ghost : true |
|||
new_model_freq : 10000 |
|||
max_num_models : 20 |
|||
ghost-Ball3DBrain: |
|||
brain_to_copy : Ball3DBrain |
|||
is_ghost : true |
|||
new_model_freq : 10000 |
|||
max_num_models : 3 |
|||
|
|||
Player: |
|||
# is_imitation : true |
|||
max_steps: 10000 |
|||
summary_freq: 1000 |
|||
|
|
|||
import numpy as np |
|||
|
|||
from unityagents.exception import UnityException |
|||
|
|||
class BufferException(UnityException): |
|||
""" |
|||
Related to errors with the Buffer. |
|||
""" |
|||
pass |
|||
|
|||
class Buffer(dict): |
|||
""" |
|||
Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id. |
|||
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model. |
|||
""" |
|||
class AgentBuffer(dict): |
|||
""" |
|||
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer. |
|||
The keys correspond to the name of the field. Example: state, action |
|||
""" |
|||
class AgentBufferField(list): |
|||
""" |
|||
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his |
|||
AgentBufferField with the append method. |
|||
""" |
|||
def __str__(self): |
|||
return str(np.array(self).shape) |
|||
|
|||
def extend(self, data): |
|||
""" |
|||
Ads a list of np.arrays to the end of the list of np.arrays. |
|||
:param data: The np.array list to append. |
|||
""" |
|||
self += list(np.array(data)) |
|||
|
|||
def set(self, data): |
|||
""" |
|||
Sets the list of np.array to the input data |
|||
:param data: The np.array list to be set. |
|||
""" |
|||
self[:] = [] |
|||
self[:] = list(np.array(data)) |
|||
|
|||
def get_batch(self, batch_size = None, training_length = None, sequential = True): |
|||
""" |
|||
Retrieve the last batch_size elements of length training_length |
|||
from the list of np.array |
|||
:param batch_size: The number of elements to retrieve. If None: |
|||
All elements will be retrieved. |
|||
:param training_length: The length of the sequence to be retrieved. If |
|||
None: only takes one element. |
|||
:param sequential: If true and training_length is not None: the elements |
|||
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and |
|||
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives |
|||
[[a,b],[b,c],[c,d],[d,e]] |
|||
""" |
|||
if training_length is None: |
|||
# When the training length is None, the method returns a list of elements, |
|||
# not a list of sequences of elements. |
|||
if batch_size is None: |
|||
# If batch_size is None : All the elements of the AgentBufferField are returned. |
|||
return np.array(self) |
|||
else: |
|||
# return the batch_size last elements |
|||
if batch_size > len(self): |
|||
raise BufferException("Batch size requested is too large") |
|||
return np.array(self[-batch_size:]) |
|||
else: |
|||
# The training_length is not None, the method returns a list of SEQUENCES of elements |
|||
if not sequential: |
|||
# The sequences will have overlapping elements |
|||
if batch_size is None: |
|||
# retrieve the maximum number of elements |
|||
batch_size = len(self) - training_length + 1 |
|||
# The number of sequences of length training_length taken from a list of len(self) elements |
|||
# with overlapping is equal to batch_size |
|||
if (len(self) - training_length + 1) < batch_size : |
|||
raise BufferException("The batch size and training length requested for get_batch where" |
|||
" too large given the current number of data points.") |
|||
return |
|||
tmp_list = [] |
|||
for end in range(len(self)-batch_size+1, len(self)+1): |
|||
tmp_list += [np.array(self[end-training_length:end])] |
|||
return np.array(tmp_list) |
|||
if sequential: |
|||
# The sequences will not have overlapping elements (this involves padding) |
|||
leftover = len(self) % training_length |
|||
# leftover is the number of elements in the first sequence (this sequence might need 0 padding) |
|||
if batch_size is None: |
|||
# retrieve the maximum number of elements |
|||
batch_size = len(self) // training_length +1 *(leftover != 0) |
|||
# The maximum number of sequences taken from a list of length len(self) without overlapping |
|||
# with padding is equal to batch_size |
|||
if batch_size > (len(self) // training_length +1 *(leftover != 0)): |
|||
raise BufferException("The batch size and training length requested for get_batch where" |
|||
" too large given the current number of data points.") |
|||
return |
|||
tmp_list = [] |
|||
padding = np.array(self[-1]) * 0 |
|||
# The padding is made with zeros and its shape is given by the shape of the last element |
|||
for end in range(len(self), len(self) % training_length , -training_length)[:batch_size]: |
|||
tmp_list += [np.array(self[end-training_length:end])] |
|||
if (leftover != 0) and (len(tmp_list) < batch_size): |
|||
tmp_list +=[np.array([padding]*(training_length - leftover)+self[:leftover])] |
|||
tmp_list.reverse() |
|||
return np.array(tmp_list) |
|||
|
|||
def reset_field(self): |
|||
""" |
|||
Resets the AgentBufferField |
|||
""" |
|||
self[:] = [] |
|||
|
|||
|
|||
|
|||
def __str__(self): |
|||
return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()]) |
|||
|
|||
def reset_agent(self): |
|||
""" |
|||
Resets the AgentBuffer |
|||
""" |
|||
for k in self.keys(): |
|||
self[k].reset_field() |
|||
|
|||
def __getitem__(self, key): |
|||
if key not in self.keys(): |
|||
self[key] = self.AgentBufferField() |
|||
return super(Buffer.AgentBuffer, self).__getitem__(key) |
|||
|
|||
def check_length(self, key_list): |
|||
""" |
|||
Some methods will require that some fields have the same length. |
|||
check_length will return true if the fields in key_list |
|||
have the same length. |
|||
:param key_list: The fields which length will be compared |
|||
""" |
|||
if len(key_list) < 2: |
|||
return True |
|||
l = None |
|||
for key in key_list: |
|||
if key not in self.keys(): |
|||
return False |
|||
if ((l != None) and (l!=len(self[key]))): |
|||
return False |
|||
l = len(self[key]) |
|||
return True |
|||
|
|||
def shuffle(self, key_list = None): |
|||
""" |
|||
Shuffles the fields in key_list in a consistent way: The reordering will |
|||
be the same accross fields. |
|||
:param key_list: The fields that must be shuffled. |
|||
""" |
|||
if key_list is None: |
|||
key_list = list(self.keys()) |
|||
if not self.check_length(key_list): |
|||
raise BufferException("Unable to shuffle if the fields are not of same length") |
|||
return |
|||
s = np.arange(len(self[key_list[0]])) |
|||
np.random.shuffle(s) |
|||
for key in key_list: |
|||
self[key][:] = [self[key][i] for i in s] |
|||
|
|||
|
|||
|
|||
def __init__(self): |
|||
self.update_buffer = self.AgentBuffer() |
|||
super(Buffer, self).__init__() |
|||
|
|||
def __str__(self): |
|||
return "update buffer :\n\t{0}\nlocal_buffers :\n{1}".format(str(self.update_buffer), |
|||
'\n'.join(['\tagent {0} :{1}'.format(k, str(self[k])) for k in self.keys()])) |
|||
|
|||
def __getitem__(self, key): |
|||
if key not in self.keys(): |
|||
self[key] = self.AgentBuffer() |
|||
return super(Buffer, self).__getitem__(key) |
|||
|
|||
def reset_update_buffer(self): |
|||
""" |
|||
Resets the update buffer |
|||
""" |
|||
self.update_buffer.reset_agent() |
|||
|
|||
def reset_all(self): |
|||
""" |
|||
Resets the update buffer and all the local local_buffers |
|||
""" |
|||
self.update_buffer.reset_agent() |
|||
agent_ids = list(self.keys()) |
|||
for k in agent_ids: |
|||
self[k].reset_agent() |
|||
|
|||
def append_update_buffer(self, agent_id ,key_list = None, batch_size = None, training_length = None): |
|||
""" |
|||
Appends the buffer of an agent to the update buffer. |
|||
:param agent_id: The id of the agent which data will be appended |
|||
:param key_list: The fields that must be added. If None: all fields will be appended. |
|||
:param batch_size: The number of elements that must be appended. If None: All of them will be. |
|||
:param training_length: The length of the samples that must be appended. If None: only takes one element. |
|||
""" |
|||
if key_list is None: |
|||
key_list = self[agent_id].keys() |
|||
if not self[agent_id].check_length(key_list): |
|||
raise BufferException("The length of the fields {0} for agent {1} where not of same length" |
|||
.format(key_list, agent_id)) |
|||
for field_key in key_list: |
|||
self.update_buffer[field_key].extend( |
|||
self[agent_id][field_key].get_batch(batch_size =batch_size, training_length =training_length) |
|||
) |
|||
|
|||
def append_all_agent_batch_to_update_buffer(self, key_list = None, batch_size = None, training_length = None): |
|||
""" |
|||
Appends the buffer of all agents to the update buffer. |
|||
:param key_list: The fields that must be added. If None: all fields will be appended. |
|||
:param batch_size: The number of elements that must be appended. If None: All of them will be. |
|||
:param training_length: The length of the samples that must be appended. If None: only takes one element. |
|||
""" |
|||
for agent_id in self.keys(): |
|||
self.append_update_buffer(agent_id ,key_list, batch_size, training_length) |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|
|||
import logging |
|||
import os |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
|
|||
from trainers.ppo_models import * |
|||
from trainers.trainer import UnityTrainerException, Trainer |
|||
|
|||
logger = logging.getLogger("unityagents") |
|||
|
|||
#This works only with PPO |
|||
class GhostTrainer(Trainer): |
|||
"""Keeps copies of a PPOTrainer past graphs and uses them to other Trainers.""" |
|||
def __init__(self, sess, env, brain_name, trainer_parameters, training): |
|||
""" |
|||
Responsible for saving and reusing past models. |
|||
:param sess: Tensorflow session. |
|||
:param env: The UnityEnvironment. |
|||
:param trainer_parameters: The parameters for the trainer (dictionary). |
|||
:param training: Whether the trainer is set for training. |
|||
""" |
|||
self.param_keys = ['brain_to_copy', 'is_ghost', 'new_model_freq', 'max_num_models'] |
|||
for k in self.param_keys: |
|||
if k not in trainer_parameters: |
|||
raise UnityTrainerException("The hyperparameter {0} could not be found for the PPO trainer of " |
|||
"brain {1}.".format(k, brain_name)) |
|||
|
|||
super(GhostTrainer, self).__init__(sess, env, brain_name, trainer_parameters, training) |
|||
|
|||
self.brain_to_copy = trainer_parameters['brain_to_copy'] |
|||
self.variable_scope = trainer_parameters['graph_scope'] |
|||
self.original_brain_parameters = trainer_parameters['original_brain_parameters'] |
|||
self.new_model_freq = trainer_parameters['new_model_freq'] |
|||
self.steps = 0 |
|||
self.models = [] |
|||
self.max_num_models = trainer_parameters['max_num_models'] |
|||
self.last_model_replaced = 0 |
|||
for i in range(self.max_num_models): |
|||
with tf.variable_scope(self.variable_scope+'_'+str(i)): |
|||
self.models += [create_agent_model(env.brains[self.brain_to_copy], |
|||
lr=float(self.original_brain_parameters['learning_rate']), |
|||
h_size=int(self.original_brain_parameters['hidden_units']), |
|||
epsilon=float(self.original_brain_parameters['epsilon']), |
|||
beta=float(self.original_brain_parameters['beta']), |
|||
max_step=float(self.original_brain_parameters['max_steps']), |
|||
normalize=self.original_brain_parameters['normalize'], |
|||
use_recurrent=self.original_brain_parameters['use_recurrent'], |
|||
num_layers=int(self.original_brain_parameters['num_layers']), |
|||
m_size = self.original_brain_parameters)] |
|||
self.model = self.models[0] |
|||
|
|||
|
|||
self.is_continuous = (env.brains[brain_name].action_space_type == "continuous") |
|||
self.use_observations = (env.brains[brain_name].number_observations > 0) |
|||
self.use_states = (env.brains[brain_name].state_space_size > 0) |
|||
self.use_recurrent = self.original_brain_parameters["use_recurrent"] |
|||
self.summary_path = trainer_parameters['summary_path'] |
|||
|
|||
|
|||
def __str__(self): |
|||
return '''Hypermarameters for the Ghost Trainer of brain {0}: \n{1}'''.format( |
|||
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys])) |
|||
|
|||
|
|||
@property |
|||
def parameters(self): |
|||
""" |
|||
Returns the trainer parameters of the trainer. |
|||
""" |
|||
return self.trainer_parameters |
|||
|
|||
@property |
|||
def graph_scope(self): |
|||
""" |
|||
Returns the graph scope of the trainer. |
|||
""" |
|||
return None |
|||
|
|||
@property |
|||
def get_max_steps(self): |
|||
""" |
|||
Returns the maximum number of steps. Is used to know when the trainer should be stopped. |
|||
:return: The maximum number of steps of the trainer |
|||
""" |
|||
return 1 |
|||
|
|||
@property |
|||
def get_step(self): |
|||
""" |
|||
Returns the number of steps the trainer has performed |
|||
:return: the step count of the trainer |
|||
""" |
|||
return 0 |
|||
|
|||
@property |
|||
def get_last_reward(self): |
|||
""" |
|||
Returns the last reward the trainer has had |
|||
:return: the new last reward |
|||
""" |
|||
return 0 |
|||
|
|||
def increment_step(self): |
|||
""" |
|||
Increment the step count of the trainer |
|||
""" |
|||
self.steps += 1 |
|||
|
|||
def update_last_reward(self): |
|||
""" |
|||
Updates the last reward |
|||
""" |
|||
return |
|||
|
|||
def update_target_graph(self, from_scope, to_scope): |
|||
from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope) |
|||
to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope) |
|||
op_holder = [] |
|||
for from_var,to_var in zip(from_vars,to_vars): |
|||
op_holder.append(to_var.assign(from_var)) |
|||
return op_holder |
|||
|
|||
def take_action(self, info): |
|||
""" |
|||
Decides actions given state/observation information, and takes them in environment. |
|||
:param info: Current BrainInfo from environment. |
|||
:return: a tupple containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
|
|||
|
|||
epsi = None |
|||
info = info[self.brain_name] |
|||
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length: 1} |
|||
run_list = [self.model.output] |
|||
if self.is_continuous: |
|||
epsi = np.random.randn(len(info.states), self.brain.action_space_size) |
|||
feed_dict[self.model.epsilon] = epsi |
|||
if self.use_observations: |
|||
for i, _ in enumerate(info.observations): |
|||
feed_dict[self.model.observation_in[i]] = info.observations[i] |
|||
if self.use_states: |
|||
feed_dict[self.model.state_in] = info.states |
|||
if self.use_recurrent: |
|||
feed_dict[self.model.memory_in] = info.memories |
|||
run_list += [self.model.memory_out] |
|||
if self.use_recurrent: |
|||
actions, memories = self.sess.run(run_list, feed_dict=feed_dict) |
|||
else: |
|||
actions = self.sess.run(run_list, feed_dict=feed_dict) |
|||
memories = None |
|||
return (actions, memories, None, None) |
|||
|
|||
def add_experiences(self, info, next_info, take_action_outputs): |
|||
""" |
|||
Adds experiences to each agent's experience history. |
|||
:param info: Current BrainInfo. |
|||
:param next_info: Next BrainInfo. |
|||
:param take_action_outputs: The outputs of the take action method. |
|||
""" |
|||
return |
|||
|
|||
def process_experiences(self, info): |
|||
""" |
|||
Checks agent histories for processing condition, and processes them as necessary. |
|||
Processing involves calculating value and advantage targets for model updating step. |
|||
:param info: Current BrainInfo |
|||
""" |
|||
return |
|||
|
|||
def end_episode(self): |
|||
""" |
|||
A signal that the Episode has ended. We must use another version of the graph. |
|||
""" |
|||
self.model = self.models[np.random.randint(0, self.max_num_models)] |
|||
|
|||
def is_ready_update(self): |
|||
""" |
|||
Returns wether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to wether or not update_model() can be run |
|||
""" |
|||
return self.steps % self.new_model_freq == 0 |
|||
|
|||
def update_model(self): |
|||
""" |
|||
Uses training_buffer to update model. |
|||
""" |
|||
self.last_model_replaced = (self.last_model_replaced + 1) % self.max_num_models |
|||
self.sess.run(self.update_target_graph( |
|||
self.original_brain_parameters['graph_scope'], |
|||
self.variable_scope+'_'+str(self.last_model_replaced)) |
|||
) |
|||
return |
|||
|
|||
def write_summary(self, lesson_number): |
|||
""" |
|||
Saves training statistics to Tensorboard. |
|||
:param lesson_number: The lesson the trainer is at. |
|||
""" |
|||
return |
|||
|
|||
|
|||
|
|
|||
# # Unity ML Agents |
|||
# ## ML-Agent Learning (PPO) |
|||
# Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347). |
|||
|
|||
import logging |
|||
import os |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
|
|||
from trainers.buffer import Buffer |
|||
from trainers.ppo_models import * |
|||
from trainers.trainer import UnityTrainerException, Trainer |
|||
|
|||
logger = logging.getLogger("unityagents") |
|||
|
|||
class ImitationNN(object): |
|||
def __init__(self, state_size, action_size, h_size, lr, action_type, n_layers): |
|||
self.state = tf.placeholder(shape=[None, state_size], dtype=tf.float32, name="state") |
|||
hidden = tf.layers.dense(self.state, h_size, activation=tf.nn.elu) |
|||
for i in range(n_layers): |
|||
hidden = tf.layers.dense(hidden, h_size, activation=tf.nn.elu) |
|||
hidden_drop = tf.layers.dropout(hidden, 0.5) |
|||
self.output = tf.layers.dense(hidden_drop, action_size, activation=None) |
|||
|
|||
if action_type == "discrete": |
|||
self.action_probs = tf.nn.softmax(self.output) |
|||
self.sample_action = tf.multinomial(self.output, 1, name="action") |
|||
self.true_action = tf.placeholder(shape=[None], dtype=tf.int32) |
|||
self.action_oh = tf.one_hot(self.true_action, action_size) |
|||
self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) |
|||
|
|||
self.action_percent = tf.reduce_mean(tf.cast( |
|||
tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) |
|||
else: |
|||
self.sample_action = tf.identity(self.output, name="action") |
|||
self.true_action = tf.placeholder(shape=[None, action_size], dtype=tf.float32) |
|||
self.loss = tf.reduce_sum(tf.squared_difference(self.true_action, self.sample_action)) |
|||
|
|||
optimizer = tf.train.AdamOptimizer(learning_rate=lr) |
|||
self.update = optimizer.minimize(self.loss) |
|||
|
|||
|
|||
|
|||
class ImitationTrainer(Trainer): |
|||
"""The ImitationTrainer is an implementation of the imitation learning.""" |
|||
def __init__(self, sess, env, brain_name, trainer_parameters, training): |
|||
""" |
|||
Responsible for collecting experiences and training PPO model. |
|||
:param sess: Tensorflow session. |
|||
:param env: The UnityEnvironment. |
|||
:param trainer_parameters: The parameters for the trainer (dictionary). |
|||
:param training: Whether the trainer is set for training. |
|||
""" |
|||
self.param_keys = [ 'is_imitation', 'brain_to_imitate', 'batch_size', 'time_horizon', 'graph_scope', |
|||
'summary_freq', 'max_steps'] |
|||
|
|||
for k in self.param_keys: |
|||
if k not in trainer_parameters: |
|||
raise UnityTrainerException("The hyperparameter {0} could not be found for the Imitation trainer of " |
|||
"brain {1}.".format(k, brain_name)) |
|||
|
|||
super(ImitationTrainer, self).__init__(sess, env, brain_name, trainer_parameters, training) |
|||
|
|||
self.variable_scope = trainer_parameters['graph_scope'] |
|||
self.brain_to_imitate = trainer_parameters['brain_to_imitate'] |
|||
self.batch_size = trainer_parameters['batch_size'] |
|||
self.step = 0 |
|||
self.cumulative_rewards = {} |
|||
self.episode_steps = {} |
|||
|
|||
|
|||
self.stats = {'losses': [], 'episode_length': [], 'cumulative_reward' : []} |
|||
|
|||
self.training_buffer = Buffer() |
|||
self.is_continuous = (env.brains[brain_name].action_space_type == "continuous") |
|||
self.use_observations = (env.brains[brain_name].number_observations > 0) |
|||
if self.use_observations: |
|||
logger.log('Cannot use observations with imitation learning') |
|||
self.use_states = (env.brains[brain_name].state_space_size > 0) |
|||
self.summary_path = trainer_parameters['summary_path'] |
|||
if not os.path.exists(self.summary_path): |
|||
os.makedirs(self.summary_path) |
|||
|
|||
self.summary_writer = tf.summary.FileWriter(self.summary_path) |
|||
s_size = self.brain.state_space_size * 1#brain_parameters.stacked_states |
|||
a_size = self.brain.action_space_size |
|||
with tf.variable_scope(self.variable_scope): |
|||
self.network = ImitationNN(state_size = s_size, |
|||
action_size = a_size, |
|||
h_size = int(trainer_parameters['hidden_units']), |
|||
lr = float(trainer_parameters['learning_rate']), |
|||
action_type = self.brain.action_space_type, |
|||
n_layers=int(trainer_parameters['num_layers'])) |
|||
|
|||
|
|||
def __str__(self): |
|||
|
|||
return '''Hypermarameters for the Imitation Trainer of brain {0}: \n{1}'''.format( |
|||
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys])) |
|||
|
|||
@property |
|||
def parameters(self): |
|||
""" |
|||
Returns the trainer parameters of the trainer. |
|||
""" |
|||
return self.trainer_parameters |
|||
|
|||
@property |
|||
def graph_scope(self): |
|||
""" |
|||
Returns the graph scope of the trainer. |
|||
""" |
|||
return self.variable_scope |
|||
|
|||
@property |
|||
def get_max_steps(self): |
|||
""" |
|||
Returns the maximum number of steps. Is used to know when the trainer should be stopped. |
|||
:return: The maximum number of steps of the trainer |
|||
""" |
|||
return self.trainer_parameters['max_steps'] |
|||
|
|||
@property |
|||
def get_step(self): |
|||
""" |
|||
Returns the number of steps the trainer has performed |
|||
:return: the step count of the trainer |
|||
""" |
|||
return self.step |
|||
|
|||
@property |
|||
def get_last_reward(self): |
|||
""" |
|||
Returns the last reward the trainer has had |
|||
:return: the new last reward |
|||
""" |
|||
if len(self.stats['cumulative_reward']) > 0: |
|||
return np.mean(self.stats['cumulative_reward']) |
|||
else: |
|||
return 0 |
|||
|
|||
def increment_step(self): |
|||
""" |
|||
Increment the step count of the trainer |
|||
""" |
|||
self.step += 1 |
|||
|
|||
def update_last_reward(self): |
|||
""" |
|||
Updates the last reward |
|||
""" |
|||
return |
|||
|
|||
|
|||
def take_action(self, info): |
|||
""" |
|||
Decides actions given state/observation information, and takes them in environment. |
|||
:param info: Current BrainInfo from environment. |
|||
:return: a tupple containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
E = info[self.brain_name] |
|||
agent_action = self.sess.run(self.network.sample_action, feed_dict={self.network.state: E.states}) |
|||
|
|||
return (agent_action, None, None, None) |
|||
|
|||
def add_experiences(self, info, next_info, take_action_outputs): |
|||
""" |
|||
Adds experiences to each agent's experience history. |
|||
:param info: Current BrainInfo. |
|||
:param next_info: Next BrainInfo. |
|||
:param take_action_outputs: The outputs of the take action method. |
|||
""" |
|||
info_P = info[self.brain_to_imitate] |
|||
next_info_P = next_info[self.brain_to_imitate] |
|||
for agent_id in info_P.agents: |
|||
if agent_id in next_info_P.agents: |
|||
idx = info_P.agents.index(agent_id) |
|||
next_idx = next_info_P.agents.index(agent_id) |
|||
if not info_P.local_done[idx]: |
|||
self.training_buffer[agent_id]['states'].append(info_P.states[idx]) |
|||
self.training_buffer[agent_id]['actions'].append(next_info_P.previous_actions[next_idx]) |
|||
# self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx]) |
|||
|
|||
info_E = next_info[self.brain_name] |
|||
next_info_E = next_info[self.brain_name] |
|||
for agent_id in info_E.agents: |
|||
idx = info_E.agents.index(agent_id) |
|||
next_idx = next_info_E.agents.index(agent_id) |
|||
if not info_E.local_done[idx]: |
|||
if agent_id not in self.cumulative_rewards: |
|||
self.cumulative_rewards[agent_id] = 0 |
|||
self.cumulative_rewards[agent_id] += next_info_E.rewards[next_idx] |
|||
if agent_id not in self.episode_steps: |
|||
self.episode_steps[agent_id] = 0 |
|||
self.episode_steps[agent_id] += 1 |
|||
|
|||
def process_experiences(self, info): |
|||
""" |
|||
Checks agent histories for processing condition, and processes them as necessary. |
|||
Processing involves calculating value and advantage targets for model updating step. |
|||
:param info: Current BrainInfo |
|||
""" |
|||
|
|||
info_P = info[self.brain_to_imitate] |
|||
for l in range(len(info_P.agents)): |
|||
if ((info_P.local_done[l] or |
|||
len(self.training_buffer[info_P.agents[l]]['actions']) > self.trainer_parameters['time_horizon']) |
|||
and len(self.training_buffer[info_P.agents[l]]['actions']) > 0): |
|||
agent_id = info_P.agents[l] |
|||
self.training_buffer.append_update_buffer(agent_id, |
|||
batch_size = None, training_length=None) |
|||
self.training_buffer[agent_id].reset_agent() |
|||
|
|||
info_E = info[self.brain_name] |
|||
for l in range(len(info_E.agents)): |
|||
if info_E.local_done[l]: |
|||
agent_id = info_E.agents[l] |
|||
self.stats['cumulative_reward'].append(self.cumulative_rewards[agent_id]) |
|||
self.stats['episode_length'].append(self.episode_steps[agent_id]) |
|||
self.cumulative_rewards[agent_id] = 0 |
|||
self.episode_steps[agent_id] = 0 |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
def end_episode(self): |
|||
""" |
|||
A signal that the Episode has ended. The buffer must be reset. |
|||
Get only called when the academy resets. |
|||
""" |
|||
self.training_buffer.reset_all() |
|||
for agent_id in self.cumulative_rewards: |
|||
self.cumulative_rewards[agent_id] = 0 |
|||
for agent_id in self.episode_steps: |
|||
self.episode_steps[agent_id] = 0 |
|||
|
|||
def is_ready_update(self): |
|||
""" |
|||
Returns wether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to wether or not update_model() can be run |
|||
""" |
|||
return len(self.training_buffer.update_buffer['actions']) > 1 |
|||
|
|||
def update_model(self): |
|||
""" |
|||
Uses training_buffer to update model. |
|||
""" |
|||
# num_epoch = self.trainer_parameters['num_epoch'] |
|||
batch_size = self.trainer_parameters['batch_size'] |
|||
# strange from there |
|||
|
|||
|
|||
self.training_buffer.update_buffer.shuffle() |
|||
batch_losses = [] |
|||
for j in range(len(self.training_buffer.update_buffer['actions']) // self.batch_size): |
|||
_buffer = self.training_buffer.update_buffer |
|||
# batch_states = shuffle_states[j * batch_size:(j + 1) * batch_size] |
|||
batch_states = np.array(_buffer['states'][j * batch_size:(j + 1) * batch_size]) |
|||
# batch_actions = shuffle_actions[j * batch_size:(j + 1) * batch_size] |
|||
batch_actions = np.array(_buffer['actions'][j * batch_size:(j + 1) * batch_size]) |
|||
if not self.is_continuous: |
|||
feed_dict = { |
|||
self.network.state: batch_states.reshape([-1, 1]), |
|||
self.network.true_action: np.reshape(batch_actions, -1) |
|||
} |
|||
else: |
|||
feed_dict = { |
|||
self.network.state: batch_states.reshape([self.batch_size, -1]), |
|||
self.network.true_action: batch_actions.reshape([self.batch_size, -1]) |
|||
} |
|||
loss, _ = self.sess.run([self.network.loss, self.network.update], feed_dict=feed_dict) |
|||
batch_losses.append(loss) |
|||
if len(batch_losses) > 0: |
|||
self.stats['losses'].append(np.mean(batch_losses)) |
|||
else: |
|||
self.stats['losses'].append(0) |
|||
|
|||
self.training_buffer.reset_all() |
|||
# Do we clear it at some point ? |
|||
# self.training_buffer.reset_update_buffer() |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
def write_summary(self, lesson_number): |
|||
""" |
|||
Saves training statistics to Tensorboard. |
|||
:param lesson_number: The lesson the trainer is at. |
|||
""" |
|||
if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and |
|||
self.is_training and self.get_step <= self.get_max_steps): |
|||
steps = self.get_step |
|||
if len(self.stats['cumulative_reward']) > 0: |
|||
mean_reward = np.mean(self.stats['cumulative_reward']) |
|||
logger.info("{0} : Step: {1}. Mean Reward: {2}. Std of Reward: {3}." |
|||
.format(self.brain_name, steps, mean_reward, np.std(self.stats['cumulative_reward']))) |
|||
summary = tf.Summary() |
|||
for key in self.stats: |
|||
if len(self.stats[key]) > 0: |
|||
stat_mean = float(np.mean(self.stats[key])) |
|||
summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean) |
|||
self.stats[key] = [] |
|||
summary.value.add(tag='Info/Lesson', simple_value=lesson_number) |
|||
self.summary_writer.add_summary(summary, steps) |
|||
self.summary_writer.flush() |
|||
|
|||
|
|||
|
|||
|
|
|||
import logging |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
import tensorflow.contrib.layers as c_layers |
|||
|
|||
from tensorflow.python.tools import freeze_graph |
|||
from unityagents import UnityEnvironmentException |
|||
|
|||
logger = logging.getLogger("unityagents") |
|||
|
|||
|
|||
def create_agent_model(brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6, normalize=False, use_recurrent = False, num_layers=2, m_size = None): |
|||
""" |
|||
Takes a Unity environment and model-specific hyper-parameters and returns the |
|||
appropriate PPO agent model for the environment. |
|||
:param env: a Unity environment. |
|||
:param lr: Learning rate. |
|||
:param h_size: Size of hidden layers/ |
|||
:param epsilon: Value for policy-divergence threshold. |
|||
:param beta: Strength of entropy regularization. |
|||
:return: a sub-class of PPOAgent tailored to the environment. |
|||
:param max_step: Total number of training steps. |
|||
""" |
|||
|
|||
if num_layers < 1: num_layers = 1 |
|||
|
|||
if brain.action_space_type == "continuous": |
|||
return ContinuousControlModel(lr, brain, h_size, epsilon, max_step, normalize, use_recurrent, num_layers, m_size) |
|||
if brain.action_space_type == "discrete": |
|||
return DiscreteControlModel(lr, brain, h_size, epsilon, beta, max_step, normalize, use_recurrent, num_layers, m_size) |
|||
|
|||
|
|||
def save_model(sess, saver, model_path="./", steps=0): |
|||
""" |
|||
Saves current model to checkpoint folder. |
|||
:param sess: Current Tensorflow session. |
|||
:param model_path: Designated model path. |
|||
:param steps: Current number of steps in training process. |
|||
:param saver: Tensorflow saver for session. |
|||
""" |
|||
last_checkpoint = model_path + '/model-' + str(steps) + '.cptk' |
|||
saver.save(sess, last_checkpoint) |
|||
tf.train.write_graph(sess.graph_def, model_path, 'raw_graph_def.pb', as_text=False) |
|||
logger.info("Saved Model") |
|||
|
|||
|
|||
def export_graph(model_path, env_name="env", target_nodes="action,value_estimate,action_probs"): |
|||
""" |
|||
Exports latest saved model to .bytes format for Unity embedding. |
|||
:param model_path: path of model checkpoints. |
|||
:param env_name: Name of associated Learning Environment. |
|||
:param target_nodes: Comma separated string of needed output nodes for embedded graph. |
|||
""" |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
freeze_graph.freeze_graph(input_graph=model_path + '/raw_graph_def.pb', |
|||
input_binary=True, |
|||
input_checkpoint=ckpt.model_checkpoint_path, |
|||
output_node_names=target_nodes, |
|||
output_graph=model_path + '/' + env_name + '.bytes', |
|||
clear_devices=True, initializer_nodes="", input_saver="", |
|||
restore_op_name="save/restore_all", filename_tensor_name="save/Const:0") |
|||
|
|||
|
|||
class PPOModel(object): |
|||
def __init__(self): |
|||
self.normalize = False |
|||
self.use_recurrent = False |
|||
self.observation_in = [] |
|||
|
|||
def create_global_steps(self): |
|||
"""Creates TF ops to track and increment global training step.""" |
|||
self.global_step = tf.Variable(0, name="global_step", trainable=False, dtype=tf.int32) |
|||
self.increment_step = tf.assign(self.global_step, self.global_step + 1) |
|||
|
|||
def create_reward_encoder(self): |
|||
"""Creates TF ops to track and increment recent average cumulative reward.""" |
|||
self.last_reward = tf.Variable(0, name="last_reward", trainable=False, dtype=tf.float32) |
|||
self.new_reward = tf.placeholder(shape=[], dtype=tf.float32, name='new_reward') |
|||
self.update_reward = tf.assign(self.last_reward, self.new_reward) |
|||
|
|||
def create_recurrent_encoder(self, s_size, input_state): |
|||
""" |
|||
Builds a recurrent encoder for either state or observations (LSTM). |
|||
:param s_size: Dimension of the input tensor. |
|||
:param input_state: The input tensor to the LSTM cell. |
|||
""" |
|||
self.lstm_input_state = tf.reshape(input_state, shape = [-1, self.sequence_length, s_size]) |
|||
self.memory_in = tf.placeholder(shape=[None, self.m_size],dtype=tf.float32, name='recurrent_in') |
|||
_half_point = int(self.m_size/2) |
|||
rnn_cell = tf.contrib.rnn.BasicLSTMCell(_half_point) |
|||
lstm_state_in = tf.contrib.rnn.LSTMStateTuple(self.memory_in[:,:_half_point], self.memory_in[:,_half_point:]) |
|||
self.recurrent_state, self.lstm_state_out = tf.nn.dynamic_rnn(rnn_cell, self.lstm_input_state, |
|||
initial_state=lstm_state_in, |
|||
time_major=False, |
|||
dtype=tf.float32) |
|||
self.memory_out = tf.concat([self.lstm_state_out.c,self.lstm_state_out.h], axis = 1) |
|||
self.memory_out = tf.identity(self.memory_out, name = 'recurrent_out') |
|||
recurrent_state = tf.reshape(self.recurrent_state, shape = [-1, _half_point]) |
|||
return recurrent_state |
|||
|
|||
def create_visual_encoder(self, o_size_h, o_size_w, bw, h_size, num_streams, activation, num_layers): |
|||
""" |
|||
Builds a set of visual (CNN) encoders. |
|||
:param o_size_h: Height observation size. |
|||
:param o_size_w: Width observation size. |
|||
:param bw: Whether image is greyscale {True} or color {False}. |
|||
:param h_size: Hidden layer size. |
|||
:param num_streams: Number of visual streams to construct. |
|||
:param activation: What type of activation function to use for layers. |
|||
:return: List of hidden layer tensors. |
|||
""" |
|||
if bw: |
|||
c_channels = 1 |
|||
else: |
|||
c_channels = 3 |
|||
|
|||
self.observation_in.append(tf.placeholder(shape=[None, o_size_h, o_size_w, c_channels], dtype=tf.float32, |
|||
name='observation_%d' % len(self.observation_in))) |
|||
|
|||
streams = [] |
|||
for i in range(num_streams): |
|||
self.conv1 = tf.layers.conv2d(self.observation_in[-1], 16, kernel_size=[8, 8], strides=[4, 4], |
|||
use_bias=False, activation=activation) |
|||
self.conv2 = tf.layers.conv2d(self.conv1, 32, kernel_size=[4, 4], strides=[2, 2], |
|||
use_bias=False, activation=activation) |
|||
|
|||
if self.use_recurrent: |
|||
_rec_input = c_layers.flatten(self.conv2) |
|||
hidden = self.create_recurrent_encoder(_rec_input.get_shape().as_list()[1], _rec_input) |
|||
else: |
|||
hidden = c_layers.flatten(self.conv2) |
|||
|
|||
for j in range(num_layers): |
|||
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation) |
|||
streams.append(hidden) |
|||
return streams |
|||
|
|||
def create_continuous_state_encoder(self, s_size, h_size, num_streams, activation, num_layers): |
|||
""" |
|||
Builds a set of hidden state encoders. |
|||
:param s_size: state input size. |
|||
:param h_size: Hidden layer size. |
|||
:param num_streams: Number of state streams to construct. |
|||
:param activation: What type of activation function to use for layers. |
|||
:return: List of hidden layer tensors. |
|||
""" |
|||
self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32, name='state') |
|||
|
|||
if self.normalize: |
|||
self.running_mean = tf.get_variable("running_mean", [s_size], trainable=False, dtype=tf.float32, |
|||
initializer=tf.zeros_initializer()) |
|||
self.running_variance = tf.get_variable("running_variance", [s_size], trainable=False, dtype=tf.float32, |
|||
initializer=tf.ones_initializer()) |
|||
|
|||
self.normalized_state = tf.clip_by_value((self.state_in - self.running_mean) / tf.sqrt( |
|||
self.running_variance / (tf.cast(self.global_step, tf.float32) + 1)), -5, 5, name="normalized_state") |
|||
|
|||
self.new_mean = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_mean') |
|||
self.new_variance = tf.placeholder(shape=[s_size], dtype=tf.float32, name='new_variance') |
|||
self.update_mean = tf.assign(self.running_mean, self.new_mean) |
|||
self.update_variance = tf.assign(self.running_variance, self.new_variance) |
|||
else: |
|||
self.normalized_state = self.state_in |
|||
|
|||
if self.use_recurrent: |
|||
self.recurrent_state = self.create_recurrent_encoder(s_size, self.normalized_state) |
|||
else: |
|||
self.recurrent_state = self.normalized_state |
|||
|
|||
streams = [] |
|||
for i in range(num_streams): |
|||
hidden = self.recurrent_state |
|||
for j in range(num_layers): |
|||
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation) |
|||
streams.append(hidden) |
|||
return streams |
|||
|
|||
def create_discrete_state_encoder(self, s_size, h_size, num_streams, activation, num_layers): |
|||
""" |
|||
Builds a set of hidden state encoders from discrete state input. |
|||
:param s_size: state input size (discrete). |
|||
:param h_size: Hidden layer size. |
|||
:param num_streams: Number of state streams to construct. |
|||
:param activation: What type of activation function to use for layers. |
|||
:return: List of hidden layer tensors. |
|||
""" |
|||
self.state_in = tf.placeholder(shape=[None, 1], dtype=tf.int32, name='state') |
|||
state_in = tf.reshape(self.state_in, [-1]) |
|||
state_onehot = c_layers.one_hot_encoding(state_in, s_size) |
|||
streams = [] |
|||
if self.use_recurrent: |
|||
hidden = self.create_recurrent_encoder(s_size, state_onehot) |
|||
else: |
|||
hidden = state_onehot |
|||
for i in range(num_streams): |
|||
for j in range(num_layers): |
|||
hidden = tf.layers.dense(hidden, h_size, use_bias=False, activation=activation) |
|||
streams.append(hidden) |
|||
return streams |
|||
|
|||
def create_ppo_optimizer(self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step): |
|||
""" |
|||
Creates training-specific Tensorflow ops for PPO models. |
|||
:param probs: Current policy probabilities |
|||
:param old_probs: Past policy probabilities |
|||
:param value: Current value estimate |
|||
:param beta: Entropy regularization strength |
|||
:param entropy: Current policy entropy |
|||
:param epsilon: Value for policy-divergence threshold |
|||
:param lr: Learning rate |
|||
:param max_step: Total number of training steps. |
|||
""" |
|||
|
|||
self.returns_holder = tf.placeholder(shape=[None], dtype=tf.float32, name='discounted_rewards') |
|||
self.advantage = tf.placeholder(shape=[None, 1], dtype=tf.float32, name='advantages') |
|||
|
|||
decay_epsilon = tf.train.polynomial_decay(epsilon, self.global_step, |
|||
max_step, 0.1, |
|||
power=1.0) |
|||
|
|||
r_theta = probs / (old_probs + 1e-10) |
|||
p_opt_a = r_theta * self.advantage |
|||
p_opt_b = tf.clip_by_value(r_theta, 1 - decay_epsilon, 1 + decay_epsilon) * self.advantage |
|||
self.policy_loss = -tf.reduce_mean(tf.minimum(p_opt_a, p_opt_b)) |
|||
|
|||
self.value_loss = tf.reduce_mean(tf.squared_difference(self.returns_holder, |
|||
tf.reduce_sum(value, axis=1))) |
|||
|
|||
decay_beta = tf.train.polynomial_decay(beta, self.global_step, |
|||
max_step, 1e-5, |
|||
power=1.0) |
|||
self.loss = self.policy_loss + 0.5 * self.value_loss - decay_beta * tf.reduce_mean(entropy) |
|||
|
|||
self.learning_rate = tf.train.polynomial_decay(lr, self.global_step, |
|||
max_step, 1e-10, |
|||
power=1.0) |
|||
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) |
|||
self.update_batch = optimizer.minimize(self.loss) |
|||
|
|||
|
|||
class ContinuousControlModel(PPOModel): |
|||
def __init__(self, lr, brain, h_size, epsilon, max_step, normalize, use_recurrent, num_layers,m_size): |
|||
""" |
|||
Creates Continuous Control Actor-Critic model. |
|||
:param brain: State-space size |
|||
:param h_size: Hidden layer size |
|||
""" |
|||
self.m_size = m_size |
|||
super(ContinuousControlModel, self).__init__() |
|||
s_size = brain.state_space_size |
|||
a_size = brain.action_space_size |
|||
|
|||
self.batch_size = tf.placeholder(shape=None, dtype=tf.int32, name='batch_size') |
|||
|
|||
self.sequence_length = tf.placeholder(shape=None, dtype=tf.int32, name='sequence_length') |
|||
|
|||
self.normalize = normalize |
|||
self.use_recurrent = use_recurrent |
|||
self.create_global_steps() |
|||
self.create_reward_encoder() |
|||
|
|||
hidden_state, hidden_visual, hidden_policy, hidden_value = None, None, None, None |
|||
if brain.number_observations > 0: |
|||
visual_encoder_0 = [] |
|||
visual_encoder_1 = [] |
|||
for i in range(brain.number_observations): |
|||
height_size, width_size = brain.camera_resolutions[i]['height'], brain.camera_resolutions[i]['width'] |
|||
bw = brain.camera_resolutions[i]['blackAndWhite'] |
|||
encoded_visual = self.create_visual_encoder(height_size, width_size, bw, h_size, 2, tf.nn.tanh, num_layers) |
|||
visual_encoder_0.append(encoded_visual[0]) |
|||
visual_encoder_1.append(encoded_visual[1]) |
|||
hidden_visual = [tf.concat(visual_encoder_0, axis=1), tf.concat(visual_encoder_1, axis=1)] |
|||
if brain.state_space_size > 0: |
|||
s_size = brain.state_space_size |
|||
if brain.state_space_type == "continuous": |
|||
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 2, tf.nn.tanh, num_layers) |
|||
else: |
|||
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 2, tf.nn.tanh, num_layers) |
|||
|
|||
if hidden_visual is None and hidden_state is None: |
|||
raise Exception("No valid network configuration possible. " |
|||
"There are no states or observations in this brain") |
|||
elif hidden_visual is not None and hidden_state is None: |
|||
hidden_policy, hidden_value = hidden_visual |
|||
elif hidden_visual is None and hidden_state is not None: |
|||
hidden_policy, hidden_value = hidden_state |
|||
elif hidden_visual is not None and hidden_state is not None: |
|||
hidden_policy = tf.concat([hidden_visual[0], hidden_state[0]], axis=1) |
|||
hidden_value = tf.concat([hidden_visual[1], hidden_state[1]], axis=1) |
|||
|
|||
|
|||
self.mu = tf.layers.dense(hidden_policy, a_size, activation=None, use_bias=False, |
|||
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) |
|||
self.log_sigma_sq = tf.get_variable("log_sigma_squared", [a_size], dtype=tf.float32, |
|||
initializer=tf.zeros_initializer()) |
|||
self.sigma_sq = tf.exp(self.log_sigma_sq) |
|||
|
|||
self.epsilon = tf.placeholder(shape=[None, a_size], dtype=tf.float32, name='epsilon') |
|||
|
|||
self.output = self.mu + tf.sqrt(self.sigma_sq) * self.epsilon |
|||
self.output = tf.identity(self.output, name='action') |
|||
|
|||
a = tf.exp(-1 * tf.pow(tf.stop_gradient(self.output) - self.mu, 2) / (2 * self.sigma_sq)) |
|||
b = 1 / tf.sqrt(2 * self.sigma_sq * np.pi) |
|||
self.probs = tf.multiply(a, b, name="action_probs") |
|||
|
|||
self.entropy = tf.reduce_sum(0.5 * tf.log(2 * np.pi * np.e * self.sigma_sq)) |
|||
|
|||
self.value = tf.layers.dense(hidden_value, 1, activation=None, use_bias=False) |
|||
self.value = tf.identity(self.value, name="value_estimate") |
|||
|
|||
self.old_probs = tf.placeholder(shape=[None, a_size], dtype=tf.float32, name='old_probabilities') |
|||
|
|||
self.create_ppo_optimizer(self.probs, self.old_probs, self.value, self.entropy, 0.0, epsilon, lr, max_step) |
|||
|
|||
|
|||
class DiscreteControlModel(PPOModel): |
|||
def __init__(self, lr, brain, h_size, epsilon, beta, max_step, normalize,use_recurrent, num_layers,m_size): |
|||
""" |
|||
Creates Discrete Control Actor-Critic model. |
|||
:param brain: State-space size |
|||
:param h_size: Hidden layer size |
|||
""" |
|||
self.m_size = m_size |
|||
super(DiscreteControlModel, self).__init__() |
|||
self.create_global_steps() |
|||
self.create_reward_encoder() |
|||
self.normalize = normalize |
|||
self.use_recurrent = use_recurrent |
|||
|
|||
self.batch_size = tf.placeholder(shape=None, dtype=tf.int32, name='batch_size') |
|||
|
|||
self.sequence_length = tf.placeholder(shape=None, dtype=tf.int32, name='sequence_length') |
|||
|
|||
hidden_state, hidden_visual, hidden = None, None, None |
|||
if brain.number_observations > 0: |
|||
visual_encoders = [] |
|||
for i in range(brain.number_observations): |
|||
height_size, width_size = brain.camera_resolutions[i]['height'], brain.camera_resolutions[i]['width'] |
|||
bw = brain.camera_resolutions[i]['blackAndWhite'] |
|||
visual_encoders.append(self.create_visual_encoder(height_size, width_size, bw, h_size, 2, tf.nn.tanh, num_layers)[0]) |
|||
hidden_visual = [tf.concat(visual_encoders, axis=1)] |
|||
if brain.state_space_size > 0: |
|||
s_size = brain.state_space_size |
|||
if brain.state_space_type == "continuous": |
|||
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 1, tf.nn.elu, num_layers)[0] |
|||
else: |
|||
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 1, tf.nn.elu, num_layers)[0] |
|||
|
|||
if hidden_visual is None and hidden_state is None: |
|||
raise Exception("No valid network configuration possible. " |
|||
"There are no states or observations in this brain") |
|||
elif hidden_visual is not None and hidden_state is None: |
|||
hidden = hidden_visual |
|||
elif hidden_visual is None and hidden_state is not None: |
|||
hidden = hidden_state |
|||
elif hidden_visual is not None and hidden_state is not None: |
|||
hidden = tf.concat([hidden_visual[0], hidden_state], axis=1) |
|||
|
|||
a_size = brain.action_space_size |
|||
|
|||
self.policy = tf.layers.dense(hidden, a_size, activation=None, use_bias=False, |
|||
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) |
|||
self.probs = tf.nn.softmax(self.policy, name="action_probs") |
|||
self.output = tf.multinomial(self.policy, 1) |
|||
self.output = tf.identity(self.output, name="action") |
|||
self.value = tf.layers.dense(hidden, 1, activation=None, use_bias=False, |
|||
kernel_initializer=c_layers.variance_scaling_initializer(factor=1.0)) |
|||
self.value = tf.identity(self.value, name="value_estimate") |
|||
|
|||
self.entropy = -tf.reduce_sum(self.probs * tf.log(self.probs + 1e-10), axis=1) |
|||
|
|||
self.action_holder = tf.placeholder(shape=[None], dtype=tf.int32) |
|||
self.selected_actions = c_layers.one_hot_encoding(self.action_holder, a_size) |
|||
self.old_probs = tf.placeholder(shape=[None, a_size], dtype=tf.float32, name='old_probabilities') |
|||
self.responsible_probs = tf.reduce_sum(self.probs * self.selected_actions, axis=1) |
|||
self.old_responsible_probs = tf.reduce_sum(self.old_probs * self.selected_actions, axis=1) |
|||
|
|||
self.create_ppo_optimizer(self.responsible_probs, self.old_responsible_probs, |
|||
self.value, self.entropy, beta, epsilon, lr, max_step) |
|
|||
# # Unity ML Agents |
|||
# ## ML-Agent Learning (PPO) |
|||
# Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347). |
|||
|
|||
import logging |
|||
import os |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
|
|||
from trainers.buffer import Buffer |
|||
from trainers.ppo_models import * |
|||
from trainers.trainer import UnityTrainerException, Trainer |
|||
|
|||
logger = logging.getLogger("unityagents") |
|||
|
|||
|
|||
|
|||
class PPOTrainer(Trainer): |
|||
"""The PPOTrainer is an implementation of the PPO algorythm.""" |
|||
def __init__(self, sess, env, brain_name, trainer_parameters, training): |
|||
""" |
|||
Responsible for collecting experiences and training PPO model. |
|||
:param sess: Tensorflow session. |
|||
:param env: The UnityEnvironment. |
|||
:param trainer_parameters: The parameters for the trainer (dictionary). |
|||
:param training: Whether the trainer is set for training. |
|||
""" |
|||
self.param_keys = ['batch_size', 'beta','buffer_size','epsilon','gamma','hidden_units','lambd','learning_rate', |
|||
'max_steps','normalize','num_epoch','num_layers','time_horizon','sequence_length','summary_freq', |
|||
'use_recurrent','graph_scope','summary_path'] |
|||
|
|||
for k in self.param_keys: |
|||
if k not in trainer_parameters: |
|||
raise UnityTrainerException("The hyperparameter {0} could not be found for the PPO trainer of " |
|||
"brain {1}.".format(k, brain_name)) |
|||
|
|||
super(PPOTrainer, self).__init__(sess, env, brain_name, trainer_parameters, training) |
|||
|
|||
self.use_recurrent = trainer_parameters["use_recurrent"] |
|||
self.sequence_length = 1 |
|||
self.m_size = None |
|||
if self.use_recurrent: |
|||
self.m_size = env.brains[brain_name].memory_space_size |
|||
self.sequence_length = trainer_parameters["sequence_length"] |
|||
self.variable_scope = trainer_parameters['graph_scope'] |
|||
with tf.variable_scope(self.variable_scope): |
|||
self.model = create_agent_model(env.brains[brain_name], |
|||
lr=float(trainer_parameters['learning_rate']), |
|||
h_size=int(trainer_parameters['hidden_units']), |
|||
epsilon=float(trainer_parameters['epsilon']), |
|||
beta=float(trainer_parameters['beta']), |
|||
max_step=float(trainer_parameters['max_steps']), |
|||
normalize=trainer_parameters['normalize'], |
|||
use_recurrent=trainer_parameters['use_recurrent'], |
|||
num_layers=int(trainer_parameters['num_layers']), |
|||
m_size = self.m_size) |
|||
|
|||
|
|||
stats = {'cumulative_reward': [], 'episode_length': [], 'value_estimate': [], |
|||
'entropy': [], 'value_loss': [], 'policy_loss': [], 'learning_rate': []} |
|||
self.stats = stats |
|||
|
|||
self.training_buffer = Buffer() |
|||
self.cumulative_rewards = {} |
|||
self.episode_steps = {} |
|||
self.is_continuous = (env.brains[brain_name].action_space_type == "continuous") |
|||
self.use_observations = (env.brains[brain_name].number_observations > 0) |
|||
self.use_states = (env.brains[brain_name].state_space_size > 0) |
|||
self.summary_path = trainer_parameters['summary_path'] |
|||
if not os.path.exists(self.summary_path): |
|||
os.makedirs(self.summary_path) |
|||
|
|||
self.summary_writer = tf.summary.FileWriter(self.summary_path) |
|||
|
|||
def __str__(self): |
|||
return '''Hypermarameters for the PPO Trainer of brain {0}: \n{1}'''.format( |
|||
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys])) |
|||
|
|||
@property |
|||
def parameters(self): |
|||
""" |
|||
Returns the trainer parameters of the trainer. |
|||
""" |
|||
return self.trainer_parameters |
|||
|
|||
@property |
|||
def graph_scope(self): |
|||
""" |
|||
Returns the graph scope of the trainer. |
|||
""" |
|||
return self.variable_scope |
|||
|
|||
@property |
|||
def get_max_steps(self): |
|||
""" |
|||
Returns the maximum number of steps. Is used to know when the trainer should be stopped. |
|||
:return: The maximum number of steps of the trainer |
|||
""" |
|||
return float(self.trainer_parameters['max_steps']) |
|||
|
|||
@property |
|||
def get_step(self): |
|||
""" |
|||
Returns the number of steps the trainer has performed |
|||
:return: the step count of the trainer |
|||
""" |
|||
return self.sess.run(self.model.global_step) |
|||
|
|||
@property |
|||
def get_last_reward(self): |
|||
""" |
|||
Returns the last reward the trainer has had |
|||
:return: the new last reward |
|||
""" |
|||
return self.sess.run(self.model.last_reward) |
|||
|
|||
def increment_step(self): |
|||
""" |
|||
Increment the step count of the trainer |
|||
""" |
|||
self.sess.run(self.model.increment_step) |
|||
|
|||
def update_last_reward(self): |
|||
""" |
|||
Updates the last reward |
|||
""" |
|||
if len(self.stats['cumulative_reward']) > 0: |
|||
mean_reward = np.mean(self.stats['cumulative_reward']) |
|||
self.sess.run(self.model.update_reward, feed_dict={self.model.new_reward: mean_reward}) |
|||
last_reward = self.sess.run(self.model.last_reward) |
|||
|
|||
def running_average(self, data, steps, running_mean, running_variance): |
|||
""" |
|||
Computes new running mean and variances. |
|||
:param data: New piece of data. |
|||
:param steps: Total number of data so far. |
|||
:param running_mean: TF op corresponding to stored running mean. |
|||
:param running_variance: TF op corresponding to stored running variance. |
|||
:return: New mean and variance values. |
|||
""" |
|||
mean, var = self.sess.run([running_mean, running_variance]) |
|||
current_x = np.mean(data, axis=0) |
|||
new_mean = mean + (current_x - mean) / (steps + 1) |
|||
new_variance = var + (current_x - new_mean) * (current_x - mean) |
|||
return new_mean, new_variance |
|||
|
|||
def take_action(self, info): |
|||
""" |
|||
Decides actions given state/observation information, and takes them in environment. |
|||
:param info: Current BrainInfo from environment. |
|||
:return: a tupple containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
steps = self.get_step |
|||
info = info[self.brain_name] |
|||
epsi = None |
|||
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length: 1} |
|||
run_list = [self.model.output, self.model.probs, self.model.value, self.model.entropy, |
|||
self.model.learning_rate] |
|||
if self.is_continuous: |
|||
epsi = np.random.randn(len(info.states), self.brain.action_space_size) |
|||
feed_dict[self.model.epsilon] = epsi |
|||
if self.use_observations: |
|||
for i, _ in enumerate(info.observations): |
|||
feed_dict[self.model.observation_in[i]] = info.observations[i] |
|||
if self.use_states: |
|||
feed_dict[self.model.state_in] = info.states |
|||
if self.use_recurrent: |
|||
feed_dict[self.model.memory_in] = info.memories |
|||
run_list += [self.model.memory_out] |
|||
if (self.is_training and self.brain.state_space_type == "continuous" and |
|||
self.use_states and self.trainer_parameters['normalize']): |
|||
new_mean, new_variance = self.running_average(info.states, steps, self.model.running_mean, |
|||
self.model.running_variance) |
|||
feed_dict[self.model.new_mean] = new_mean |
|||
feed_dict[self.model.new_variance] = new_variance |
|||
run_list = run_list + [self.model.update_mean, self.model.update_variance] |
|||
#only ask for memories if use_recurrent |
|||
if self.use_recurrent: |
|||
actions, a_dist, value, ent, learn_rate, memories, _, _ = self.sess.run(run_list, feed_dict=feed_dict) |
|||
else: |
|||
actions, a_dist, value, ent, learn_rate, _, _ = self.sess.run(run_list, feed_dict=feed_dict) |
|||
memories = None |
|||
else: |
|||
if self.use_recurrent: |
|||
actions, a_dist, value, ent, learn_rate, memories = self.sess.run(run_list, feed_dict=feed_dict) |
|||
else: |
|||
actions, a_dist, value, ent, learn_rate= self.sess.run(run_list, feed_dict=feed_dict) |
|||
memories = None |
|||
self.stats['value_estimate'].append(value) |
|||
self.stats['entropy'].append(ent) |
|||
self.stats['learning_rate'].append(learn_rate) |
|||
return (actions, memories, value, (actions, epsi, a_dist, value)) |
|||
|
|||
def add_experiences(self, info, next_info, take_action_outputs): |
|||
""" |
|||
Adds experiences to each agent's experience history. |
|||
:param info: Current BrainInfo. |
|||
:param next_info: Next BrainInfo. |
|||
:param take_action_outputs: The outputs of the take action method. |
|||
""" |
|||
info = info[self.brain_name] |
|||
next_info = next_info[self.brain_name] |
|||
actions, epsi, a_dist, value = take_action_outputs |
|||
for agent_id in info.agents: |
|||
if agent_id in next_info.agents: |
|||
idx = info.agents.index(agent_id) |
|||
next_idx = next_info.agents.index(agent_id) |
|||
if not info.local_done[idx]: |
|||
if self.use_observations: |
|||
for i, _ in enumerate(info.observations): |
|||
self.training_buffer[agent_id]['observations%d'%i].append(info.observations[i][idx]) |
|||
if self.use_states: |
|||
self.training_buffer[agent_id]['states'].append(info.states[idx]) |
|||
if self.use_recurrent: |
|||
self.training_buffer[agent_id]['memory'].append(info.memories[idx]) |
|||
if self.is_continuous: |
|||
self.training_buffer[agent_id]['epsilons'].append(epsi[idx]) |
|||
self.training_buffer[agent_id]['actions'].append(actions[idx]) |
|||
self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx]) |
|||
self.training_buffer[agent_id]['action_probs'].append(a_dist[idx]) |
|||
self.training_buffer[agent_id]['value_estimates'].append(value[idx][0]) |
|||
if agent_id not in self.cumulative_rewards: |
|||
self.cumulative_rewards[agent_id] = 0 |
|||
self.cumulative_rewards[agent_id] += next_info.rewards[next_idx] |
|||
if agent_id not in self.episode_steps: |
|||
self.episode_steps[agent_id] = 0 |
|||
self.episode_steps[agent_id] += 1 |
|||
|
|||
def process_experiences(self, info): |
|||
""" |
|||
Checks agent histories for processing condition, and processes them as necessary. |
|||
Processing involves calculating value and advantage targets for model updating step. |
|||
:param info: Current BrainInfo |
|||
""" |
|||
|
|||
info = info[self.brain_name] |
|||
for l in range(len(info.agents)): |
|||
if ((info.local_done[l] or |
|||
len(self.training_buffer[info.agents[l]]['actions']) > self.trainer_parameters['time_horizon']) |
|||
and len(self.training_buffer[info.agents[l]]['actions']) > 0): |
|||
|
|||
if info.local_done[l]: |
|||
value_next = 0.0 |
|||
else: |
|||
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length :1} |
|||
if self.use_observations: |
|||
for i in range(self.info.observations): |
|||
feed_dict[self.model.observation_in[i]] = info.observations[i] |
|||
if self.use_states: |
|||
feed_dict[self.model.state_in] = info.states |
|||
if self.use_recurrent: |
|||
feed_dict[self.model.memory_in] = info.memories |
|||
value_next = self.sess.run(self.model.value, feed_dict)[l] |
|||
agent_id = info.agents[l] |
|||
self.training_buffer[agent_id]['advantages'].set( |
|||
get_gae( |
|||
rewards=self.training_buffer[agent_id]['rewards'].get_batch(), |
|||
value_estimates=self.training_buffer[agent_id]['value_estimates'].get_batch(), |
|||
value_next=value_next, |
|||
gamma=self.trainer_parameters['gamma'], |
|||
lambd=self.trainer_parameters['lambd']) |
|||
) |
|||
self.training_buffer[agent_id]['discounted_returns'].set( \ |
|||
self.training_buffer[agent_id]['advantages'].get_batch() \ |
|||
+ self.training_buffer[agent_id]['value_estimates'].get_batch()) |
|||
|
|||
self.training_buffer.append_update_buffer(agent_id, |
|||
batch_size = None, training_length=self.sequence_length) |
|||
|
|||
self.training_buffer[agent_id].reset_agent() |
|||
if info.local_done[l]: |
|||
self.stats['cumulative_reward'].append(self.cumulative_rewards[agent_id]) |
|||
self.stats['episode_length'].append(self.episode_steps[agent_id]) |
|||
self.cumulative_rewards[agent_id] = 0 |
|||
self.episode_steps[agent_id] = 0 |
|||
|
|||
|
|||
def end_episode(self): |
|||
""" |
|||
A signal that the Episode has ended. The buffer must be reset. |
|||
Get only called when the academy resets. |
|||
""" |
|||
self.training_buffer.reset_all() |
|||
for agent_id in self.cumulative_rewards: |
|||
self.cumulative_rewards[agent_id] = 0 |
|||
for agent_id in self.episode_steps: |
|||
self.episode_steps[agent_id] = 0 |
|||
|
|||
def is_ready_update(self): |
|||
""" |
|||
Returns wether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to wether or not update_model() can be run |
|||
""" |
|||
return len(self.training_buffer.update_buffer['actions']) > self.trainer_parameters['buffer_size'] |
|||
|
|||
def update_model(self): |
|||
""" |
|||
Uses training_buffer to update model. |
|||
""" |
|||
num_epoch = self.trainer_parameters['num_epoch'] |
|||
batch_size = self.trainer_parameters['batch_size'] |
|||
total_v, total_p = 0, 0 |
|||
advantages = self.training_buffer.update_buffer['advantages'].get_batch() |
|||
self.training_buffer.update_buffer['advantages'].set( |
|||
(advantages - advantages.mean()) / advantages.std()) |
|||
for k in range(num_epoch): |
|||
self.training_buffer.update_buffer.shuffle() |
|||
for l in range(len(self.training_buffer.update_buffer['actions']) // batch_size): |
|||
start = l * batch_size |
|||
end = (l + 1) * batch_size |
|||
|
|||
_buffer = self.training_buffer.update_buffer |
|||
feed_dict = {self.model.batch_size:batch_size, |
|||
self.model.sequence_length: self.sequence_length, |
|||
self.model.returns_holder: np.array(_buffer['discounted_returns'][start:end]).reshape([-1]), |
|||
self.model.advantage: np.array(_buffer['advantages'][start:end]).reshape([-1,1]), |
|||
self.model.old_probs: np.array( |
|||
_buffer['action_probs'][start:end]).reshape([-1,self.brain.action_space_size])} |
|||
if self.is_continuous: |
|||
feed_dict[self.model.epsilon] = np.array( |
|||
_buffer['epsilons'][start:end]).reshape([-1,self.brain.action_space_size]) |
|||
else: |
|||
feed_dict[self.model.action_holder] = np.array( |
|||
_buffer['actions'][start:end]).reshape([-1]) |
|||
if self.use_states: |
|||
if self.brain.state_space_type == "continuous": |
|||
feed_dict[self.model.state_in] = np.array( |
|||
_buffer['states'][start:end]).reshape([-1,self.brain.state_space_size]) |
|||
else: |
|||
feed_dict[self.model.state_in] = np.array( |
|||
_buffer['states'][start:end]).reshape([-1,1]) |
|||
if self.use_observations: |
|||
for i, _ in enumerate(self.model.observation_in): |
|||
_obs = np.array(_buffer['observations%d'%i][start:end]) |
|||
(_batch, _seq, _w, _h, _c) = _obs.shape |
|||
feed_dict[self.model.observation_in[i]] = _obs.reshape([-1,_w,_h,_c]) |
|||
#Memories are zeros |
|||
if self.use_recurrent: |
|||
feed_dict[self.model.memory_in] = np.zeros([batch_size , self.m_size]) |
|||
v_loss, p_loss, _ = self.sess.run([self.model.value_loss, self.model.policy_loss, |
|||
self.model.update_batch], feed_dict=feed_dict) |
|||
total_v += v_loss |
|||
total_p += p_loss |
|||
self.stats['value_loss'].append(total_v) |
|||
self.stats['policy_loss'].append(total_p) |
|||
self.training_buffer.reset_update_buffer() |
|||
|
|||
def write_summary(self, lesson_number): |
|||
""" |
|||
Saves training statistics to Tensorboard. |
|||
:param lesson_number: The lesson the trainer is at. |
|||
""" |
|||
if (self.get_step % self.trainer_parameters['summary_freq'] == 0 and self.get_step != 0 and |
|||
self.is_training and self.get_step <= self.get_max_steps): |
|||
steps = self.get_step |
|||
if len(self.stats['cumulative_reward']) > 0: |
|||
mean_reward = np.mean(self.stats['cumulative_reward']) |
|||
logger.info(" {0}: Step: {1}. Mean Reward: {2}. Std of Reward: {3}." |
|||
.format(self.brain_name, steps, mean_reward, np.std(self.stats['cumulative_reward']))) |
|||
summary = tf.Summary() |
|||
for key in self.stats: |
|||
if len(self.stats[key]) > 0: |
|||
stat_mean = float(np.mean(self.stats[key])) |
|||
summary.value.add(tag='Info/{}'.format(key), simple_value=stat_mean) |
|||
self.stats[key] = [] |
|||
summary.value.add(tag='Info/Lesson', simple_value=lesson_number) |
|||
self.summary_writer.add_summary(summary, steps) |
|||
self.summary_writer.flush() |
|||
|
|||
|
|||
def discount_rewards(r, gamma=0.99, value_next=0.0): |
|||
""" |
|||
Computes discounted sum of future rewards for use in updating value estimate. |
|||
:param r: List of rewards. |
|||
:param gamma: Discount factor. |
|||
:param value_next: T+1 value estimate for returns calculation. |
|||
:return: discounted sum of future rewards as list. |
|||
""" |
|||
discounted_r = np.zeros_like(r) |
|||
running_add = value_next |
|||
for t in reversed(range(0, r.size)): |
|||
running_add = running_add * gamma + r[t] |
|||
discounted_r[t] = running_add |
|||
return discounted_r |
|||
|
|||
|
|||
def get_gae(rewards, value_estimates, value_next=0.0, gamma=0.99, lambd=0.95): |
|||
""" |
|||
Computes generalized advantage estimate for use in updating policy. |
|||
:param rewards: list of rewards for time-steps t to T. |
|||
:param value_next: Value estimate for time-step T+1. |
|||
:param value_estimates: list of value estimates for time-steps t to T. |
|||
:param gamma: Discount factor. |
|||
:param lambd: GAE weighing factor. |
|||
:return: list of advantage estimates for time-steps t to T. |
|||
""" |
|||
value_estimates = np.asarray(value_estimates.tolist() + [value_next]) |
|||
delta_t = rewards + gamma * value_estimates[1:] - value_estimates[:-1] |
|||
advantage = discount_rewards(r=delta_t, gamma=gamma*lambd) |
|||
return advantage |
|||
|
|
|||
# # Unity ML Agents |
|||
import logging |
|||
|
|||
import tensorflow as tf |
|||
|
|||
from unityagents import UnityException |
|||
|
|||
logger = logging.getLogger("unityagents") |
|||
|
|||
class UnityTrainerException(UnityException): |
|||
""" |
|||
Related to errors with the Trainer. |
|||
""" |
|||
pass |
|||
|
|||
|
|||
class Trainer(object): |
|||
"""This class is the abstract class for the trainers""" |
|||
def __init__(self, sess, env, brain_name, trainer_parameters, training): |
|||
""" |
|||
Responsible for collecting experiences and training PPO model. |
|||
:param sess: Tensorflow session. |
|||
:param env: The UnityEnvironment. |
|||
:param trainer_parameters: The parameters for the trainer (dictionary). |
|||
:param training: Whether the trainer is set for training. |
|||
""" |
|||
self.brain_name = brain_name |
|||
self.brain = env.brains[self.brain_name] |
|||
self.trainer_parameters = trainer_parameters |
|||
self.is_training = training |
|||
self.sess = sess |
|||
|
|||
|
|||
def __str__(self): |
|||
return '''Empty Trainer''' |
|||
|
|||
@property |
|||
def parameters(self): |
|||
""" |
|||
Returns the trainer parameters of the trainer. |
|||
""" |
|||
raise UnityTrainerException("The parameters property was not implemented.") |
|||
|
|||
@property |
|||
def graph_scope(self): |
|||
""" |
|||
Returns the graph scope of the trainer. |
|||
""" |
|||
raise UnityTrainerException("The graph_scope property was not implemented.") |
|||
|
|||
@property |
|||
def get_max_steps(self): |
|||
""" |
|||
Returns the maximum number of steps. Is used to know when the trainer should be stopped. |
|||
:return: The maximum number of steps of the trainer |
|||
""" |
|||
raise UnityTrainerException("The get_max_steps property was not implemented.") |
|||
|
|||
@property |
|||
def get_step(self): |
|||
""" |
|||
Returns the number of steps the trainer has performed |
|||
:return: the step count of the trainer |
|||
""" |
|||
raise UnityTrainerException("The get_step property was not implemented.") |
|||
|
|||
@property |
|||
def get_last_reward(self): |
|||
""" |
|||
Returns the last reward the trainer has had |
|||
:return: the new last reward |
|||
""" |
|||
raise UnityTrainerException("The get_last_reward property was not implemented.") |
|||
|
|||
def increment_step(self): |
|||
""" |
|||
Increment the step count of the trainer |
|||
""" |
|||
raise UnityTrainerException("The increment_step method was not implemented.") |
|||
|
|||
def update_last_reward(self): |
|||
""" |
|||
Updates the last reward |
|||
""" |
|||
raise UnityTrainerException("The update_last_reward method was not implemented.") |
|||
|
|||
def take_action(self, info): |
|||
""" |
|||
Decides actions given state/observation information, and takes them in environment. |
|||
:param info: Current BrainInfo from environment. |
|||
:return: a tupple containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
raise UnityTrainerException("The take_action method was not implemented.") |
|||
|
|||
def add_experiences(self, info, next_info, take_action_outputs): |
|||
""" |
|||
Adds experiences to each agent's experience history. |
|||
:param info: Current BrainInfo. |
|||
:param next_info: Next BrainInfo. |
|||
:param take_action_outputs: The outputs of the take action method. |
|||
""" |
|||
raise UnityTrainerException("The add_experiences method was not implemented.") |
|||
|
|||
def process_experiences(self, info): |
|||
""" |
|||
Checks agent histories for processing condition, and processes them as necessary. |
|||
Processing involves calculating value and advantage targets for model updating step. |
|||
:param info: Current BrainInfo |
|||
""" |
|||
raise UnityTrainerException("The process_experiences method was not implemented.") |
|||
|
|||
|
|||
def end_episode(self): |
|||
""" |
|||
A signal that the Episode has ended. The buffer must be reset. |
|||
Get only called when the academy resets. |
|||
""" |
|||
raise UnityTrainerException("The end_episode method was not implemented.") |
|||
|
|||
def is_ready_update(self): |
|||
""" |
|||
Returns wether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to wether or not update_model() can be run |
|||
""" |
|||
raise UnityTrainerException("The is_ready_update method was not implemented.") |
|||
|
|||
def update_model(self): |
|||
""" |
|||
Uses training_buffer to update model. |
|||
""" |
|||
raise UnityTrainerException("The update_model method was not implemented.") |
|||
|
|||
def write_summary(self, lesson_number): |
|||
""" |
|||
Saves training statistics to Tensorboard. |
|||
:param lesson_number: The lesson the trainer is at. |
|||
""" |
|||
raise UnityTrainerException("The write_summary method was not implemented.") |
|||
|
|||
def write_tensorboard_text(self, key, input_dict): |
|||
""" |
|||
Saves text to Tensorboard. |
|||
Note: Only works on tensorflow r1.2 or above. |
|||
:param key: The name of the text. |
|||
:param input_dict: A dictionary that will be displayed in a table on Tensorboard. |
|||
""" |
|||
try: |
|||
s_op = tf.summary.text(key, |
|||
tf.convert_to_tensor(([[str(x), str(input_dict[x])] for x in input_dict])) |
|||
) |
|||
s = self.sess.run(s_op) |
|||
self.summary_writer.add_summary(s, self.get_step) |
|||
except: |
|||
logger.info("Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.") |
|||
pass |
|||
|
|||
|
|||
|
|
|||
{ |
|||
"cells": [ |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": {}, |
|||
"source": [ |
|||
"# Unity ML Agents\n", |
|||
"## Proximal Policy Optimization (PPO)\n", |
|||
"Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347)." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"collapsed": true |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"import numpy as np\n", |
|||
"import os\n", |
|||
"import tensorflow as tf\n", |
|||
"\n", |
|||
"from ppo.history import *\n", |
|||
"from ppo.models import *\n", |
|||
"from ppo.trainer import Trainer\n", |
|||
"from unityagents import *" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": {}, |
|||
"source": [ |
|||
"### Hyperparameters" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"collapsed": true |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"### General parameters\n", |
|||
"max_steps = 5e5 # Set maximum number of steps to run environment.\n", |
|||
"run_path = \"ppo\" # The sub-directory name for model and summary statistics\n", |
|||
"load_model = False # Whether to load a saved model.\n", |
|||
"train_model = True # Whether to train the model.\n", |
|||
"summary_freq = 10000 # Frequency at which to save training statistics.\n", |
|||
"save_freq = 50000 # Frequency at which to save model.\n", |
|||
"env_name = \"environment\" # Name of the training environment file.\n", |
|||
"curriculum_file = None\n", |
|||
"lesson = 0 # Start learning from this lesson\n", |
|||
"\n", |
|||
"### Algorithm-specific parameters for tuning\n", |
|||
"gamma = 0.99 # Reward discount rate.\n", |
|||
"lambd = 0.95 # Lambda parameter for GAE.\n", |
|||
"time_horizon = 2048 # How many steps to collect per agent before adding to buffer.\n", |
|||
"beta = 1e-3 # Strength of entropy regularization\n", |
|||
"num_epoch = 5 # Number of gradient descent steps per batch of experiences.\n", |
|||
"num_layers = 2 # Number of hidden layers between state/observation encoding and value/policy layers.\n", |
|||
"epsilon = 0.2 # Acceptable threshold around ratio of old and new policy probabilities.\n", |
|||
"buffer_size = 2048 # How large the experience buffer should be before gradient descent.\n", |
|||
"learning_rate = 3e-4 # Model learning rate.\n", |
|||
"hidden_units = 64 # Number of units in hidden layer.\n", |
|||
"batch_size = 64 # How many experiences per gradient descent update step.\n", |
|||
"normalize = False\n", |
|||
"\n", |
|||
"### Logging dictionary for hyperparameters\n", |
|||
"hyperparameter_dict = {'max_steps':max_steps, 'run_path':run_path, 'env_name':env_name,\n", |
|||
" 'curriculum_file':curriculum_file, 'gamma':gamma, 'lambd':lambd, 'time_horizon':time_horizon,\n", |
|||
" 'beta':beta, 'num_epoch':num_epoch, 'epsilon':epsilon, 'buffe_size':buffer_size,\n", |
|||
" 'leaning_rate':learning_rate, 'hidden_units':hidden_units, 'batch_size':batch_size}" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": {}, |
|||
"source": [ |
|||
"### Load the environment" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"collapsed": true |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"env = UnityEnvironment(file_name=env_name, curriculum=curriculum_file, lesson=lesson)\n", |
|||
"print(str(env))\n", |
|||
"brain_name = env.external_brain_names[0]" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": {}, |
|||
"source": [ |
|||
"### Train the Agent(s)" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"collapsed": true, |
|||
"scrolled": true |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"tf.reset_default_graph()\n", |
|||
"\n", |
|||
"if curriculum_file == \"None\":\n", |
|||
" curriculum_file = None\n", |
|||
"\n", |
|||
"\n", |
|||
"def get_progress():\n", |
|||
" if curriculum_file is not None:\n", |
|||
" if env._curriculum.measure_type == \"progress\":\n", |
|||
" return steps / max_steps\n", |
|||
" elif env._curriculum.measure_type == \"reward\":\n", |
|||
" return last_reward\n", |
|||
" else:\n", |
|||
" return None\n", |
|||
" else:\n", |
|||
" return None\n", |
|||
"\n", |
|||
"# Create the Tensorflow model graph\n", |
|||
"ppo_model = create_agent_model(env, lr=learning_rate,\n", |
|||
" h_size=hidden_units, epsilon=epsilon,\n", |
|||
" beta=beta, max_step=max_steps, \n", |
|||
" normalize=normalize, num_layers=num_layers)\n", |
|||
"\n", |
|||
"is_continuous = (env.brains[brain_name].action_space_type == \"continuous\")\n", |
|||
"use_observations = (env.brains[brain_name].number_observations > 0)\n", |
|||
"use_states = (env.brains[brain_name].state_space_size > 0)\n", |
|||
"\n", |
|||
"model_path = './models/{}'.format(run_path)\n", |
|||
"summary_path = './summaries/{}'.format(run_path)\n", |
|||
"\n", |
|||
"if not os.path.exists(model_path):\n", |
|||
" os.makedirs(model_path)\n", |
|||
"\n", |
|||
"if not os.path.exists(summary_path):\n", |
|||
" os.makedirs(summary_path)\n", |
|||
"\n", |
|||
"init = tf.global_variables_initializer()\n", |
|||
"saver = tf.train.Saver()\n", |
|||
"\n", |
|||
"with tf.Session() as sess:\n", |
|||
" # Instantiate model parameters\n", |
|||
" if load_model:\n", |
|||
" print('Loading Model...')\n", |
|||
" ckpt = tf.train.get_checkpoint_state(model_path)\n", |
|||
" saver.restore(sess, ckpt.model_checkpoint_path)\n", |
|||
" else:\n", |
|||
" sess.run(init)\n", |
|||
" steps, last_reward = sess.run([ppo_model.global_step, ppo_model.last_reward]) \n", |
|||
" summary_writer = tf.summary.FileWriter(summary_path)\n", |
|||
" info = env.reset(train_mode=train_model, progress=get_progress())[brain_name]\n", |
|||
" trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations, use_states, train_model)\n", |
|||
" if train_model:\n", |
|||
" trainer.write_text(summary_writer, 'Hyperparameters', hyperparameter_dict, steps)\n", |
|||
" while steps <= max_steps:\n", |
|||
" if env.global_done:\n", |
|||
" info = env.reset(train_mode=train_model, progress=get_progress())[brain_name]\n", |
|||
" # Decide and take an action\n", |
|||
" new_info = trainer.take_action(info, env, brain_name, steps, normalize)\n", |
|||
" info = new_info\n", |
|||
" trainer.process_experiences(info, time_horizon, gamma, lambd)\n", |
|||
" if len(trainer.training_buffer['actions']) > buffer_size and train_model:\n", |
|||
" # Perform gradient descent with experience buffer\n", |
|||
" trainer.update_model(batch_size, num_epoch)\n", |
|||
" if steps % summary_freq == 0 and steps != 0 and train_model:\n", |
|||
" # Write training statistics to tensorboard.\n", |
|||
" trainer.write_summary(summary_writer, steps, env._curriculum.lesson_number)\n", |
|||
" if steps % save_freq == 0 and steps != 0 and train_model:\n", |
|||
" # Save Tensorflow model\n", |
|||
" save_model(sess, model_path=model_path, steps=steps, saver=saver)\n", |
|||
" steps += 1\n", |
|||
" sess.run(ppo_model.increment_step)\n", |
|||
" if len(trainer.stats['cumulative_reward']) > 0:\n", |
|||
" mean_reward = np.mean(trainer.stats['cumulative_reward'])\n", |
|||
" sess.run(ppo_model.update_reward, feed_dict={ppo_model.new_reward: mean_reward})\n", |
|||
" last_reward = sess.run(ppo_model.last_reward)\n", |
|||
" # Final save Tensorflow model\n", |
|||
" if steps != 0 and train_model:\n", |
|||
" save_model(sess, model_path=model_path, steps=steps, saver=saver)\n", |
|||
"env.close()\n", |
|||
"export_graph(model_path, env_name)" |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "markdown", |
|||
"metadata": {}, |
|||
"source": [ |
|||
"### Export the trained Tensorflow graph\n", |
|||
"Once the model has been trained and saved, we can export it as a .bytes file which Unity can embed." |
|||
] |
|||
}, |
|||
{ |
|||
"cell_type": "code", |
|||
"execution_count": null, |
|||
"metadata": { |
|||
"collapsed": true |
|||
}, |
|||
"outputs": [], |
|||
"source": [ |
|||
"export_graph(model_path, env_name)" |
|||
] |
|||
} |
|||
], |
|||
"metadata": { |
|||
"anaconda-cloud": {}, |
|||
"kernelspec": { |
|||
"display_name": "Python 3", |
|||
"language": "python", |
|||
"name": "python3" |
|||
}, |
|||
"language_info": { |
|||
"codemirror_mode": { |
|||
"name": "ipython", |
|||
"version": 3 |
|||
}, |
|||
"file_extension": ".py", |
|||
"mimetype": "text/x-python", |
|||
"name": "python", |
|||
"nbconvert_exporter": "python", |
|||
"pygments_lexer": "ipython3", |
|||
"version": "3.6.2" |
|||
} |
|||
}, |
|||
"nbformat": 4, |
|||
"nbformat_minor": 1 |
|||
} |
|
|||
# # Unity ML Agents |
|||
# ## Proximal Policy Optimization (PPO) |
|||
# Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347). |
|||
|
|||
from docopt import docopt |
|||
|
|||
import os |
|||
from ppo.models import * |
|||
from ppo.trainer import Trainer |
|||
from unityagents import UnityEnvironment |
|||
|
|||
_USAGE = ''' |
|||
Usage: |
|||
ppo (<env>) [options] |
|||
|
|||
Options: |
|||
--help Show this message. |
|||
--batch-size=<n> How many experiences per gradient descent update step [default: 64]. |
|||
--beta=<n> Strength of entropy regularization [default: 2.5e-3]. |
|||
--buffer-size=<n> How large the experience buffer should be before gradient descent [default: 2048]. |
|||
--curriculum=<file> Curriculum json file for environment [default: None]. |
|||
--epsilon=<n> Acceptable threshold around ratio of old and new policy probabilities [default: 0.2]. |
|||
--gamma=<n> Reward discount rate [default: 0.99]. |
|||
--hidden-units=<n> Number of units in hidden layer [default: 64]. |
|||
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5]. |
|||
--lambd=<n> Lambda parameter for GAE [default: 0.95]. |
|||
--learning-rate=<rate> Model learning rate [default: 3e-4]. |
|||
--load Whether to load the model or randomly initialize [default: False]. |
|||
--max-steps=<n> Maximum number of steps to run environment [default: 1e6]. |
|||
--normalize Whether to normalize the state input using running statistics [default: False]. |
|||
--num-epoch=<n> Number of gradient descent steps per batch of experiences [default: 5]. |
|||
--num-layers=<n> Number of hidden layers between state/observation and outputs [default: 2]. |
|||
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo]. |
|||
--save-freq=<n> Frequency at which to save model [default: 50000]. |
|||
--summary-freq=<n> Frequency at which to save training statistics [default: 10000]. |
|||
--time-horizon=<n> How many steps to collect per agent before adding to buffer [default: 2048]. |
|||
--train Whether to train model, or only run inference [default: False]. |
|||
--worker-id=<n> Number to add to communication port (5005). Used for multi-environment [default: 0]. |
|||
--lesson=<n> Start learning from this lesson [default: 0]. |
|||
''' |
|||
|
|||
options = docopt(_USAGE) |
|||
print(options) |
|||
|
|||
# General parameters |
|||
max_steps = float(options['--max-steps']) |
|||
model_path = './models/{}'.format(str(options['--run-path'])) |
|||
summary_path = './summaries/{}'.format(str(options['--run-path'])) |
|||
load_model = options['--load'] |
|||
train_model = options['--train'] |
|||
summary_freq = int(options['--summary-freq']) |
|||
save_freq = int(options['--save-freq']) |
|||
env_name = options['<env>'] |
|||
keep_checkpoints = int(options['--keep-checkpoints']) |
|||
worker_id = int(options['--worker-id']) |
|||
curriculum_file = str(options['--curriculum']) |
|||
if curriculum_file == "None": |
|||
curriculum_file = None |
|||
lesson = int(options['--lesson']) |
|||
|
|||
# Algorithm-specific parameters for tuning |
|||
gamma = float(options['--gamma']) |
|||
lambd = float(options['--lambd']) |
|||
time_horizon = int(options['--time-horizon']) |
|||
beta = float(options['--beta']) |
|||
num_epoch = int(options['--num-epoch']) |
|||
num_layers = int(options['--num-layers']) |
|||
epsilon = float(options['--epsilon']) |
|||
buffer_size = int(options['--buffer-size']) |
|||
learning_rate = float(options['--learning-rate']) |
|||
hidden_units = int(options['--hidden-units']) |
|||
batch_size = int(options['--batch-size']) |
|||
normalize = options['--normalize'] |
|||
|
|||
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file, lesson=lesson) |
|||
print(str(env)) |
|||
brain_name = env.external_brain_names[0] |
|||
|
|||
tf.reset_default_graph() |
|||
|
|||
# Create the Tensorflow model graph |
|||
ppo_model = create_agent_model(env, lr=learning_rate, |
|||
h_size=hidden_units, epsilon=epsilon, |
|||
beta=beta, max_step=max_steps, |
|||
normalize=normalize, num_layers=num_layers) |
|||
|
|||
is_continuous = (env.brains[brain_name].action_space_type == "continuous") |
|||
use_observations = (env.brains[brain_name].number_observations > 0) |
|||
use_states = (env.brains[brain_name].state_space_size > 0) |
|||
|
|||
if not os.path.exists(model_path): |
|||
os.makedirs(model_path) |
|||
|
|||
if not os.path.exists(summary_path): |
|||
os.makedirs(summary_path) |
|||
|
|||
init = tf.global_variables_initializer() |
|||
saver = tf.train.Saver(max_to_keep=keep_checkpoints) |
|||
|
|||
|
|||
def get_progress(): |
|||
if curriculum_file is not None: |
|||
if env._curriculum.measure_type == "progress": |
|||
return steps / max_steps |
|||
elif env._curriculum.measure_type == "reward": |
|||
return last_reward |
|||
else: |
|||
return None |
|||
else: |
|||
return None |
|||
|
|||
with tf.Session() as sess: |
|||
# Instantiate model parameters |
|||
if load_model: |
|||
print('Loading Model...') |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
if ckpt == None: |
|||
print('The model {0} could not be found. Make sure you specified the right ' |
|||
'--run-path'.format(model_path)) |
|||
saver.restore(sess, ckpt.model_checkpoint_path) |
|||
else: |
|||
sess.run(init) |
|||
steps, last_reward = sess.run([ppo_model.global_step, ppo_model.last_reward]) |
|||
summary_writer = tf.summary.FileWriter(summary_path) |
|||
info = env.reset(train_mode=train_model, progress=get_progress())[brain_name] |
|||
trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations, use_states, train_model) |
|||
if train_model: |
|||
trainer.write_text(summary_writer, 'Hyperparameters', options, steps) |
|||
while steps <= max_steps or not train_model: |
|||
if env.global_done: |
|||
info = env.reset(train_mode=train_model, progress=get_progress())[brain_name] |
|||
trainer.reset_buffers(info, total=True) |
|||
# Decide and take an action |
|||
new_info = trainer.take_action(info, env, brain_name, steps, normalize) |
|||
info = new_info |
|||
trainer.process_experiences(info, time_horizon, gamma, lambd) |
|||
if len(trainer.training_buffer['actions']) > buffer_size and train_model: |
|||
# Perform gradient descent with experience buffer |
|||
trainer.update_model(batch_size, num_epoch) |
|||
if steps % summary_freq == 0 and steps != 0 and train_model: |
|||
# Write training statistics to tensorboard. |
|||
trainer.write_summary(summary_writer, steps, env._curriculum.lesson_number) |
|||
if steps % save_freq == 0 and steps != 0 and train_model: |
|||
# Save Tensorflow model |
|||
save_model(sess, model_path=model_path, steps=steps, saver=saver) |
|||
if train_model: |
|||
steps += 1 |
|||
sess.run(ppo_model.increment_step) |
|||
if len(trainer.stats['cumulative_reward']) > 0: |
|||
mean_reward = np.mean(trainer.stats['cumulative_reward']) |
|||
sess.run(ppo_model.update_reward, feed_dict={ppo_model.new_reward: mean_reward}) |
|||
last_reward = sess.run(ppo_model.last_reward) |
|||
# Final save Tensorflow model |
|||
if steps != 0 and train_model: |
|||
save_model(sess, model_path=model_path, steps=steps, saver=saver) |
|||
env.close() |
|||
graph_name = (env_name.strip() |
|||
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', '')) |
|||
graph_name = os.path.basename(os.path.normpath(graph_name)) |
|||
export_graph(model_path, graph_name) |
撰写
预览
正在加载...
取消
保存
Reference in new issue