浏览代码

Multi Brain Training and Recurrent state encoder (#166)

* `learn.py` is now main script for training brains. 
* Simultaneous multi-brain training is now possible.
* `ghost-trainer` allows for proper training in adversarial scenarios.
* `imitation-trainer` provides a basic implementation of real-time behavioral cloning.
* All trainer hyperparameters now exist in `.yaml` files.
* `PPO.ipynb` removed.
* LSTM model added.
* More dynamic buffer class to handle greater variety of scenarios.
/develop-generalizationTraining-TrainerController
Arthur Juliani 7 年前
当前提交
de700c3a
共有 21 个文件被更改,包括 2214 次插入439 次删除
  1. 3
      python/requirements.txt
  2. 159
      python/test_unityagents.py
  3. 40
      python/unityagents/curriculum.py
  4. 48
      python/unityagents/environment.py
  5. 12
      python/unityagents/exception.py
  6. 1
      unity-environment/Assets/ML-Agents/Scripts/Agent.cs
  7. 15
      unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
  8. 20
      unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs
  9. 223
      python/learn.py
  10. 46
      python/trainer_configurations.yaml
  11. 227
      python/trainers/buffer.py
  12. 204
      python/trainers/ghost_trainer.py
  13. 315
      python/trainers/imitation_trainer.py
  14. 381
      python/trainers/ppo_models.py
  15. 403
      python/trainers/ppo_trainer.py
  16. 159
      python/trainers/trainer.py
  17. 237
      python/PPO.ipynb
  18. 160
      python/ppo.py
  19. 0
      /python/trainers/__init__.py

3
python/requirements.txt


jupyter
mock>=2.0.0
pytest>=3.2.2
docopt
docopt
pyyaml

159
python/test_unityagents.py


import struct
import json
from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, BrainInfo, BrainParameters, Curriculum
import tensorflow as tf
from unityagents import UnityEnvironment, UnityEnvironmentException, UnityActionException, \
BrainInfo, BrainParameters, Curriculum
from trainers.ppo_models import *
from trainers.buffer import Buffer
def append_length(input):
return struct.pack("I", len(input.encode())) + input.encode()

UnityEnvironment(' ')
def test_initialialization():
def test_initialization():
with mock.patch('subprocess.Popen') as mock_subproc_popen:
with mock.patch('socket.socket') as mock_socket:
with mock.patch('glob.glob') as mock_glob:

with pytest.raises(UnityEnvironmentException):
curriculum = Curriculum('test_unityagents.py', {"param1":1,"param2":1})
curriculum = Curriculum('test_unityagents.py', {"param1":1,"param2":1,"param3":1})
assert curriculum.get_lesson_number() == 0
assert curriculum.get_lesson_number == 0
assert curriculum.get_lesson_number() == 1
curriculum.get_lesson(10)
assert curriculum.get_lesson_number() == 1
curriculum.get_lesson(30)
curriculum.get_lesson(30)
assert curriculum.get_lesson_number() == 1
assert curriculum.get_lesson_number == 1
curriculum.increment_lesson(10)
assert curriculum.get_lesson_number == 1
curriculum.increment_lesson(30)
curriculum.increment_lesson(30)
assert curriculum.get_lesson_number == 1
assert curriculum.get_lesson(30) == {'param1': 0.3, 'param2': 20, 'param3': 0.7}
curriculum.increment_lesson(30)
assert curriculum.get_config() == {'param1': 0.3, 'param2': 20, 'param3': 0.7}
assert curriculum.get_config(0) == {"param1":0.7,"param2":100,"param3":0.2}
assert curriculum.get_lesson_number() == 2
assert curriculum.get_lesson_number == 2
c_action_c_state_start = '''{
"AcademyName": "RealFakeAcademy",
"resetParameters": {},
"brainNames": ["RealFakeBrain"],
"externalBrainNames": ["RealFakeBrain"],
"logPath":"RealFakePath",
"apiNumber":"API-2",
"brainParameters": [{
"stateSize": 3,
"actionSize": 2,
"memorySize": 0,
"cameraResolutions": [],
"actionDescriptions": ["",""],
"actionSpaceType": 1,
"stateSpaceType": 1
}]
}'''.encode()
def test_ppo_model_continuous():
tf.reset_default_graph()
with mock.patch('subprocess.Popen') as mock_subproc_popen:
with mock.patch('socket.socket') as mock_socket:
with mock.patch('glob.glob') as mock_glob:
# End of mock
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_glob.return_value = ['FakeLaunchPath']
mock_socket.return_value.accept.return_value = (mock_socket, 0)
mock_socket.recv.return_value.decode.return_value = c_action_c_state_start
env = UnityEnvironment(' ')
model = create_agent_model(env.brains["RealFakeBrain"])
init = tf.global_variables_initializer()
sess.run(init)
run_list = [model.output, model.probs, model.value, model.entropy,
model.learning_rate]
feed_dict = {model.batch_size: 2,
model.sequence_length: 1,
model.state_in : np.array([[1,2,3],[3,4,5]]),
model.epsilon :np.random.randn(2, 2)
}
sess.run(run_list, feed_dict = feed_dict)
env.close()
d_action_c_state_start = '''{
"AcademyName": "RealFakeAcademy",
"resetParameters": {},
"brainNames": ["RealFakeBrain"],
"externalBrainNames": ["RealFakeBrain"],
"logPath":"RealFakePath",
"apiNumber":"API-2",
"brainParameters": [{
"stateSize": 3,
"actionSize": 2,
"memorySize": 0,
"cameraResolutions": [{"width":30,"height":40,"blackAndWhite":false}],
"actionDescriptions": ["",""],
"actionSpaceType": 0,
"stateSpaceType": 1
}]
}'''.encode()
def test_ppo_model_discrete():
tf.reset_default_graph()
with mock.patch('subprocess.Popen') as mock_subproc_popen:
with mock.patch('socket.socket') as mock_socket:
with mock.patch('glob.glob') as mock_glob:
# End of mock
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_glob.return_value = ['FakeLaunchPath']
mock_socket.return_value.accept.return_value = (mock_socket, 0)
mock_socket.recv.return_value.decode.return_value = d_action_c_state_start
env = UnityEnvironment(' ')
model = create_agent_model(env.brains["RealFakeBrain"])
init = tf.global_variables_initializer()
sess.run(init)
run_list = [model.output, model.probs, model.value, model.entropy,
model.learning_rate]
feed_dict = {model.batch_size: 2,
model.sequence_length: 1,
model.state_in : np.array([[1,2,3],[3,4,5]]),
model.observation_in[0] : np.ones([2,40,30,3])
}
sess.run(run_list, feed_dict = feed_dict)
env.close()
def assert_array(a, b):
assert a.shape == b.shape
la = list(a.flatten())
lb = list(b.flatten())
for i in range(len(la)):
assert la[i] == lb[i]
def test_buffer():
b = Buffer()
for fake_agent_id in range(4):
for i in range(9):
b[fake_agent_id]['state'].append(
[100*fake_agent_id+10*i +1, 100*fake_agent_id+10*i +2, 100*fake_agent_id+10*i +3]
)
b[fake_agent_id]['action'].append([100*fake_agent_id+10*i +4,100*fake_agent_id+10*i +5])
a = b[1]['state'].get_batch(batch_size = 2, training_length = None, sequential = True)
assert_array(a, np.array([[171,172,173], [181,182,183]]))
a = b[2]['state'].get_batch(batch_size = 2, training_length = 3, sequential = True)
assert_array(a, np.array([
[[231,232,233], [241,242,243], [251,252,253]],
[[261,262,263], [271,272,273], [281,282,283]]
]))
a = b[2]['state'].get_batch(batch_size = 2, training_length = 3, sequential = False)
assert_array(a, np.array([
[[251,252,253], [261,262,263], [271,272,273]],
[[261,262,263], [271,272,273], [281,282,283]]
]))
b[4].reset_agent()
assert len(b[4]) == 0
b.append_update_buffer(3,
batch_size = None, training_length=2)
b.append_update_buffer(2,
batch_size = None, training_length=2)
assert len(b.update_buffer['action']) == 10
assert np.array(b.update_buffer['action']).shape == (10,2,2)

40
python/unityagents/curriculum.py


from .exception import UnityEnvironmentException
import logging
logger = logging.getLogger("unityagents")
def __init__(self, location, default_reset_parameters, lesson):
def __init__(self, location, default_reset_parameters):
"""
Initializes a Curriculum object.
:param location: Path to JSON defining curriculum.

self.max_lesson_number = 0
self.measure_type = None
if location is None:
self.data = None

self.data = json.load(data_file)
except FileNotFoundError:
except IOError:
raise UnityEnvironmentException(
"The file {0} could not be found.".format(location))
except UnicodeDecodeError:

"The parameter {0} in Curriculum {1} must have {2} values "
"but {3} were found".format(key, location,
self.max_lesson_number + 1, len(parameters[key])))
self.set_lesson_number(lesson)
self.set_lesson_number(0)
@property
def get_lesson_number(self):
return self.lesson_number

def get_lesson(self, progress):
def increment_lesson(self, progress):
Returns reset parameters which correspond to current lesson.
Increments the lesson number depending on the progree given.
:return: Dictionary containing reset parameters.
return {}
return
if self.data["signal_smoothing"]:
progress = self.smoothing_value * 0.25 + 0.75 * progress
self.smoothing_value = progress

(self.lesson_length > self.data['min_lesson_length'])):
self.lesson_length = 0
self.lesson_number += 1
config = {}
parameters = self.data["parameters"]
for key in parameters:
config[key] = parameters[key][self.lesson_number]
logger.info("\nLesson changed. Now in Lesson {0} : \t{1}"
.format(self.lesson_number,
', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))
def get_config(self, lesson = None):
"""
Returns reset parameters which correspond to the lesson.
:param lesson: The lesson you want to get the config of. If None, the current lesson is returned.
:return: The configuration of the reset parameters.
"""
if self.data is None:
return {}
if lesson is None:
lesson = self.lesson_number
lesson = max(0, min(lesson, self.max_lesson_number))
config[key] = parameters[key][self.lesson_number]
config[key] = parameters[key][lesson]
return config

48
python/unityagents/environment.py


class UnityEnvironment(object):
def __init__(self, file_name, worker_id=0,
base_port=5005, curriculum=None, lesson=0):
base_port=5005, curriculum=None):
"""
Starts a new unity environment and establishes a connection with the environment.
Notice: Currently communication between Unity and Python takes place over an open socket without authentication.

self._num_brains = len(self._brain_names)
self._num_external_brains = len(self._external_brain_names)
self._resetParameters = p["resetParameters"]
self._curriculum = Curriculum(curriculum, self._resetParameters, lesson)
self._curriculum = Curriculum(curriculum, self._resetParameters)
for i in range(self._num_brains):
self._brains[self._brain_names[i]] = BrainParameters(self._brain_names[i], p["brainParameters"][i])
self._loaded = True

proc1.kill()
self.close()
raise
@property
def curriculum(self):
return self._curriculum
@property
def logfile_path(self):

return s
def __str__(self):
_new_reset_param = self._curriculum.get_config()
for k in _new_reset_param:
self._resetParameters[k] = _new_reset_param[k]
Number of brains: {1}
Reset Parameters :\n\t\t{2}'''.format(self._academy_name, str(self._num_brains),
"\n\t\t".join([str(k) + " -> " + str(self._resetParameters[k])
for k in self._resetParameters])) + '\n' + \
Number of Brains: {1}
Number of External Brains : {2}
Lesson number : {3}
Reset Parameters :\n\t\t{4}'''.format(self._academy_name, str(self._num_brains),
str(self._num_external_brains), self._curriculum.get_lesson_number,
"\n\t\t".join([str(k) + " -> " + str(self._resetParameters[k])
for k in self._resetParameters])) + '\n' + \
'\n'.join([str(self._brains[b]) for b in self._brains])
def _recv_bytes(self):

state_dict = json.loads(state)
return state_dict
def reset(self, train_mode=True, config=None, progress=None):
def reset(self, train_mode=True, config=None, lesson=None):
old_lesson = self._curriculum.get_lesson_number()
config = self._curriculum.get_lesson(progress)
if old_lesson != self._curriculum.get_lesson_number():
logger.info("\nLesson changed. Now in Lesson {0} : \t{1}"
.format(self._curriculum.get_lesson_number(),
', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))
config = self._curriculum.get_config(lesson)
elif config != {}:
logger.info("\nAcademy Reset with parameters : \t{0}"
.format(', '.join([str(x) + ' -> ' + str(config[x]) for x in config])))

except socket.timeout as e:
raise UnityTimeOutException("The environment took too long to respond.", self._log_path)
action_message = {"action": action, "memory": memory, "value": value}
self._conn.send(json.dumps(action_message).encode('utf-8'))
self._conn.send(self._append_length(json.dumps(action_message).encode('utf-8')))
@staticmethod
def _append_length(message):
return struct.pack("I", len(message)) + message
@staticmethod
def _flatten(arr):

if b not in memory:
memory[b] = [0.0] * self._brains[b].memory_space_size * n_agent
else:
memory[b] = self._flatten(memory[b])
if memory[b] is None:
memory[b] = [0.0] * self._brains[b].memory_space_size * n_agent
else:
memory[b] = self._flatten(memory[b])
value[b] = self._flatten(value[b])
if value[b] is None:
value[b] = [0.0] * n_agent
else:
value[b] = self._flatten(value[b])
if not (len(value[b]) == n_agent):
raise UnityActionException(
"There was a mismatch between the provided value and environment's expectation: "

12
python/unityagents/exception.py


import logging
logger = logging.getLogger("unityagents")
class UnityEnvironmentException(Exception):
class UnityException(Exception):
"""
Any error related to ml-agents environment.
"""
pass
class UnityEnvironmentException(UnityException):
"""
Related to errors starting and closing environment.
"""

class UnityActionException(Exception):
class UnityActionException(UnityException):
class UnityTimeOutException(Exception):
class UnityTimeOutException(UnityException):
"""
Related to errors with communication timeouts.
"""

1
unity-environment/Assets/ML-Agents/Scripts/Agent.cs


memory = new float[brain.brainParameters.memorySize];
stepCounter = 0;
AgentReset();
CumulativeReward = -reward;
}
/// Do not modify : Is used by the brain to collect rewards.

15
unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs


session = new TFSession(graph);
if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/')){
graphScope = graphScope + '/';
}
if (graph[graphScope + BatchSizePlaceholderName] != null)
{
hasBatchSize = true;

}
var runner = session.GetRunner();
runner.Fetch(graph[graphScope + ActionPlaceholderName][0]);
try
{
runner.Fetch(graph[graphScope + ActionPlaceholderName][0]);
}
catch
{
throw new UnityAgentsException(string.Format(@"The node {0} could not be found. Please make sure the graphScope {1} is correct",
graphScope + ActionPlaceholderName, graphScope));
}
if (hasBatchSize)
{

if (hasRecurrent)
{
runner.AddInput(graph[graphScope + "sequence_length"][0], 1 );
runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
}

20
unity-environment/Assets/ML-Agents/Scripts/ExternalCommunicator.cs


private int comPort;
Socket sender;
byte[] messageHolder;
byte[] lengthHolder;
const int messageLength = 12000;

logWriter.WriteLine(" ");
logWriter.Close();
messageHolder = new byte[messageLength];
lengthHolder = new byte[4];
// Create a TCP/IP socket.
sender = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

rMessage = Encoding.ASCII.GetString(messageHolder, 0, location);
}
/// Receives a message and can reconstruct a message if was too long
private string ReceiveAll(){
sender.Receive(lengthHolder);
int totalLength = System.BitConverter.ToInt32(lengthHolder, 0);
int location = 0;
rMessage = "";
while (location != totalLength){
int fragment = sender.Receive(messageHolder);
location += fragment;
rMessage += Encoding.ASCII.GetString(messageHolder, 0, fragment);
}
return rMessage;
}
/// Ends connection and closes environment
private void OnApplicationQuit()
{

{
// TO MODIFY --------------------------------------------
sender.Send(Encoding.ASCII.GetBytes("STEPPING"));
Receive();
ReceiveAll();
var agentMessage = JsonConvert.DeserializeObject<AgentMessage>(rMessage);
foreach (Brain brain in brains)

return storedValues[brainName];
}
}
}

223
python/learn.py


# # Unity ML Agents
# ## ML-Agent Learning (PPO)
# Launches trainers for each External Brains in a Unity Environemnt
import logging
import os
import re
import yaml
from docopt import docopt
from trainers.ghost_trainer import GhostTrainer
from trainers.ppo_models import *
from trainers.ppo_trainer import PPOTrainer
from trainers.imitation_trainer import ImitationTrainer
from unityagents import UnityEnvironment, UnityEnvironmentException
def get_progress():
if curriculum_file is not None:
if env.curriculum.measure_type == "progress":
progress = 0
for brain_name in env.external_brain_names:
progress += trainers[brain_name].get_step / trainers[brain_name].get_max_steps
return progress / len(env.external_brain_names)
elif env.curriculum.measure_type == "reward":
progress = 0
for brain_name in env.external_brain_names:
progress += trainers[brain_name].get_last_reward
return progress
else:
return None
else:
return None
if __name__ == '__main__' :
logger = logging.getLogger("unityagents")
_USAGE = '''
Usage:
ppo (<env>) [options]
Options:
--help Show this message.
--curriculum=<file> Curriculum json file for environment [default: None].
--slow Whether to run the game at training speed [default: False].
--keep-checkpoints=<n> How many model checkpoints to keep [default: 5].
--lesson=<n> Start learning from this lesson [default: 0].
--load Whether to load the model or randomly initialize [default: False].
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo].
--save-freq=<n> Frequency at which to save model [default: 50000].
--train Whether to train model, or only run inference [default: False].
--worker-id=<n> Number to add to communication port (5005). Used for multi-environment [default: 0].
'''
options = docopt(_USAGE)
logger.info(options)
# General parameters
model_path = './models/{}'.format(str(options['--run-path']))
load_model = options['--load']
train_model = options['--train']
save_freq = int(options['--save-freq'])
env_name = options['<env>']
keep_checkpoints = int(options['--keep-checkpoints'])
worker_id = int(options['--worker-id'])
curriculum_file = str(options['--curriculum'])
if curriculum_file == "None":
curriculum_file = None
lesson = int(options['--lesson'])
fast_simulation = not bool(options['--slow'])
env = UnityEnvironment(file_name=env_name, worker_id=worker_id, curriculum=curriculum_file)
env.curriculum.set_lesson_number(lesson)
logger.info(str(env))
tf.reset_default_graph()
try:
if not os.path.exists(model_path):
os.makedirs(model_path)
except:
raise UnityEnvironmentException("The folder {} containing the generated model could not be accessed."
" Please make sure the permissions are set correctly.".format(model_path))
try:
with open("trainer_configurations.yaml") as data_file:
trainer_configurations = yaml.load(data_file)
except IOError:
raise UnityEnvironmentException("The file {} could not be found. Will use default Hyperparameters"
.format("trainer_configurations.yaml"))
except UnicodeDecodeError:
raise UnityEnvironmentException("There was an error decoding {}".format("trainer_configurations.yaml"))
with tf.Session() as sess:
trainers = {}
trainer_parameters_dict = {}
for brain_name in env.external_brain_names:
trainer_parameters = trainer_configurations['default'].copy()
if len(env.external_brain_names) > 1:
graph_scope = re.sub('[^0-9a-zA-Z]+', '-', brain_name)
trainer_parameters['graph_scope'] = graph_scope
trainer_parameters['summary_path'] = './summaries/{}'.format(str(options['--run-path']))+'_'+graph_scope
else :
trainer_parameters['graph_scope'] = ''
trainer_parameters['summary_path'] = './summaries/{}'.format(str(options['--run-path']))
if brain_name in trainer_configurations:
_brain_key = brain_name
while not isinstance(trainer_configurations[_brain_key], dict):
_brain_key = trainer_configurations[_brain_key]
for k in trainer_configurations[_brain_key]:
trainer_parameters[k] = trainer_configurations[_brain_key][k]
trainer_parameters_dict[brain_name] = trainer_parameters.copy()
for brain_name in env.external_brain_names:
if 'is_ghost' not in trainer_parameters_dict[brain_name]:
trainer_parameters_dict[brain_name]['is_ghost'] = False
if 'is_imitation' not in trainer_parameters_dict[brain_name]:
trainer_parameters_dict[brain_name]['is_imitation'] = False
if trainer_parameters_dict[brain_name]['is_ghost']:
if trainer_parameters_dict[brain_name]['brain_to_copy'] not in env.external_brain_names:
raise UnityEnvironmentException("The external brain {0} could not be found in the environment "
"even though the ghost trainer of brain {1} is trying to ghost it."
.format(trainer_parameters_dict[brain_name]['brain_to_copy'], brain_name))
trainer_parameters_dict[brain_name]['original_brain_parameters'] = trainer_parameters_dict[
trainer_parameters_dict[brain_name]['brain_to_copy']]
trainers[brain_name] = GhostTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model)
elif trainer_parameters_dict[brain_name]['is_imitation']:
trainers[brain_name] = ImitationTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model)
else:
trainers[brain_name] = PPOTrainer(sess, env, brain_name, trainer_parameters_dict[brain_name], train_model)
for k, t in trainers.items():
logger.info(t)
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=keep_checkpoints)
# Instantiate model parameters
if load_model:
logger.info('Loading Model...')
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt == None:
logger.info('The model {0} could not be found. Make sure you specified the right '
'--run-path'.format(model_path))
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(init)
global_step = 0 # This is only for saving the model
env.curriculum.increment_lesson(get_progress())
info = env.reset(train_mode= fast_simulation)
if train_model:
for brain_name, trainer in trainers.items():
trainer.write_tensorboard_text('Hyperparameters', trainer.parameters)
try:
while any([t.get_step <= t.get_max_steps for k, t in trainers.items()]) or not train_model:
if env.global_done:
env.curriculum.increment_lesson(get_progress())
info = env.reset(train_mode=fast_simulation)
for brain_name, trainer in trainers.items():
trainer.end_episode()
# Decide and take an action
take_action_actions = {}
take_action_memories = {}
take_action_values = {}
take_action_outputs = {}
for brain_name, trainer in trainers.items():
(take_action_actions[brain_name],
take_action_memories[brain_name],
take_action_values[brain_name],
take_action_outputs[brain_name]) = trainer.take_action(info)
new_info = env.step(action = take_action_actions, memory = take_action_memories, value = take_action_values)
for brain_name, trainer in trainers.items():
trainer.add_experiences(info, new_info, take_action_outputs[brain_name])
info = new_info
for brain_name, trainer in trainers.items():
trainer.process_experiences(info)
if trainer.is_ready_update() and train_model and trainer.get_step <= trainer.get_max_steps:
# Perform gradient descent with experience buffer
trainer.update_model()
# Write training statistics to tensorboard.
trainer.write_summary(env.curriculum.lesson_number)
if train_model and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step()
trainer.update_last_reward()
if train_model and trainer.get_step <= trainer.get_max_steps:
global_step += 1
if global_step % save_freq == 0 and global_step != 0 and train_model:
# Save Tensorflow model
save_model(sess, model_path=model_path, steps=global_step, saver=saver)
# Final save Tensorflow model
if global_step != 0 and train_model:
save_model(sess, model_path=model_path, steps=global_step, saver=saver)
except KeyboardInterrupt:
if train_model:
logger.info("Learning was interupted. Please wait while the graph is generated.")
save_model(sess, model_path=model_path, steps=global_step, saver=saver)
pass
env.close()
if train_model:
graph_name = (env_name.strip()
.replace('.app', '').replace('.exe', '').replace('.x86_64', '').replace('.x86', ''))
graph_name = os.path.basename(os.path.normpath(graph_name))
nodes = []
scopes = []
for brain_name in trainers.keys():
if trainers[brain_name].graph_scope is not None:
scope = trainers[brain_name].graph_scope + '/'
if scope == '/':
scope = ''
scopes += [scope]
if trainers[brain_name].parameters["is_imitation"]:
nodes +=[scope + x for x in ["action"]]
elif not trainers[brain_name].parameters["use_recurrent"]:
nodes +=[scope + x for x in ["action","value_estimate","action_probs"]]
else:
nodes +=[scope + x for x in ["action","value_estimate","action_probs","recurrent_out"]]
export_graph(model_path, graph_name, target_nodes=','.join(nodes))
if len(scopes) > 1:
logger.info("List of available scopes :")
for scope in scopes:
logger.info("\t" + scope )
logger.info("List of nodes exported :")
for n in nodes:
logger.info("\t" + n)

46
python/trainer_configurations.yaml


default:
batch_size: 64
beta: 2.5e-3
buffer_size: 2048
epsilon: 0.2
gamma: 0.99
hidden_units: 64
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 1.0e6
normalize: false
num_epoch: 5
num_layers: 2
time_horizon: 2048
sequence_length: 32
summary_freq: 10000
use_recurrent: false
Ball3DBrain:
summary_freq: 1000
normalize: true
batch_size: 1024
max_steps: 1.0e4
TurretBrain: ExampleBrain
ghost-HunterBrain:
brain_to_copy : HunterBrain
is_ghost : true
new_model_freq : 10000
max_num_models : 20
ghost-HunteeBrain :
brain_to_copy : HunteeBrain
is_ghost : true
new_model_freq : 10000
max_num_models : 20
ghost-Ball3DBrain:
brain_to_copy : Ball3DBrain
is_ghost : true
new_model_freq : 10000
max_num_models : 3
Player:
# is_imitation : true
max_steps: 10000
summary_freq: 1000

227
python/trainers/buffer.py


import numpy as np
from unityagents.exception import UnityException
class BufferException(UnityException):
"""
Related to errors with the Buffer.
"""
pass
class Buffer(dict):
"""
Buffer contains a dictionary of AgentBuffer. The AgentBuffers are indexed by agent_id.
Buffer also contains an update_buffer that corresponds to the buffer used when updating the model.
"""
class AgentBuffer(dict):
"""
AgentBuffer contains a dictionary of AgentBufferFields. Each agent has his own AgentBuffer.
The keys correspond to the name of the field. Example: state, action
"""
class AgentBufferField(list):
"""
AgentBufferField is a list of numpy arrays. When an agent collects a field, you can add it to his
AgentBufferField with the append method.
"""
def __str__(self):
return str(np.array(self).shape)
def extend(self, data):
"""
Ads a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
self += list(np.array(data))
def set(self, data):
"""
Sets the list of np.array to the input data
:param data: The np.array list to be set.
"""
self[:] = []
self[:] = list(np.array(data))
def get_batch(self, batch_size = None, training_length = None, sequential = True):
"""
Retrieve the last batch_size elements of length training_length
from the list of np.array
:param batch_size: The number of elements to retrieve. If None:
All elements will be retrieved.
:param training_length: The length of the sequence to be retrieved. If
None: only takes one element.
:param sequential: If true and training_length is not None: the elements
will not repeat in the sequence. [a,b,c,d,e] with training_length = 2 and
sequential=True gives [[0,a],[b,c],[d,e]]. If sequential=False gives
[[a,b],[b,c],[c,d],[d,e]]
"""
if training_length is None:
# When the training length is None, the method returns a list of elements,
# not a list of sequences of elements.
if batch_size is None:
# If batch_size is None : All the elements of the AgentBufferField are returned.
return np.array(self)
else:
# return the batch_size last elements
if batch_size > len(self):
raise BufferException("Batch size requested is too large")
return np.array(self[-batch_size:])
else:
# The training_length is not None, the method returns a list of SEQUENCES of elements
if not sequential:
# The sequences will have overlapping elements
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) - training_length + 1
# The number of sequences of length training_length taken from a list of len(self) elements
# with overlapping is equal to batch_size
if (len(self) - training_length + 1) < batch_size :
raise BufferException("The batch size and training length requested for get_batch where"
" too large given the current number of data points.")
return
tmp_list = []
for end in range(len(self)-batch_size+1, len(self)+1):
tmp_list += [np.array(self[end-training_length:end])]
return np.array(tmp_list)
if sequential:
# The sequences will not have overlapping elements (this involves padding)
leftover = len(self) % training_length
# leftover is the number of elements in the first sequence (this sequence might need 0 padding)
if batch_size is None:
# retrieve the maximum number of elements
batch_size = len(self) // training_length +1 *(leftover != 0)
# The maximum number of sequences taken from a list of length len(self) without overlapping
# with padding is equal to batch_size
if batch_size > (len(self) // training_length +1 *(leftover != 0)):
raise BufferException("The batch size and training length requested for get_batch where"
" too large given the current number of data points.")
return
tmp_list = []
padding = np.array(self[-1]) * 0
# The padding is made with zeros and its shape is given by the shape of the last element
for end in range(len(self), len(self) % training_length , -training_length)[:batch_size]:
tmp_list += [np.array(self[end-training_length:end])]
if (leftover != 0) and (len(tmp_list) < batch_size):
tmp_list +=[np.array([padding]*(training_length - leftover)+self[:leftover])]
tmp_list.reverse()
return np.array(tmp_list)
def reset_field(self):
"""
Resets the AgentBufferField
"""
self[:] = []
def __str__(self):
return ", ".join(["'{0}' : {1}".format(k, str(self[k])) for k in self.keys()])
def reset_agent(self):
"""
Resets the AgentBuffer
"""
for k in self.keys():
self[k].reset_field()
def __getitem__(self, key):
if key not in self.keys():
self[key] = self.AgentBufferField()
return super(Buffer.AgentBuffer, self).__getitem__(key)
def check_length(self, key_list):
"""
Some methods will require that some fields have the same length.
check_length will return true if the fields in key_list
have the same length.
:param key_list: The fields which length will be compared
"""
if len(key_list) < 2:
return True
l = None
for key in key_list:
if key not in self.keys():
return False
if ((l != None) and (l!=len(self[key]))):
return False
l = len(self[key])
return True
def shuffle(self, key_list = None):
"""
Shuffles the fields in key_list in a consistent way: The reordering will
be the same accross fields.
:param key_list: The fields that must be shuffled.
"""
if key_list is None:
key_list = list(self.keys())
if not self.check_length(key_list):
raise BufferException("Unable to shuffle if the fields are not of same length")
return
s = np.arange(len(self[key_list[0]]))
np.random.shuffle(s)
for key in key_list:
self[key][:] = [self[key][i] for i in s]
def __init__(self):
self.update_buffer = self.AgentBuffer()
super(Buffer, self).__init__()
def __str__(self):
return "update buffer :\n\t{0}\nlocal_buffers :\n{1}".format(str(self.update_buffer),
'\n'.join(['\tagent {0} :{1}'.format(k, str(self[k])) for k in self.keys()]))
def __getitem__(self, key):
if key not in self.keys():
self[key] = self.AgentBuffer()
return super(Buffer, self).__getitem__(key)
def reset_update_buffer(self):
"""
Resets the update buffer
"""
self.update_buffer.reset_agent()
def reset_all(self):
"""
Resets the update buffer and all the local local_buffers
"""
self.update_buffer.reset_agent()
agent_ids = list(self.keys())
for k in agent_ids:
self[k].reset_agent()
def append_update_buffer(self, agent_id ,key_list = None, batch_size = None, training_length = None):
"""
Appends the buffer of an agent to the update buffer.
:param agent_id: The id of the agent which data will be appended
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
if key_list is None:
key_list = self[agent_id].keys()
if not self[agent_id].check_length(key_list):
raise BufferException("The length of the fields {0} for agent {1} where not of same length"
.format(key_list, agent_id))
for field_key in key_list:
self.update_buffer[field_key].extend(
self[agent_id][field_key].get_batch(batch_size =batch_size, training_length =training_length)
)
def append_all_agent_batch_to_update_buffer(self, key_list = None, batch_size = None, training_length = None):
"""
Appends the buffer of all agents to the update buffer.
:param key_list: The fields that must be added. If None: all fields will be appended.
:param batch_size: The number of elements that must be appended. If None: All of them will be.
:param training_length: The length of the samples that must be appended. If None: only takes one element.
"""
for agent_id in self.keys():
self.append_update_buffer(agent_id ,key_list, batch_size, training_length)

204
python/trainers/ghost_trainer.py


import logging
import os
import numpy as np
import tensorflow as tf
from trainers.ppo_models import *
from trainers.trainer import UnityTrainerException, Trainer
logger = logging.getLogger("unityagents")
#This works only with PPO
class GhostTrainer(Trainer):
"""Keeps copies of a PPOTrainer past graphs and uses them to other Trainers."""
def __init__(self, sess, env, brain_name, trainer_parameters, training):
"""
Responsible for saving and reusing past models.
:param sess: Tensorflow session.
:param env: The UnityEnvironment.
:param trainer_parameters: The parameters for the trainer (dictionary).
:param training: Whether the trainer is set for training.
"""
self.param_keys = ['brain_to_copy', 'is_ghost', 'new_model_freq', 'max_num_models']
for k in self.param_keys:
if k not in trainer_parameters:
raise UnityTrainerException("The hyperparameter {0} could not be found for the PPO trainer of "
"brain {1}.".format(k, brain_name))
super(GhostTrainer, self).__init__(sess, env, brain_name, trainer_parameters, training)
self.brain_to_copy = trainer_parameters['brain_to_copy']
self.variable_scope = trainer_parameters['graph_scope']
self.original_brain_parameters = trainer_parameters['original_brain_parameters']
self.new_model_freq = trainer_parameters['new_model_freq']
self.steps = 0
self.models = []
self.max_num_models = trainer_parameters['max_num_models']
self.last_model_replaced = 0
for i in range(self.max_num_models):
with tf.variable_scope(self.variable_scope+'_'+str(i)):
self.models += [create_agent_model(env.brains[self.brain_to_copy],
lr=float(self.original_brain_parameters['learning_rate']),
h_size=int(self.original_brain_parameters['hidden_units']),
epsilon=float(self.original_brain_parameters['epsilon']),
beta=float(self.original_brain_parameters['beta']),
max_step=float(self.original_brain_parameters['max_steps']),
normalize=self.original_brain_parameters['normalize'],
use_recurrent=self.original_brain_parameters['use_recurrent'],
num_layers=int(self.original_brain_parameters['num_layers']),
m_size = self.original_brain_parameters)]
self.model = self.models[0]
self.is_continuous = (env.brains[brain_name].action_space_type == "continuous")
self.use_observations = (env.brains[brain_name].number_observations > 0)
self.use_states = (env.brains[brain_name].state_space_size > 0)
self.use_recurrent = self.original_brain_parameters["use_recurrent"]
self.summary_path = trainer_parameters['summary_path']
def __str__(self):
return '''Hypermarameters for the Ghost Trainer of brain {0}: \n{1}'''.format(
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys]))
@property
def parameters(self):
"""
Returns the trainer parameters of the trainer.
"""
return self.trainer_parameters
@property
def graph_scope(self):
"""
Returns the graph scope of the trainer.
"""
return None
@property
def get_max_steps(self):
"""
Returns the maximum number of steps. Is used to know when the trainer should be stopped.
:return: The maximum number of steps of the trainer
"""
return 1
@property
def get_step(self):
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
"""
return 0
@property
def get_last_reward(self):
"""
Returns the last reward the trainer has had
:return: the new last reward
"""
return 0
def increment_step(self):
"""
Increment the step count of the trainer
"""
self.steps += 1
def update_last_reward(self):
"""
Updates the last reward
"""
return
def update_target_graph(self, from_scope, to_scope):
from_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, from_scope)
to_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, to_scope)
op_holder = []
for from_var,to_var in zip(from_vars,to_vars):
op_holder.append(to_var.assign(from_var))
return op_holder
def take_action(self, info):
"""
Decides actions given state/observation information, and takes them in environment.
:param info: Current BrainInfo from environment.
:return: a tupple containing action, memories, values and an object
to be passed to add experiences
"""
epsi = None
info = info[self.brain_name]
feed_dict = {self.model.batch_size: len(info.states), self.model.sequence_length: 1}
run_list = [self.model.output]
if self.is_continuous:
epsi = np.random.randn(len(info.states), self.brain.action_space_size)
feed_dict[self.model.epsilon] = epsi
if self.use_observations:
for i, _ in enumerate(info.observations):
feed_dict[self.model.observation_in[i]] = info.observations[i]
if self.use_states:
feed_dict[self.model.state_in] = info.states
if self.use_recurrent:
feed_dict[self.model.memory_in] = info.memories
run_list += [self.model.memory_out]
if self.use_recurrent:
actions, memories = self.sess.run(run_list, feed_dict=feed_dict)
else:
actions = self.sess.run(run_list, feed_dict=feed_dict)
memories = None
return (actions, memories, None, None)
def add_experiences(self, info, next_info, take_action_outputs):
"""
Adds experiences to each agent's experience history.
:param info: Current BrainInfo.
:param next_info: Next BrainInfo.
:param take_action_outputs: The outputs of the take action method.
"""
return
def process_experiences(self, info):
"""
Checks agent histories for processing condition, and processes them as necessary.
Processing involves calculating value and advantage targets for model updating step.
:param info: Current BrainInfo
"""
return
def end_episode(self):
"""
A signal that the Episode has ended. We must use another version of the graph.
"""
self.model = self.models[np.random.randint(0, self.max_num_models)]
def is_ready_update(self):
"""
Returns wether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
return self.steps % self.new_model_freq == 0
def update_model(self):
"""
Uses training_buffer to update model.
"""
self.last_model_replaced = (self.last_model_replaced + 1) % self.max_num_models
self.sess.run(self.update_target_graph(
self.original_brain_parameters['graph_scope'],
self.variable_scope+'_'+str(self.last_model_replaced))
)
return
def write_summary(self, lesson_number):
"""
Saves training statistics to Tensorboard.
:param lesson_number: The lesson the trainer is at.
"""
return

315
python/trainers/imitation_trainer.py


# # Unity ML Agents
# ## ML-Agent Learning (PPO)
# Contains an implementation of PPO as described [here](https://arxiv.org/abs/1707.06347).
import logging
import os
import numpy as np
import tensorflow as tf
from trainers.buffer import Buffer
from trainers.ppo_models import *
from trainers.trainer import UnityTrainerException, Trainer
logger = logging.getLogger("unityagents")
class ImitationNN(object):
def __init__(self, state_size, action_size, h_size, lr, action_type, n_layers):
self.state = tf.placeholder(shape=[None, state_size], dtype=tf.float32, name="state")
hidden = tf.layers.dense(self.state, h_size, activation=tf.nn.elu)
for i in range(n_layers):
hidden = tf.layers.dense(hidden, h_size, activation=tf.nn.elu)
hidden_drop = tf.layers.dropout(hidden, 0.5)
self.output = tf.layers.dense(hidden_drop, action_size, activation=None)
if action_type == "discrete":
self.action_probs = tf.nn.softmax(self.output)
self.sample_action = tf.multinomial(self.output, 1, name="action")
self.true_action = tf.placeholder(shape=[None], dtype=tf.int32)
self.action_oh = tf.one_hot(self.true_action, action_size)
self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh)
self.action_percent = tf.reduce_mean(tf.cast(
tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32))
else:
self.sample_action = tf.identity(self.output, name="action")
self.true_action = tf.placeholder(shape=[None, action_size], dtype=tf.float32)
self.loss = tf.reduce_sum(tf.squared_difference(self.true_action, self.sample_action))
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
self.update = optimizer.minimize(self.loss)
class ImitationTrainer(Trainer):
"""The ImitationTrainer is an implementation of the imitation learning."""
def __init__(self, sess, env, brain_name, trainer_parameters, training):
"""
Responsible for collecting experiences and training PPO model.
:param sess: Tensorflow session.
:param env: The UnityEnvironment.
:param trainer_parameters: The parameters for the trainer (dictionary).
:param training: Whether the trainer is set for training.
"""
self.param_keys = [ 'is_imitation', 'brain_to_imitate', 'batch_size', 'time_horizon', 'graph_scope',
'summary_freq', 'max_steps']
for k in self.param_keys:
if k not in trainer_parameters:
raise UnityTrainerException("The hyperparameter {0} could not be found for the Imitation trainer of "
"brain {1}.".format(k, brain_name))
super(ImitationTrainer, self).__init__(sess, env, brain_name, trainer_parameters, training)
self.variable_scope = trainer_parameters['graph_scope']
self.brain_to_imitate = trainer_parameters['brain_to_imitate']
self.batch_size = trainer_parameters['batch_size']
self.step = 0
self.cumulative_rewards = {}
self.episode_steps = {}
self.stats = {'losses': [], 'episode_length': [], 'cumulative_reward' : []}
self.training_buffer = Buffer()
self.is_continuous = (env.brains[brain_name].action_space_type == "continuous")
self.use_observations = (env.brains[brain_name].number_observations > 0)
if self.use_observations:
logger.log('Cannot use observations with imitation learning')
self.use_states = (env.brains[brain_name].state_space_size > 0)
self.summary_path = trainer_parameters['summary_path']
if not os.path.exists(self.summary_path):
os.makedirs(self.summary_path)
self.summary_writer = tf.summary.FileWriter(self.summary_path)
s_size = self.brain.state_space_size * 1#brain_parameters.stacked_states
a_size = self.brain.action_space_size
with tf.variable_scope(self.variable_scope):
self.network = ImitationNN(state_size = s_size,
action_size = a_size,
h_size = int(trainer_parameters['hidden_units']),
lr = float(trainer_parameters['learning_rate']),
action_type = self.brain.action_space_type,
n_layers=int(trainer_parameters['num_layers']))
def __str__(self):
return '''Hypermarameters for the Imitation Trainer of brain {0}: \n{1}'''.format(
self.brain_name, '\n'.join(['\t{0}:\t{1}'.format(x, self.trainer_parameters[x]) for x in self.param_keys]))
@property
def parameters(self):
"""
Returns the trainer parameters of the trainer.
"""
return self.trainer_parameters
@property
def graph_scope(self):
"""
Returns the graph scope of the trainer.
"""
return self.variable_scope
@property
def get_max_steps(self):
"""
Returns the maximum number of steps. Is used to know when the trainer should be stopped.
:return: The maximum number of steps of the trainer
"""
return self.trainer_parameters['max_steps']
@property
def get_step(self):
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
"""
return self.step
@property
def get_last_reward(self):
"""
Returns the last reward the trainer has had
:return: the new last reward
"""
if len(self.stats['cumulative_reward']) > 0:
return np.mean(self.stats['cumulative_reward'])
else:
return 0
def increment_step(self):
"""
Increment the step count of the trainer
"""
self.step += 1
def update_last_reward(self):
"""
Updates the last reward
"""
return
def take_action(self, info):
"""
Decides actions given state/observation information, and takes them in environment.
:param info: Current BrainInfo from environment.
:return: a tupple containing action, memories, values and an object
to be passed to add experiences
"""
E = info[self.brain_name]
agent_action = self.sess.run(self.network.sample_action, feed_dict={self.network.state: E.states})
return (agent_action, None, None, None)
def add_experiences(self, info, next_info, take_action_outputs):
"""
Adds experiences to each agent's experience history.
:param info: Current BrainInfo.
:param next_info: Next BrainInfo.
:param take_action_outputs: The outputs of the take action method.
"""
info_P = info[self.brain_to_imitate]
next_info_P = next_info[self.brain_to_imitate]
for agent_id in info_P.agents:
if agent_id in next_info_P.agents:
idx = info_P.agents.index(agent_id)
next_idx = next_info_P.agents.index(agent_id)
if not info_P.local_done[idx]:
self.training_buffer[agent_id]['states'].append(info_P.states[idx])
self.training_buffer[agent_id]['actions'].append(next_info_P.previous_actions[next_idx])
# self.training_buffer[agent_id]['rewards'].append(next_info.rewards[next_idx])
info_E = next_info[self.brain_name]
next_info_E = next_info[self.brain_name]
for agent_id in info_E.agents:
idx = info_E.agents.index(agent_id)
next_idx = next_info_E.agents.index(agent_id)