浏览代码

Remove graph scope (#1205)

* initial commit : Only works with PPO balance ball

* Fix for recurrent

* [Fix indentation error]

* Fixed BC

* Remove Dead code

* Addressing comment : Removing dead code

* Fixing the Pytest

* edited comments

* Removing GraphScope from the InternalBrain (#1227)

* Documentation changes for removing graph scope (#1226)

* Documentation changes

* removed the keep checkpoint printing
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
d2c320dd
共有 18 个文件被更改,包括 393 次插入416 次删除
  1. 61
      UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
  2. 2
      docs/Basic-Guide.md
  3. 22
      docs/FAQ.md
  4. 5
      docs/Learning-Environment-Design-External-Internal-Brains.md
  5. 2
      docs/Learning-Environment-Executable.md
  6. 4
      docs/Using-TensorFlow-Sharp-in-Unity.md
  7. 93
      ml-agents/mlagents/trainers/bc/models.py
  8. 32
      ml-agents/mlagents/trainers/bc/policy.py
  9. 20
      ml-agents/mlagents/trainers/bc/trainer.py
  10. 83
      ml-agents/mlagents/trainers/policy.py
  11. 39
      ml-agents/mlagents/trainers/ppo/models.py
  12. 42
      ml-agents/mlagents/trainers/ppo/policy.py
  13. 17
      ml-agents/mlagents/trainers/ppo/trainer.py
  14. 28
      ml-agents/mlagents/trainers/trainer.py
  15. 286
      ml-agents/mlagents/trainers/trainer_controller.py
  16. 26
      ml-agents/tests/trainers/test_bc.py
  17. 28
      ml-agents/tests/trainers/test_ppo.py
  18. 19
      ml-agents/tests/trainers/test_trainer_controller.py

61
UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs


/// Modify only in inspector : Reference to the Graph asset
public TextAsset graphModel;
/// Modify only in inspector : If a scope was used when training the model, specify it here
public string graphScope;
[SerializeField]
[Tooltip(
"If your graph takes additional inputs that are fixed (example: noise level) you can specify them here.")]

// TODO: Make this a loop over a dynamic set of graph inputs
if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/'))
{
graphScope = graphScope + '/';
}
if (graph[graphScope + BatchSizePlaceholderName] != null)
if (graph[BatchSizePlaceholderName] != null)
if ((graph[graphScope + RecurrentInPlaceholderName] != null) &&
(graph[graphScope + RecurrentOutPlaceholderName] != null))
if ((graph[RecurrentInPlaceholderName] != null) &&
(graph[RecurrentOutPlaceholderName] != null))
runner.Fetch(graph[graphScope + "memory_size"][0]);
runner.Fetch(graph["memory_size"][0]);
if (graph[graphScope + VectorObservationPlacholderName] != null)
if (graph[VectorObservationPlacholderName] != null)
if (graph[graphScope + PreviousActionPlaceholderName] != null)
if (graph[PreviousActionPlaceholderName] != null)
if (graph[graphScope + "value_estimate"] != null)
if (graph["value_estimate"] != null)
if (graph[graphScope + ActionMaskPlaceholderName] != null)
if (graph[ActionMaskPlaceholderName] != null)
{
hasMaskedActions = true;
}

var runner = session.GetRunner();
try
{
runner.Fetch(graph[graphScope + ActionPlaceholderName][0]);
runner.Fetch(graph[ActionPlaceholderName][0]);
@"The node {0} could not be found. Please make sure the graphScope {1} is correct",
graphScope + ActionPlaceholderName, graphScope));
@"The node {0} could not be found. Please make sure the node name is correct",
ActionPlaceholderName));
runner.AddInput(graph[graphScope + BatchSizePlaceholderName][0], new int[] {currentBatchSize});
runner.AddInput(graph[BatchSizePlaceholderName][0], new int[] {currentBatchSize});
}
foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)

if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.FloatingPoint)
{
runner.AddInput(graph[graphScope + placeholder.name][0],
runner.AddInput(graph[placeholder.name][0],
runner.AddInput(graph[graphScope + placeholder.name][0],
runner.AddInput(graph[placeholder.name][0],
new int[] {Random.Range((int) placeholder.minValue, (int) placeholder.maxValue + 1)});
}
}

@"One of the Tensorflow placeholder cound nout be found.
In brain {0}, there are no {1} placeholder named {2}.",
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name));
brain.gameObject.name, placeholder.valueType.ToString(), placeholder.name));
}
}

runner.AddInput(graph[graphScope + VectorObservationPlacholderName][0], inputState);
runner.AddInput(graph[VectorObservationPlacholderName][0], inputState);
runner.AddInput(graph[graphScope + PreviousActionPlaceholderName][0], inputPrevAction);
runner.AddInput(graph[PreviousActionPlaceholderName][0], inputPrevAction);
runner.AddInput(graph[graphScope + ActionMaskPlaceholderName][0], maskedActions);
runner.AddInput(graph[ActionMaskPlaceholderName][0], maskedActions);
}
// Create the observation tensors

obsNumber++)
{
runner.AddInput(graph[graphScope + VisualObservationPlaceholderName[obsNumber]][0],
runner.AddInput(graph[VisualObservationPlaceholderName[obsNumber]][0],
runner.AddInput(graph[graphScope + "sequence_length"][0], 1);
runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
runner.AddInput(graph["sequence_length"][0], 1);
runner.AddInput(graph[RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[RecurrentOutPlaceholderName][0]);
runner.Fetch(graph[graphScope + "value_estimate"][0]);
runner.Fetch(graph["value_estimate"][0]);
}
TFTensor[] networkOutput;

{
EditorGUILayout.HelpBox("Please provide a tensorflow graph as a bytes file.", MessageType.Error);
}
graphScope =
EditorGUILayout.TextField(new GUIContent("Graph Scope",
"If you set a scope while training your tensorflow model, " +
"all your placeholder name will have a prefix. You must specify that prefix here."), graphScope);
if (BatchSizePlaceholderName == "")
{
BatchSizePlaceholderName = "batch_size";

2
docs/Basic-Guide.md


sequence_length: 64
summary_freq: 1000
use_recurrent: False
graph_scope:
model_path: ./models/first-run-0/Ball3DBrain
INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 1000. Mean Reward: 1.242. Std of Reward: 0.746. Training.
INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 2000. Mean Reward: 1.319. Std of Reward: 0.693. Training.
INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 3000. Mean Reward: 1.804. Std of Reward: 1.056. Training.

22
docs/FAQ.md


Internal mode, your TensorFlowSharp plugin is imported and the
ENABLE_TENSORFLOW flag is set. This fix is only valid locally and unstable.
## Tensorflow epsilon placeholder error
If you have a graph placeholder set in the Internal Brain inspector that is not
present in the TensorFlow graph, you will see some error like this:
```console
UnityAgentsException: One of the TensorFlow placeholder could not be found. In brain <some_brain_name>, there are no FloatingPoint placeholder named <some_placeholder_name>.
```
Solution: Go to all of your Brain object, find `Graph placeholders` and change
its `size` to 0 to remove the `epsilon` placeholder.
Similarly, if you have a graph scope set in the Internal Brain inspector that is
not correctly set, you will see some error like this:
```console
UnityAgentsException: The node <Wrong_Graph_Scope>/action could not be found. Please make sure the graphScope <Wrong_Graph_Scope>/ is correct
```
Solution: Make sure your Graph Scope field matches the corresponding Brain
object name in your Hierarchy Inspector when there are multiple Brains.
## Environment Permission Error
If you directly import your Unity environment without building it in the

5
docs/Learning-Environment-Design-External-Internal-Brains.md


Only change the following Internal Brain properties if you have created your own
TensorFlow model and are not using an ML-Agents model:
* `Graph Scope` : If you set a scope while training your TensorFlow model, all
your placeholder name will have a prefix. You must specify that prefix here.
Note that if more than one Brain were set to external during training, you
must give a `Graph Scope` to the Internal Brain corresponding to the name of
the Brain GameObject.
* `Batch Size Node Name` : If the batch size is one of the inputs of your
graph, you must specify the name if the placeholder here. The Brain will make
the batch size equal to the number of Agents connected to the Brain

2
docs/Learning-Environment-Executable.md


sequence_length: 64
summary_freq: 1000
use_recurrent: False
graph_scope:
model_path: ./models/first-run-0/Ball3DBrain
INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 1000. Mean Reward: 1.242. Std of Reward: 0.746. Training.
INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 2000. Mean Reward: 1.319. Std of Reward: 0.693. Training.
INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 3000. Mean Reward: 1.804. Std of Reward: 1.056. Training.

4
docs/Using-TensorFlow-Sharp-in-Unity.md


both the graph and associated weights. Note that you must save your graph as a
.bytes file so Unity can load it.
In the Unity Editor, you must specify the names of the nodes used by your graph
in the **Internal** Brain Inspector window. If you used a scope when defining
your graph, specify it in the `Graph Scope` field.
![Internal Brain Inspector](images/internal_brain.png)
See

93
ml-agents/mlagents/trainers/bc/models.py


class BehavioralCloningModel(LearningModel):
def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128,
normalize=False, use_recurrent=False, scope='PPO', seed=0):
with tf.variable_scope(scope):
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed)
num_streams = 1
hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers)
hidden = hidden_streams[0]
self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate")
hidden_reg = tf.layers.dropout(hidden, self.dropout_rate)
if self.use_recurrent:
tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)
self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in')
hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in,
self.sequence_length)
self.memory_out = tf.identity(self.memory_out, name='recurrent_out')
normalize=False, use_recurrent=False, seed=0):
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed)
num_streams = 1
hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers)
hidden = hidden_streams[0]
self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate")
hidden_reg = tf.layers.dropout(hidden, self.dropout_rate)
if self.use_recurrent:
tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)
self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in')
hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in,
self.sequence_length)
self.memory_out = tf.identity(self.memory_out, name='recurrent_out')
if brain.vector_action_space_type == "discrete":
policy_branches = []
for size in self.act_size:
policy_branches.append(
tf.layers.dense(
hidden,
size,
activation=None,
use_bias=False,
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)))
self.action_probs = tf.concat(
[tf.nn.softmax(branch) for branch in policy_branches], axis=1, name="action_probs")
self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks")
self.sample_action_float, _ = self.create_discrete_action_masking_layer(
tf.concat(policy_branches, axis = 1), self.action_masks, self.act_size)
self.sample_action_float = tf.identity(self.sample_action_float, name="action")
self.sample_action = tf.cast(self.sample_action_float, tf.int32)
self.true_action = tf.placeholder(shape=[None, len(policy_branches)], dtype=tf.int32, name="teacher_action")
self.action_oh = tf.concat([
tf.one_hot(self.true_action[:, i], self.act_size[i]) for i in range(len(self.act_size))], axis=1)
self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh)
self.action_percent = tf.reduce_mean(tf.cast(
tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32))
else:
self.policy = tf.layers.dense(hidden_reg, self.act_size[0], activation=None, use_bias=False, name='pre_action',
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))
self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1)
self.sample_action = tf.identity(self.clipped_sample_action, name="action")
self.true_action = tf.placeholder(shape=[None, self.act_size[0]], dtype=tf.float32, name="teacher_action")
self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1)
self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action))
if brain.vector_action_space_type == "discrete":
policy_branches = []
for size in self.act_size:
policy_branches.append(
tf.layers.dense(
hidden,
size,
activation=None,
use_bias=False,
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)))
self.action_probs = tf.concat(
[tf.nn.softmax(branch) for branch in policy_branches], axis=1, name="action_probs")
self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks")
self.sample_action_float, _ = self.create_discrete_action_masking_layer(
tf.concat(policy_branches, axis = 1), self.action_masks, self.act_size)
self.sample_action_float = tf.identity(self.sample_action_float, name="action")
self.sample_action = tf.cast(self.sample_action_float, tf.int32)
self.true_action = tf.placeholder(shape=[None, len(policy_branches)], dtype=tf.int32, name="teacher_action")
self.action_oh = tf.concat([
tf.one_hot(self.true_action[:, i], self.act_size[i]) for i in range(len(self.act_size))], axis=1)
self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh)
self.action_percent = tf.reduce_mean(tf.cast(
tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32))
else:
self.policy = tf.layers.dense(hidden_reg, self.act_size[0], activation=None, use_bias=False, name='pre_action',
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))
self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1)
self.sample_action = tf.identity(self.clipped_sample_action, name="action")
self.true_action = tf.placeholder(shape=[None, self.act_size[0]], dtype=tf.float32, name="teacher_action")
self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1)
self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action))
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
self.update = optimizer.minimize(self.loss)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
self.update = optimizer.minimize(self.loss)

32
ml-agents/mlagents/trainers/bc/policy.py


class BCPolicy(Policy):
def __init__(self, seed, brain, trainer_parameters, sess):
def __init__(self, seed, brain, trainer_parameters, load):
:param sess: TensorFlow session.
:param load: Whether a pre-trained model will be loaded or a new one created.
super().__init__(seed, brain, trainer_parameters, sess)
super().__init__(seed, brain, trainer_parameters)
self.model = BehavioralCloningModel(
h_size=int(trainer_parameters['hidden_units']),
lr=float(trainer_parameters['learning_rate']),
n_layers=int(trainer_parameters['num_layers']),
m_size=self.m_size,
normalize=False,
use_recurrent=trainer_parameters['use_recurrent'],
brain=brain,
scope=self.variable_scope,
seed=seed)
with self.graph.as_default():
with self.graph.as_default():
self.model = BehavioralCloningModel(
h_size=int(trainer_parameters['hidden_units']),
lr=float(trainer_parameters['learning_rate']),
n_layers=int(trainer_parameters['num_layers']),
m_size=self.m_size,
normalize=False,
use_recurrent=trainer_parameters['use_recurrent'],
brain=brain,
seed=seed)
if load:
self._load_graph()
else:
self._initialize_graph()
self.inference_dict = {'action': self.model.sample_action}
self.update_dict = {'policy_loss': self.model.loss,

20
ml-agents/mlagents/trainers/bc/trainer.py


class BehavioralCloningTrainer(Trainer):
"""The ImitationTrainer is an implementation of the imitation learning."""
def __init__(self, sess, brain, trainer_parameters, training, seed, run_id):
def __init__(self, brain, trainer_parameters, training, load, seed, run_id):
:param sess: Tensorflow session.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param run_id: The The identifier of the current run
super(BehavioralCloningTrainer, self).__init__(sess, brain, trainer_parameters, training, run_id)
'graph_scope', 'summary_freq', 'max_steps',
'summary_freq', 'max_steps',
'hidden_units','learning_rate', 'num_layers',
'sequence_length', 'memory_size']
'hidden_units', 'learning_rate', 'num_layers',
'sequence_length', 'memory_size', 'model_path']
print(k)
print(k not in trainer_parameters)
self.policy = BCPolicy(seed, brain, trainer_parameters, sess)
super(BehavioralCloningTrainer, self).__init__(brain, trainer_parameters, training, run_id)
self.policy = BCPolicy(seed, brain, trainer_parameters, load)
self.brain_name = brain.brain_name
self.brain_to_imitate = trainer_parameters['brain_to_imitate']
self.batches_per_epoch = trainer_parameters['batches_per_epoch']

83
ml-agents/mlagents/trainers/policy.py


import logging
import numpy as np
import tensorflow as tf
from tensorflow.python.tools import freeze_graph
logger = logging.getLogger("mlagents.trainers")

Contains a learning model, and the necessary
functions to interact with it to perform evaluate and updating.
"""
possible_output_nodes = ['action', 'value_estimate',
'action_probs', 'recurrent_out', 'memory_size']
def __init__(self, seed, brain, trainer_parameters, sess):
def __init__(self, seed, brain, trainer_parameters):
:param sess: The current TensorFlow session.
"""
self.m_size = None
self.model = None

self.seed = seed
self.brain = brain
self.variable_scope = trainer_parameters['graph_scope']
self.sess = sess
self.model_path = trainer_parameters["model_path"]
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config, graph=self.graph)
self.saver = None
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]

"but it must be divisible by 4."
.format(brain.brain_name, self.m_size))
def _initialize_graph(self):
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
init = tf.global_variables_initializer()
self.sess.run(init)
def _load_graph(self):
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
logger.info('Loading Model for brain {}'.format(self.brain.brain_name))
ckpt = tf.train.get_checkpoint_state(self.model_path)
if ckpt is None:
logger.info('The model {0} could not be found. Make '
'sure you specified the right '
'--run-id'
.format(self.model_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
def evaluate(self, brain_info):
"""
Evaluates policy for the agent experiences provided.

:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.m_size))
@property
def graph_scope(self):
"""
Returns the graph scope of the trainer.
"""
return self.variable_scope
def get_current_step(self):
"""

:return:list of update var names
"""
return list(self.update_dict.keys())
def save_model(self, steps):
"""
Saves the model
:param steps: The number of steps the model was trained for
:return:
"""
with self.graph.as_default():
last_checkpoint = self.model_path + '/model-' + str(steps) + '.cptk'
self.saver.save(self.sess, last_checkpoint)
tf.train.write_graph(self.graph, self.model_path,
'raw_graph_def.pb', as_text=False)
def export_model(self):
"""
Exports latest saved model to .bytes format for Unity embedding.
"""
with self.graph.as_default():
target_nodes = ','.join(self._process_graph())
ckpt = tf.train.get_checkpoint_state(self.model_path)
freeze_graph.freeze_graph(
input_graph=self.model_path + '/raw_graph_def.pb',
input_binary=True,
input_checkpoint=ckpt.model_checkpoint_path,
output_node_names=target_nodes,
output_graph=(self.model_path + '.bytes'),
clear_devices=True, initializer_nodes='', input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0')
def _process_graph(self):
"""
Gets the list of the output nodes present in the graph for inference
:return: list of node names
"""
all_nodes = [x.name for x in self.graph.as_graph_def().node]
nodes = [x for x in all_nodes if x in self.possible_output_nodes]
logger.info('List of nodes to export for brain :' + self.brain.brain_name)
for n in nodes:
logger.info('\t' + n)
return nodes
@property
def vis_obs_size(self):

39
ml-agents/mlagents/trainers/ppo/models.py


class PPOModel(LearningModel):
def __init__(self, brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6,
normalize=False, use_recurrent=False, num_layers=2, m_size=None, use_curiosity=False,
curiosity_strength=0.01, curiosity_enc_size=128, scope='Model', seed=0):
curiosity_strength=0.01, curiosity_enc_size=128, seed=0):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.

:param num_layers Number of hidden layers between encoded input and policy & value layers
:param m_size: Size of brain memory.
"""
with tf.variable_scope(scope):
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed)
self.use_curiosity = use_curiosity
if num_layers < 1:
num_layers = 1
self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder()
if brain.vector_action_space_type == "continuous":
self.create_cc_actor_critic(h_size, num_layers)
self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers)
if self.use_curiosity:
self.curiosity_enc_size = curiosity_enc_size
self.curiosity_strength = curiosity_strength
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_ppo_optimizer(self.log_probs, self.old_log_probs, self.value,
self.entropy, beta, epsilon, lr, max_step)
LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed)
self.use_curiosity = use_curiosity
if num_layers < 1:
num_layers = 1
self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder()
if brain.vector_action_space_type == "continuous":
self.create_cc_actor_critic(h_size, num_layers)
self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy
else:
self.create_dc_actor_critic(h_size, num_layers)
if self.use_curiosity:
self.curiosity_enc_size = curiosity_enc_size
self.curiosity_strength = curiosity_strength
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_ppo_optimizer(self.log_probs, self.old_log_probs, self.value,
self.entropy, beta, epsilon, lr, max_step)
@staticmethod
def create_reward_encoder():

42
ml-agents/mlagents/trainers/ppo/policy.py


import logging
import numpy as np
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.policy import Policy

class PPOPolicy(Policy):
def __init__(self, seed, brain, trainer_params, sess, is_training):
def __init__(self, seed, brain, trainer_params, is_training, load):
:param sess: TensorFlow session.
:param load: Whether a pre-trained model will be loaded or a new one created.
super().__init__(seed, brain, trainer_params, sess)
super().__init__(seed, brain, trainer_params)
self.model = PPOModel(brain,
lr=float(trainer_params['learning_rate']),
h_size=int(trainer_params['hidden_units']),
epsilon=float(trainer_params['epsilon']),
beta=float(trainer_params['beta']),
max_step=float(trainer_params['max_steps']),
normalize=trainer_params['normalize'],
use_recurrent=trainer_params['use_recurrent'],
num_layers=int(trainer_params['num_layers']),
m_size=self.m_size,
use_curiosity=bool(trainer_params['use_curiosity']),
curiosity_strength=float(trainer_params['curiosity_strength']),
curiosity_enc_size=float(trainer_params['curiosity_enc_size']),
scope=self.variable_scope, seed=seed)
with self.graph.as_default():
self.model = PPOModel(brain,
lr=float(trainer_params['learning_rate']),
h_size=int(trainer_params['hidden_units']),
epsilon=float(trainer_params['epsilon']),
beta=float(trainer_params['beta']),
max_step=float(trainer_params['max_steps']),
normalize=trainer_params['normalize'],
use_recurrent=trainer_params['use_recurrent'],
num_layers=int(trainer_params['num_layers']),
m_size=self.m_size,
use_curiosity=bool(trainer_params['use_curiosity']),
curiosity_strength=float(trainer_params['curiosity_strength']),
curiosity_enc_size=float(trainer_params['curiosity_enc_size']),
seed=seed)
if load:
self._load_graph()
else:
self._initialize_graph()
self.inference_dict = {'action': self.model.output, 'log_probs': self.model.all_log_probs,
'value': self.model.value, 'entropy': self.model.entropy,

17
ml-agents/mlagents/trainers/ppo/trainer.py


class PPOTrainer(Trainer):
"""The PPOTrainer is an implementation of the PPO algorithm."""
def __init__(self, sess, brain, reward_buff_cap, trainer_parameters, training, seed, run_id):
def __init__(self, brain, reward_buff_cap, trainer_parameters, training, load, seed, run_id):
:param sess: Tensorflow session.
:param trainer_parameters: The parameters for the trainer (dictionary).
:param trainer_parameters: The parameters for the trainer (dictionary).
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param run_id: The The identifier of the current run
super(PPOTrainer, self).__init__(sess, brain.brain_name, trainer_parameters, training, run_id)
'graph_scope', 'summary_path', 'memory_size', 'use_curiosity', 'curiosity_strength',
'curiosity_enc_size']
'summary_path', 'memory_size', 'use_curiosity', 'curiosity_strength',
'curiosity_enc_size', 'model_path']
super(PPOTrainer, self).__init__(brain.brain_name, trainer_parameters, training, run_id)
self.use_curiosity = bool(trainer_parameters['use_curiosity'])

sess, self.is_training)
self.is_training, load)
stats = {'cumulative_reward': [], 'episode_length': [], 'value_estimate': [],
'entropy': [], 'value_loss': [], 'policy_loss': [], 'learning_rate': []}

28
ml-agents/mlagents/trainers/trainer.py


class Trainer(object):
"""This class is the abstract class for the mlagents.trainers"""
def __init__(self, sess, brain_name, trainer_parameters, training, run_id):
def __init__(self, brain_name, trainer_parameters, training, run_id):
:param sess: Tensorflow session.
:param run_id: The identifier of the current run
self.sess = sess
self.brain_name = brain_name
self.run_id = run_id
self.trainer_parameters = trainer_parameters

self.policy = None
def __str__(self):
return '''Empty Trainer'''

"""
raise UnityTrainerException("The update_model method was not implemented.")
def save_model(self, steps):
"""
Saves the model
:param steps: The number of steps of training
"""
self.policy.save_model(steps)
def export_model(self):
"""
Exports the model
"""
self.policy.export_model()
def write_summary(self, global_step, lesson_num=0):
"""
Saves training statistics to Tensorboard.

:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
s_op = tf.summary.text(key, tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])))
s = self.sess.run(s_op)
self.summary_writer.add_summary(s, self.get_step)
with tf.Session() as sess:
s_op = tf.summary.text(key, tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])))
s = sess.run(s_op)
self.summary_writer.add_summary(s, self.get_step)
except:
logger.info(
"Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.")

286
ml-agents/mlagents/trainers/trainer_controller.py


else:
return None
def _process_graph(self):
nodes = []
scopes = []
for brain_name in self.trainers.keys():
if self.trainers[brain_name].policy.graph_scope is not None:
scope = self.trainers[brain_name].policy.graph_scope + '/'
if scope == '/':
scope = ''
scopes += [scope]
if self.trainers[brain_name].parameters['trainer'] \
== 'imitation':
nodes += [scope + x for x in ['action']]
else:
nodes += [scope + x for x in ['action', 'value_estimate',
'action_probs',
'value_estimate']]
if self.trainers[brain_name].parameters['use_recurrent']:
nodes += [scope + x for x in ['recurrent_out',
'memory_size']]
if len(scopes) > 1:
self.logger.info('List of available scopes :')
for scope in scopes:
self.logger.info('\t' + scope)
self.logger.info('List of nodes to export :')
for n in nodes:
self.logger.info('\t' + n)
return nodes
def _save_model(self, sess, saver, steps=0):
def _save_model(self,steps=0):
:param sess: Current Tensorflow session.
last_checkpoint = self.model_path + '/model-' + str(steps) + '.cptk'
saver.save(sess, last_checkpoint)
tf.train.write_graph(sess.graph_def, self.model_path,
'raw_graph_def.pb', as_text=False)
for brain_name in self.trainers.keys():
self.trainers[brain_name].save_model(steps)
Exports latest saved model to .bytes format for Unity embedding.
Exports latest saved models to .bytes format for Unity embedding.
target_nodes = ','.join(self._process_graph())
ckpt = tf.train.get_checkpoint_state(self.model_path)
freeze_graph.freeze_graph(
input_graph=self.model_path + '/raw_graph_def.pb',
input_binary=True,
input_checkpoint=ckpt.model_checkpoint_path,
output_node_names=target_nodes,
output_graph=(self.model_path + '/' + self.env_name + '_'
+ self.run_id + '.bytes'),
clear_devices=True, initializer_nodes='', input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0')
for brain_name in self.trainers.keys():
self.trainers[brain_name].export_model()
def _initialize_trainers(self, trainer_config, sess):
def _initialize_trainers(self, trainer_config):
"""
Initialization of the trainers
:param trainer_config: The configurations of the trainers
"""
# TODO: This probably doesn't need to be reinitialized.
self.trainers = {}
if len(self.env.external_brain_names) > 1:
graph_scope = re.sub('[^0-9a-zA-Z]+', '-', brain_name)
trainer_parameters['graph_scope'] = graph_scope
trainer_parameters['summary_path'] = '{basedir}/{name}'.format(
basedir=self.summaries_dir,
name=str(self.run_id) + '_' + graph_scope)
else:
trainer_parameters['graph_scope'] = ''
trainer_parameters['summary_path'] = '{basedir}/{name}'.format(
basedir=self.summaries_dir,
name=str(self.run_id))
trainer_parameters['summary_path'] = '{basedir}/{name}'.format(
basedir=self.summaries_dir,
name=str(self.run_id) + '_' + brain_name)
trainer_parameters['model_path'] = '{basedir}/{name}'.format(
basedir=self.model_path,
name=brain_name)
trainer_parameters['keep_checkpoints'] = self.keep_checkpoints
if brain_name in trainer_config:
_brain_key = brain_name
while not isinstance(trainer_config[_brain_key], dict):

for brain_name in self.env.external_brain_names:
if trainer_parameters_dict[brain_name]['trainer'] == 'imitation':
self.trainers[brain_name] = BehavioralCloningTrainer(
sess, self.env.brains[brain_name],
self.env.brains[brain_name],
self.seed, self.run_id)
self.load_model, self.seed, self.run_id)
sess, self.env.brains[brain_name],
self.env.brains[brain_name],
self.train_model, self.seed, self.run_id)
self.train_model, self.load_model, self.seed, self.run_id)
else:
raise UnityEnvironmentException('The trainer config contains '
'an unknown trainer type for '

tf.reset_default_graph()
# Prevent a single session from taking all GPU memory.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
self._initialize_trainers(trainer_config, sess)
for _, t in self.trainers.items():
self.logger.info(t)
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
# Instantiate model parameters
if self.load_model:
self.logger.info('Loading Model...')
ckpt = tf.train.get_checkpoint_state(self.model_path)
if ckpt is None:
self.logger.info('The model {0} could not be found. Make '
'sure you specified the right '
'--run-id'
.format(self.model_path))
saver.restore(sess, ckpt.model_checkpoint_path)
else:
sess.run(init)
global_step = 0 # This is only for saving the model
curr_info = self._reset_env()
if self.train_model:
for brain_name, trainer in self.trainers.items():
trainer.write_tensorboard_text('Hyperparameters',
trainer.parameters)
try:
while any([t.get_step <= t.get_max_steps \
for k, t in self.trainers.items()]) \
or not self.train_model:
if self.meta_curriculum:
# Get the sizes of the reward buffers.
reward_buff_sizes = {k:len(t.reward_buffer) \
for (k,t) in self.trainers.items()}
# Attempt to increment the lessons of the brains who
# were ready.
lessons_incremented = \
self.meta_curriculum.increment_lessons(
self._get_measure_vals(),
reward_buff_sizes=reward_buff_sizes)
# If any lessons were incremented or the environment is
# ready to be reset
if (self.meta_curriculum
and any(lessons_incremented.values())):
curr_info = self._reset_env()
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
for brain_name, changed in lessons_incremented.items():
if changed:
self.trainers[brain_name].reward_buffer.clear()
elif self.env.global_done:
curr_info = self._reset_env()
for brain_name, trainer in self.trainers.items():
trainer.end_episode()
self._initialize_trainers(trainer_config)
for _, t in self.trainers.items():
self.logger.info(t)
global_step = 0 # This is only for saving the model
curr_info = self._reset_env()
if self.train_model:
for brain_name, trainer in self.trainers.items():
trainer.write_tensorboard_text('Hyperparameters',
trainer.parameters)
try:
while any([t.get_step <= t.get_max_steps \
for k, t in self.trainers.items()]) \
or not self.train_model:
if self.meta_curriculum:
# Get the sizes of the reward buffers.
reward_buff_sizes = {k:len(t.reward_buffer) \
for (k,t) in self.trainers.items()}
# Attempt to increment the lessons of the brains who
# were ready.
lessons_incremented = \
self.meta_curriculum.increment_lessons(
self._get_measure_vals(),
reward_buff_sizes=reward_buff_sizes)
# Decide and take an action
take_action_vector, \
take_action_memories, \
take_action_text, \
take_action_value, \
take_action_outputs \
= {}, {}, {}, {}, {}
# If any lessons were incremented or the environment is
# ready to be reset
if (self.meta_curriculum
and any(lessons_incremented.values())):
curr_info = self._reset_env()
(take_action_vector[brain_name],
take_action_memories[brain_name],
take_action_text[brain_name],
take_action_value[brain_name],
take_action_outputs[brain_name]) = \
trainer.take_action(curr_info)
new_info = self.env.step(vector_action=take_action_vector,
memory=take_action_memories,
text_action=take_action_text,
value=take_action_value)
trainer.end_episode()
for brain_name, changed in lessons_incremented.items():
if changed:
self.trainers[brain_name].reward_buffer.clear()
elif self.env.global_done:
curr_info = self._reset_env()
trainer.add_experiences(curr_info, new_info,
take_action_outputs[brain_name])
trainer.process_experiences(curr_info, new_info)
if trainer.is_ready_update() and self.train_model \
and trainer.get_step <= trainer.get_max_steps:
# Perform gradient descent with experience buffer
trainer.update_policy()
# Write training statistics to Tensorboard.
if self.meta_curriculum is not None:
trainer.write_summary(
global_step,
lesson_num=self.meta_curriculum
.brains_to_curriculums[brain_name]
.lesson_num)
else:
trainer.write_summary(global_step)
if self.train_model \
and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step_and_update_last_reward()
global_step += 1
if global_step % self.save_freq == 0 and global_step != 0 \
and self.train_model:
# Save Tensorflow model
self._save_model(sess, steps=global_step, saver=saver)
curr_info = new_info
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model(sess, steps=global_step, saver=saver)
except KeyboardInterrupt:
print('--------------------------Now saving model--------------'
'-----------')
if self.train_model:
self.logger.info('Learning was interrupted. Please wait '
'while the graph is generated.')
self._save_model(sess, steps=global_step, saver=saver)
pass
trainer.end_episode()
# Decide and take an action
take_action_vector, \
take_action_memories, \
take_action_text, \
take_action_value, \
take_action_outputs \
= {}, {}, {}, {}, {}
for brain_name, trainer in self.trainers.items():
(take_action_vector[brain_name],
take_action_memories[brain_name],
take_action_text[brain_name],
take_action_value[brain_name],
take_action_outputs[brain_name]) = \
trainer.take_action(curr_info)
new_info = self.env.step(vector_action=take_action_vector,
memory=take_action_memories,
text_action=take_action_text,
value=take_action_value)
for brain_name, trainer in self.trainers.items():
trainer.add_experiences(curr_info, new_info,
take_action_outputs[brain_name])
trainer.process_experiences(curr_info, new_info)
if trainer.is_ready_update() and self.train_model \
and trainer.get_step <= trainer.get_max_steps:
# Perform gradient descent with experience buffer
trainer.update_policy()
# Write training statistics to Tensorboard.
if self.meta_curriculum is not None:
trainer.write_summary(
global_step,
lesson_num=self.meta_curriculum
.brains_to_curriculums[brain_name]
.lesson_num)
else:
trainer.write_summary(global_step)
if self.train_model \
and trainer.get_step <= trainer.get_max_steps:
trainer.increment_step_and_update_last_reward()
global_step += 1
if global_step % self.save_freq == 0 and global_step != 0 \
and self.train_model:
# Save Tensorflow model
self._save_model(steps=global_step)
curr_info = new_info
# Final save Tensorflow model
if global_step != 0 and self.train_model:
self._save_model(steps=global_step)
except KeyboardInterrupt:
print('--------------------------Now saving model--------------'
'-----------')
if self.train_model:
self.logger.info('Learning was interrupted. Please wait '
'while the graph is generated.')
self._save_model(steps=global_step)
pass
self.env.close()
if self.train_model:
self._export_graph()

26
ml-agents/tests/trainers/test_bc.py


@mock.patch('mlagents.envs.UnityEnvironment.get_communicator')
def test_bc_policy_evaluate(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0)
env = UnityEnvironment(' ')
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0)
env = UnityEnvironment(' ')
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
trainer_parameters = dummy_config()
graph_scope = env.brain_names[0]
trainer_parameters['graph_scope'] = graph_scope
policy = BCPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, sess)
init = tf.global_variables_initializer()
sess.run(init)
run_out = policy.evaluate(brain_info)
assert run_out['action'].shape == (3, 2)
trainer_parameters = dummy_config()
model_path = env.brain_names[0]
trainer_parameters['model_path'] = model_path
trainer_parameters['keep_checkpoints'] = 3
policy = BCPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, False)
run_out = policy.evaluate(brain_info)
assert run_out['action'].shape == (3, 2)
env.close()

28
ml-agents/tests/trainers/test_ppo.py


@mock.patch('mlagents.envs.UnityEnvironment.get_communicator')
def test_ppo_policy_evaluate(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0)
env = UnityEnvironment(' ')
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0)
env = UnityEnvironment(' ')
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
trainer_parameters = dummy_config()
graph_scope = env.brain_names[0]
trainer_parameters['graph_scope'] = graph_scope
policy = PPOPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, sess, False)
init = tf.global_variables_initializer()
sess.run(init)
run_out = policy.evaluate(brain_info)
assert run_out['action'].shape == (3, 2)
env.close()
trainer_parameters = dummy_config()
model_path = env.brain_names[0]
trainer_parameters['model_path'] = model_path
trainer_parameters['keep_checkpoints'] = 3
policy = PPOPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, False, False)
run_out = policy.evaluate(brain_info)
assert run_out['action'].shape == (3, 2)
env.close()
@mock.patch('mlagents.envs.UnityEnvironment.executable_launcher')

19
ml-agents/tests/trainers/test_trainer_controller.py


with mock.patch(open_name, create=True) as _:
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=1)
tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 1,
tc = TrainerController(' ', ' ', 1, None, True, False, False, 1, 1,
1, 1, '', "tests/test_mlagents.trainers.py",
False)

tf.reset_default_graph()
with tf.Session() as sess:
tc._initialize_trainers(config, sess)
assert(len(tc.trainers) == 1)
assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer))
tc._initialize_trainers(config)
assert(len(tc.trainers) == 1)
assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer))
with tf.Session() as sess:
tc._initialize_trainers(config, sess)
assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer))
tc._initialize_trainers(config)
assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer))
with tf.Session() as sess:
with pytest.raises(UnityEnvironmentException):
tc._initialize_trainers(config, sess)
with pytest.raises(UnityEnvironmentException):
tc._initialize_trainers(config)
正在加载...
取消
保存