浏览代码

Merge branch 'master' into dev-broadcast

/develop-generalizationTraining-TrainerController
vincentpierre 7 年前
当前提交
3f85bb56
共有 47 个文件被更改,包括 2683 次插入905 次删除
  1. 10
      README.md
  2. 12
      python/PPO.ipynb
  3. 4
      python/README.md
  4. 9
      python/ppo.py
  5. 209
      python/ppo/models.py
  6. 50
      python/ppo/trainer.py
  7. 2
      python/setup.py
  8. 992
      unity-environment/Assets/ML-Agents/Examples/3DBall/Scene.unity
  9. 24
      unity-environment/Assets/ML-Agents/Scripts/Brain.cs
  10. 47
      unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs
  11. 4
      unity-environment/README.md
  12. 9
      unity-environment/Assets/ML-Agents/Examples/3DBall/Prefabs.meta
  13. 9
      unity-environment/Assets/ML-Agents/Examples/Basic.meta
  14. 71
      docs/Agents-Editor-Interface.md
  15. 58
      docs/Example-Environments.md
  16. 134
      docs/Getting-Started-with-Balance-Ball.md
  17. 43
      docs/Limitations-&-Common-Issues.md
  18. 143
      docs/Making-a-new-Unity-Environment.md
  19. 33
      docs/Organizing-the-Scene.md
  20. 40
      docs/Training-on-Amazon-Web-Service.md
  21. 44
      docs/Unity-Agents---Python-API.md
  22. 43
      docs/Unity-Agents-Overview.md
  23. 112
      docs/Using-TensorFlow-Sharp-in-Unity-(Experimental).md
  24. 19
      docs/Readme.md
  25. 20
      docs/best-practices.md
  26. 51
      docs/installation.md
  27. 354
      unity-environment/Assets/ML-Agents/Examples/3DBall/Prefabs/Game.prefab
  28. 9
      unity-environment/Assets/ML-Agents/Examples/3DBall/Prefabs/Game.prefab.meta
  29. 9
      unity-environment/Assets/ML-Agents/Examples/Basic/Materials.meta
  30. 76
      unity-environment/Assets/ML-Agents/Examples/Basic/Materials/agent.mat
  31. 9
      unity-environment/Assets/ML-Agents/Examples/Basic/Materials/agent.mat.meta
  32. 76
      unity-environment/Assets/ML-Agents/Examples/Basic/Materials/goal.mat
  33. 9
      unity-environment/Assets/ML-Agents/Examples/Basic/Materials/goal.mat.meta
  34. 702
      unity-environment/Assets/ML-Agents/Examples/Basic/Scene.unity
  35. 8
      unity-environment/Assets/ML-Agents/Examples/Basic/Scene.unity.meta
  36. 9
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts.meta
  37. 17
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAcademy.cs
  38. 12
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAcademy.cs.meta
  39. 64
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs
  40. 12
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs.meta
  41. 18
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicDecision.cs
  42. 12
      unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicDecision.cs.meta

10
README.md


<img src="images/unity-wide.png" align="middle" width="3000"/>
# Unity ML - Agents
# Unity ML - Agents (Beta)
**Unity Machine Learning Agents** allows researchers and developers to
create games and simulations using the Unity Editor which serve as

the [wiki page](../../wiki).
the [documentation page](docs).
[here](../../wiki/Getting-Started-with-Balance-Ball).
[here](docs/Getting-Started-with-Balance-Ball.md).
## Features
* Unity Engine flexibility and simplicity

The _Agents SDK_, including example environment scenes is located in
`unity-environment` folder. For requirements, instructions, and other
information, see the contained Readme and the relevant
[wiki page](../../wiki/Making-a-new-Unity-Environment).
[documentation](docs/Making-a-new-Unity-Environment.md).
## Training your Agents

contained Readme and the relevant
[wiki page](../../wiki/Unity-Agents---Python-API).
[documentation](docs/Unity-Agents---Python-API.md).

12
python/PPO.ipynb


"train_model = True # Whether to train the model.\n",
"summary_freq = 10000 # Frequency at which to save training statistics.\n",
"save_freq = 50000 # Frequency at which to save model.\n",
"env_name = \"simple\" # Name of the training environment file.\n",
"env_name = \"environment\" # Name of the training environment file.\n",
"\n",
"### Algorithm-specific parameters for tuning\n",
"gamma = 0.99 # Reward discount rate.\n",

{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"env = UnityEnvironment(file_name=env_name)\n",

"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"scrolled": true
},
"outputs": [],

"# Create the Tensorflow model graph\n",
"ppo_model = create_agent_model(env, lr=learning_rate,\n",
" h_size=hidden_units, epsilon=epsilon,\n",
" beta=beta)\n",
" beta=beta, max_step=max_steps)\n",
"use_states = (env.brains[brain_name].state_space_size > 0)\n",
"\n",
"model_path = './models/{}'.format(run_path)\n",
"summary_path = './summaries/{}'.format(run_path)\n",

" steps = sess.run(ppo_model.global_step)\n",
" summary_writer = tf.summary.FileWriter(summary_path)\n",
" info = env.reset(train_mode=train_model)[brain_name]\n",
" trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations)\n",
" trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations, use_states)\n",
" while steps <= max_steps:\n",
" if env.global_done:\n",
" info = env.reset(train_mode=train_model)[brain_name]\n",

4
python/README.md


To monitor training progress, run the following from the root directory of this repo:
`tensorboard --logdir='./summaries'`
`tensorboard --logdir=summaries`
Then navigate to `localhost:6006` to monitor progress with Tensorboard.

`python3 ppo.py --help`
## Using Python API
See this [wiki page](https://github.com/Unity-Technologies/python-rl-control/wiki/Unity-Agents---Python-API) for a detailed description of the functions and uses of the Python API.
See this [documentation](../docs/Unity-Agents---Python-API.md) for a detailed description of the functions and uses of the Python API.
## Training on AWS
See this related [blog post](https://medium.com/towards-data-science/how-to-run-unity-on-amazon-cloud-or-without-monitor-3c10ce022639) for a description of how to run Unity Environments on AWS EC2 instances with the GPU.

9
python/ppo.py


Options:
--help Show this message.
--max-step=<n> Maximum number of steps to run environment [default: 5e6].
--max-steps=<n> Maximum number of steps to run environment [default: 1e6].
--run-path=<path> The sub-directory name for model and summary statistics [default: ppo].
--load Whether to load the model or randomly initialize [default: False].
--train Whether to train model, or only run inference [default: True].

print(options)
# General parameters
max_steps = float(options['--max-step'])
max_steps = float(options['--max-steps'])
model_path = './models/{}'.format(str(options['--run-path']))
summary_path = './summaries/{}'.format(str(options['--run-path']))
load_model = options['--load']

# Create the Tensorflow model graph
ppo_model = create_agent_model(env, lr=learning_rate,
h_size=hidden_units, epsilon=epsilon,
beta=beta)
beta=beta, max_step=max_steps)
use_states = (env.brains[brain_name].state_space_size > 0)
if not os.path.exists(model_path):
os.makedirs(model_path)

steps = sess.run(ppo_model.global_step)
summary_writer = tf.summary.FileWriter(summary_path)
info = env.reset(train_mode=train_model)[brain_name]
trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations)
trainer = Trainer(ppo_model, sess, info, is_continuous, use_observations, use_states)
while steps <= max_steps or not train_model:
if env.global_done:
info = env.reset(train_mode=train_model)[brain_name]

209
python/ppo/models.py


import tensorflow as tf
import tensorflow.contrib.layers as c_layers
from tensorflow.python.tools import freeze_graph
from unityagents import UnityEnvironmentException
def create_agent_model(env, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3):
def create_agent_model(env, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6):
Takes a Unity environment and model-specific hyperparameters and returns the
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.
:param env: a Unity environment.
:param lr: Learning rate.

:return: a sub-class of PPOAgent tailored to the environment.
:param max_step: Total number of training steps.
if env.brains[brain_name].action_space_type == "continuous":
return ContinuousControlModel(lr, env.brains[brain_name].state_space_size,
env.brains[brain_name].action_space_size, h_size, epsilon, beta)
if env.brains[brain_name].action_space_type == "discrete":
if env.brains[brain_name].number_observations == 0:
return DiscreteControlModel(lr, env.brains[brain_name].state_space_size,
env.brains[brain_name].action_space_size, h_size, epsilon, beta)
else:
brain = env.brains[brain_name]
h, w = brain.camera_resolutions[0]['height'], brain.camera_resolutions[0]['height']
return VisualDiscreteControlModel(lr, h, w, env.brains[brain_name].action_space_size, h_size, epsilon, beta)
brain = env.brains[brain_name]
if brain.action_space_type == "continuous":
return ContinuousControlModel(lr, brain, h_size, epsilon, max_step)
if brain.action_space_type == "discrete":
return DiscreteControlModel(lr, brain, h_size, epsilon, beta, max_step)
def save_model(sess, saver, model_path="./", steps=0):

:param steps: Current number of steps in training process.
:param saver: Tensorflow saver for session.
"""
last_checkpoint = model_path+'/model-'+str(steps)+'.cptk'
last_checkpoint = model_path + '/model-' + str(steps) + '.cptk'
saver.save(sess, last_checkpoint)
tf.train.write_graph(sess.graph_def, model_path, 'raw_graph_def.pb', as_text=False)
print("Saved Model")

class PPOModel(object):
def __init__(self, probs, old_probs, value, entropy, beta, epsilon, lr):
def create_visual_encoder(self, o_size_h, o_size_w, bw, h_size, num_streams, activation):
"""
Builds a set of visual (CNN) encoders.
:param o_size_h: Height observation size.
:param o_size_w: Width observation size.
:param bw: Whether image is greyscale {True} or color {False}.
:param h_size: Hidden layer size.
:param num_streams: Number of visual streams to construct.
:param activation: What type of activation function to use for layers.
:return: List of hidden layer tensors.
"""
if bw:
c_channels = 1
else:
c_channels = 3
self.observation_in = tf.placeholder(shape=[None, o_size_h, o_size_w, c_channels], dtype=tf.float32,
name='observation_0')
streams = []
for i in range(num_streams):
self.conv1 = tf.layers.conv2d(self.observation_in, 32, kernel_size=[3, 3], strides=[2, 2],
use_bias=False, activation=activation)
self.conv2 = tf.layers.conv2d(self.conv1, 64, kernel_size=[3, 3], strides=[2, 2],
use_bias=False, activation=activation)
hidden = tf.layers.dense(c_layers.flatten(self.conv2), h_size, use_bias=False, activation=activation)
streams.append(hidden)
return streams
def create_continuous_state_encoder(self, s_size, h_size, num_streams, activation):
"""
Builds a set of hidden state encoders.
:param s_size: state input size.
:param h_size: Hidden layer size.
:param num_streams: Number of state streams to construct.
:param activation: What type of activation function to use for layers.
:return: List of hidden layer tensors.
"""
self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32, name='state')
streams = []
for i in range(num_streams):
hidden_1 = tf.layers.dense(self.state_in, h_size, use_bias=False, activation=activation)
hidden_2 = tf.layers.dense(hidden_1, h_size, use_bias=False, activation=activation)
streams.append(hidden_2)
return streams
def create_discrete_state_encoder(self, s_size, h_size, num_streams, activation):
"""
Builds a set of hidden state encoders from discrete state input.
:param s_size: state input size (discrete).
:param h_size: Hidden layer size.
:param num_streams: Number of state streams to construct.
:param activation: What type of activation function to use for layers.
:return: List of hidden layer tensors.
"""
self.state_in = tf.placeholder(shape=[None, 1], dtype=tf.int32, name='state')
state_in = tf.reshape(self.state_in, [-1])
state_onehot = c_layers.one_hot_encoding(state_in, s_size)
streams = []
for i in range(num_streams):
hidden = tf.layers.dense(state_onehot, h_size, use_bias=False, activation=activation)
streams.append(hidden)
return streams
def create_ppo_optimizer(self, probs, old_probs, value, entropy, beta, epsilon, lr, max_step):
"""
Creates training-specific Tensorflow ops for PPO models.
:param probs: Current policy probabilities

:param entropy: Current policy entropy
:param epsilon: Value for policy-divergence threshold
:param lr: Learning rate
:param max_step: Total number of training steps.
"""
self.returns_holder = tf.placeholder(shape=[None], dtype=tf.float32, name='discounted_rewards')
self.advantage = tf.placeholder(shape=[None, 1], dtype=tf.float32, name='advantages')

self.loss = self.policy_loss + self.value_loss - beta * tf.reduce_mean(entropy)
optimizer = tf.train.AdamOptimizer(learning_rate=lr)
self.global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int32)
self.learning_rate = tf.train.polynomial_decay(lr, self.global_step,
max_step, 1e-10,
power=1.0)
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
self.global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int32)
self.increment_step = tf.assign(self.global_step, self.global_step+1)
self.increment_step = tf.assign(self.global_step, self.global_step + 1)
def __init__(self, lr, s_size, a_size, h_size, epsilon, beta):
def __init__(self, lr, brain, h_size, epsilon, max_step):
:param s_size: State-space size
:param a_size: Action-space size
:param brain: State-space size
self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32, name='state')
s_size = brain.state_space_size
a_size = brain.action_space_size
hidden_state, hidden_visual, hidden_policy, hidden_value = None, None, None, None
if brain.number_observations > 0:
h_size, w_size = brain.camera_resolutions[0]['height'], brain.camera_resolutions[0]['width']
bw = brain.camera_resolutions[0]['blackAndWhite']
hidden_visual = self.create_visual_encoder(h_size, w_size, bw, h_size, 2, tf.nn.tanh)
if brain.state_space_size > 0:
s_size = brain.state_space_size
if brain.state_space_type == "continuous":
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 2, tf.nn.tanh)
else:
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 2, tf.nn.tanh)
if hidden_visual is None and hidden_state is None:
raise Exception("No valid network configuration possible. "
"There are no states or observations in this brain")
elif hidden_visual is not None and hidden_state is None:
hidden_policy, hidden_value = hidden_visual
elif hidden_visual is None and hidden_state is not None:
hidden_policy, hidden_value = hidden_state
elif hidden_visual is not None and hidden_state is not None:
hidden_policy = tf.concat([hidden_visual[0], hidden_state[0]], axis=1)
hidden_value = tf.concat([hidden_visual[1], hidden_state[1]], axis=1)
hidden_policy = tf.layers.dense(self.state_in, h_size, use_bias=False, activation=tf.nn.tanh)
hidden_value = tf.layers.dense(self.state_in, h_size, use_bias=False, activation=tf.nn.tanh)
hidden_policy_2 = tf.layers.dense(hidden_policy, h_size, use_bias=False, activation=tf.nn.tanh)
hidden_value_2 = tf.layers.dense(hidden_value, h_size, use_bias=False, activation=tf.nn.tanh)
self.mu = tf.layers.dense(hidden_policy_2, a_size, activation=None, use_bias=False,
self.mu = tf.layers.dense(hidden_policy, a_size, activation=None, use_bias=False,
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.1))
self.log_sigma_sq = tf.Variable(tf.zeros([a_size]))
self.sigma_sq = tf.exp(self.log_sigma_sq)

self.entropy = tf.reduce_sum(0.5 * tf.log(2 * np.pi * np.e * self.sigma_sq))
self.value = tf.layers.dense(hidden_value_2, 1, activation=None, use_bias=False)
self.value = tf.layers.dense(hidden_value, 1, activation=None, use_bias=False)
PPOModel.__init__(self, self.probs, self.old_probs, self.value, self.entropy, 0.0, epsilon, lr)
self.create_ppo_optimizer(self.probs, self.old_probs, self.value, self.entropy, 0.0, epsilon, lr, max_step)
def __init__(self, lr, s_size, a_size, h_size, epsilon, beta):
def __init__(self, lr, brain, h_size, epsilon, beta, max_step):
:param s_size: State-space size
:param a_size: Action-space size
:param brain: State-space size
self.state_in = tf.placeholder(shape=[None, s_size], dtype=tf.float32, name='state')
self.batch_size = tf.placeholder(shape=None, dtype=tf.int32, name='batch_size')
hidden_1 = tf.layers.dense(self.state_in, h_size, use_bias=False, activation=tf.nn.elu)
hidden_2 = tf.layers.dense(hidden_1, h_size, use_bias=False, activation=tf.nn.elu)
self.policy = tf.layers.dense(hidden_2, a_size, activation=None, use_bias=False,
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.1))
self.probs = tf.nn.softmax(self.policy)
self.action = tf.multinomial(self.policy, 1)
self.output = tf.identity(self.action, name='action')
self.value = tf.layers.dense(hidden_2, 1, activation=None, use_bias=False)
self.entropy = -tf.reduce_sum(self.probs * tf.log(self.probs + 1e-10), axis=1)
self.action_holder = tf.placeholder(shape=[None], dtype=tf.int32)
self.selected_actions = c_layers.one_hot_encoding(self.action_holder, a_size)
self.old_probs = tf.placeholder(shape=[None, a_size], dtype=tf.float32, name='old_probabilities')
self.responsible_probs = tf.reduce_sum(self.probs * self.selected_actions, axis=1)
self.old_responsible_probs = tf.reduce_sum(self.old_probs * self.selected_actions, axis=1)
hidden_state, hidden_visual, hidden = None, None, None
if brain.number_observations > 0:
h_size, w_size = brain.camera_resolutions[0]['height'], brain.camera_resolutions[0]['width']
bw = brain.camera_resolutions[0]['blackAndWhite']
hidden_visual = self.create_visual_encoder(h_size, w_size, bw, h_size, 1, tf.nn.elu)[0]
if brain.state_space_size > 0:
s_size = brain.state_space_size
if brain.state_space_type == "continuous":
hidden_state = self.create_continuous_state_encoder(s_size, h_size, 1, tf.nn.elu)[0]
else:
hidden_state = self.create_discrete_state_encoder(s_size, h_size, 1, tf.nn.elu)[0]
PPOModel.__init__(self, self.responsible_probs, self.old_responsible_probs,
self.value, self.entropy, beta, epsilon, lr)
if hidden_visual is None and hidden_state is None:
raise Exception("No valid network configuration possible. "
"There are no states or observations in this brain")
elif hidden_visual is not None and hidden_state is None:
hidden = hidden_visual
elif hidden_visual is None and hidden_state is not None:
hidden = hidden_state
elif hidden_visual is not None and hidden_state is not None:
hidden = tf.concat([hidden_visual, hidden_state], axis=1)
a_size = brain.action_space_size
class VisualDiscreteControlModel(PPOModel):
def __init__(self, lr, o_size_h, o_size_w, a_size, h_size, epsilon, beta):
"""
Creates Discrete Control Actor-Critic model for use with visual observations (images).
:param o_size_h: Observation height.
:param o_size_w: Observation width.
:param a_size: Action-space size.
:param h_size: Hidden layer size.
"""
self.observation_in = tf.placeholder(shape=[None, o_size_h, o_size_w, 1], dtype=tf.float32,
name='observation_0')
self.conv1 = tf.layers.conv2d(self.observation_in, 32, kernel_size=[3, 3], strides=[2, 2],
use_bias=False, activation=tf.nn.elu)
self.conv2 = tf.layers.conv2d(self.conv1, 64, kernel_size=[3, 3], strides=[2, 2],
use_bias=False, activation=tf.nn.elu)
self.batch_size = tf.placeholder(shape=None, dtype=tf.int32)
hidden = tf.layers.dense(c_layers.flatten(self.conv2), h_size, use_bias=False, activation=tf.nn.elu)
self.batch_size = tf.placeholder(shape=None, dtype=tf.int32, name='batch_size')
self.policy = tf.layers.dense(hidden, a_size, activation=None, use_bias=False,
kernel_initializer=c_layers.variance_scaling_initializer(factor=0.1))
self.probs = tf.nn.softmax(self.policy)

self.responsible_probs = tf.reduce_sum(self.probs * self.selected_actions, axis=1)
self.old_responsible_probs = tf.reduce_sum(self.old_probs * self.selected_actions, axis=1)
PPOModel.__init__(self, self.responsible_probs, self.old_responsible_probs,
self.value, self.entropy, beta, epsilon, lr)
self.create_ppo_optimizer(self.responsible_probs, self.old_responsible_probs,
self.value, self.entropy, beta, epsilon, lr, max_step)

50
python/ppo/trainer.py


class Trainer(object):
def __init__(self, ppo_model, sess, info, is_continuous, use_observations):
def __init__(self, ppo_model, sess, info, is_continuous, use_observations, use_states):
"""
Responsible for collecting experinces and training PPO model.
:param ppo_model: Tensorflow graph defining model.

self.model = ppo_model
self.sess = sess
stats = {'cumulative_reward': [], 'episode_length': [], 'value_estimate': [],
'entropy': [], 'value_loss': [], 'policy_loss': []}
'entropy': [], 'value_loss': [], 'policy_loss': [], 'learning_rate': []}
self.stats = stats
self.training_buffer = vectorize_history(empty_local_history({}))

self.is_continuous = is_continuous
self.use_observations = use_observations
self.use_states = use_states
def take_action(self, info, env, brain_name):
"""

:return: BrainInfo corresponding to new environment state.
"""
epsi = None
feed_dict = {self.model.batch_size: len(info.states)}
feed_dict = {self.model.state_in: info.states, self.model.batch_size: len(info.states),
self.model.epsilon: epsi}
elif self.use_observations:
feed_dict = {self.model.observation_in: np.vstack(info.observations),
self.model.batch_size: len(info.states)}
else:
feed_dict = {self.model.state_in: info.states, self.model.batch_size: len(info.states)}
actions, a_dist, value, ent = self.sess.run([self.model.output, self.model.probs,
self.model.value, self.model.entropy],
feed_dict=feed_dict)
feed_dict[self.model.epsilon] = epsi
if self.use_observations:
feed_dict[self.model.observation_in] = np.vstack(info.observations)
if self.use_states:
feed_dict[self.model.state_in] = info.states
actions, a_dist, value, ent, learn_rate = self.sess.run([self.model.output, self.model.probs,
self.model.value, self.model.entropy,
self.model.learning_rate],
feed_dict=feed_dict)
self.stats['learning_rate'].append(learn_rate)
new_info = env.step(actions, value={brain_name: value})[brain_name]
self.add_experiences(info, new_info, epsi, actions, a_dist, value)
return new_info

if not info.local_done[idx]:
if self.use_observations:
history['observations'].append(info.observations[idx])
else:
if self.use_states:
if self.is_continuous:
history['epsilons'].append(epsi[idx])
if self.is_continuous:
history['epsilons'].append(epsi[idx])
history['value_estimates'].append(value[idx][0])
history['cumulative_reward'] += next_info.rewards[idx]
history['episode_steps'] += 1

if info.local_done[l]:
value_next = 0.0
else:
feed_dict = {self.model.batch_size: len(info.states)}
feed_dict = {self.model.observation_in: np.vstack(info.observations),
self.model.batch_size: len(info.states)}
else:
feed_dict = {self.model.state_in: info.states,
self.model.batch_size: len(info.states)}
feed_dict[self.model.observation_in] = np.vstack(info.observations)
if self.use_states:
feed_dict[self.model.state_in] = info.states
value_next = self.sess.run(self.model.value, feed_dict)[l]
history = vectorize_history(self.history_dict[info.agents[l]])
history['advantages'] = get_gae(rewards=history['rewards'],

self.model.old_probs: np.vstack(training_buffer['action_probs'][start:end])}
if self.is_continuous:
feed_dict[self.model.epsilon] = np.vstack(training_buffer['epsilons'][start:end])
feed_dict[self.model.state_in] = np.vstack(training_buffer['states'][start:end])
if self.use_observations:
feed_dict[self.model.observation_in] = np.vstack(training_buffer['observations'][start:end])
else:
feed_dict[self.model.state_in] = np.vstack(training_buffer['states'][start:end])
if self.use_states:
feed_dict[self.model.state_in] = np.vstack(training_buffer['states'][start:end])
if self.use_observations:
feed_dict[self.model.observation_in] = np.vstack(training_buffer['observations'][start:end])
v_loss, p_loss, _ = self.sess.run([self.model.value_loss, self.model.policy_loss,
self.model.update_batch], feed_dict=feed_dict)
total_v += v_loss

2
python/setup.py


required = f.read().splitlines()
setup(name='unityagents',
version='0.1',
version='0.1.1',
description='Unity Machine Learning Agents',
license='Apache License 2.0',
author='Unity Technologies',

992
unity-environment/Assets/ML-Agents/Examples/3DBall/Scene.unity
文件差异内容过多而无法显示
查看文件

24
unity-environment/Assets/ML-Agents/Scripts/Brain.cs


* Defines brain-specific parameters
*/
[System.Serializable]
public struct BrainParameters
public class BrainParameters
public int stateSize;
public int stateSize = 1;
public int actionSize;
public int actionSize = 1;
public int memorySize;
public int memorySize = 0;
/**< \brief The length of the float vector that holds the memory for the agent */
public resolution[] cameraResolutions;
/**<\brief The list of observation resolutions for the brain */

public StateType actionSpaceType;
public StateType actionSpaceType = StateType.discrete;
public StateType stateSpaceType;
public StateType stateSpaceType = StateType.continuous;
}
/**

*/
public class Brain : MonoBehaviour
{
public BrainParameters brainParameters;
public BrainParameters brainParameters = new BrainParameters();
/**< \brief Defines brain specific parameters such as the state size*/
public BrainType brainType;
/**< \brief Defines what is the type of the brain :

foreach (KeyValuePair<int, Agent> idAgent in agents)
{
List<float> states = idAgent.Value.CollectState();
if (states.Count != brainParameters.stateSize)
if ((states.Count != brainParameters.stateSize) && (brainParameters.stateSpaceType == StateType.continuous ))
{
throw new UnityAgentsException(string.Format(@"The number of states does not match for agent {0}:
Was expecting {1} continuous states but received {2}.", idAgent.Value.gameObject.name, brainParameters.stateSize, states.Count));
}
if ((states.Count != 1) && (brainParameters.stateSpaceType == StateType.discrete ))
Was expecting {1} states but received {2}.", idAgent.Value.gameObject.name, brainParameters.stateSize, states.Count));
Was expecting 1 discrete states but received {1}.", idAgent.Value.gameObject.name, states.Count));
}
result.Add(idAgent.Key, states);
}

47
unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs


public string[] ObservationPlaceholderName;
/// Modify only in inspector : Name of the action node
public string ActionPlaceholderName = "action";
#if ENABLE_TENSORFLOW
#if ENABLE_TENSORFLOW
TFGraph graph;
TFSession session;
bool hasRecurrent;

float[,] inputState;
List<float[,,,]> observationMatrixList;
float[,] inputOldMemories;
#endif
#endif
/// Reference to the brain that uses this CoreBrainInternal
public Brain brain;

foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)
{
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
try
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
{
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
{
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
catch
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found.
In brain {0}, there are no {1} placeholder named {2}.",
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name));
}
}

runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]);
}
TFTensor[] networkOutput;
try
{
networkOutput = runner.Run();
}
catch (TFException e)
{
string errorMessage = e.Message;
try
{
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}",
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0],
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]);
}
finally
{
throw new UnityAgentsException(errorMessage);
}
}
// Create the recurrent tensor
if (hasRecurrent)

runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,];
float[,] recurrent_tensor = networkOutput[1].GetValue() as float[,];
int i = 0;
foreach (int k in agentKeys)

if (brain.brainParameters.actionSpaceType == StateType.continuous)
{
float[,] output = runner.Run()[0].GetValue() as float[,];
float[,] output = networkOutput[0].GetValue() as float[,];
int i = 0;
foreach (int k in agentKeys)
{

}
else if (brain.brainParameters.actionSpaceType == StateType.discrete)
{
long[,] output = runner.Run()[0].GetValue() as long[,];
long[,] output = networkOutput[0].GetValue() as long[,];
int i = 0;
foreach (int k in agentKeys)
{

4
unity-environment/README.md


* **GridWorld** - A simple gridworld containing regions which provide positive and negative reward. The agent must learn to move to the rewarding regions (green) and avoid the negatively rewarding ones (red). Supports discrete control.
* **Tennis** - An adversarial game where two agents control rackets, which must be used to bounce a ball back and forth between them. Supports continuous control.
For more informoation on each of these environments, see this [wiki page](../../../wiki/Example-Environments).
For more informoation on each of these environments, see this [documentation page](../docs/Example-Environments.md).
Within `ML-Agents/Template` there also exists:
* **Template** - An empty Unity scene with a single _Academy_, _Brain_, and _Agent_. Designed to be used as a template for new environments.

For information on the use of each script, see the comments and documentation within the files themselves, or read the [documentation](../../../wiki).
## Creating your own Unity Environment
For information on how to create a new Unity Environment, see the walkthrough [here](../../../wiki/Making-a-new-Unity-Environment). If you have questions or run into issues, please feel free to create issues through the repo, and we will do our best to address them.
For information on how to create a new Unity Environment, see the walkthrough [here](../docs/Making-a-new-Unity-Environment.md). If you have questions or run into issues, please feel free to create issues through the repo, and we will do our best to address them.
## Embedding Models with TensorflowSharp _[Experimental]_
If you will be using Tensorflow Sharp in Unity, you must:

9
unity-environment/Assets/ML-Agents/Examples/3DBall/Prefabs.meta


fileFormatVersion: 2
guid: 7f76e451b3030e54eac0f7c5488d22e9
folderAsset: yes
timeCreated: 1506066534
licenseType: Free
DefaultImporter:
userData:
assetBundleName:
assetBundleVariant:

9
unity-environment/Assets/ML-Agents/Examples/Basic.meta


fileFormatVersion: 2
guid: 230c334ab2f144bcda6eea42d18ebdc8
folderAsset: yes
timeCreated: 1506189168
licenseType: Pro
DefaultImporter:
userData:
assetBundleName:
assetBundleVariant:

71
docs/Agents-Editor-Interface.md


# ML Agents Editor Interface
This page contains an explanation of the use of each of the inspector panels relating to the `Academy`, `Brain`, and `Agent` objects.
## Academy
![Academy Inspector](../images/academy.png)
* `Max Steps` - Total number of steps per-episode. `0` corresponds to episodes without a maximum number
of steps. Once the step counter reaches maximum, the environment will reset.
* `Frames To Skip` - How many steps of the environment to skip before asking Brains for decisions.
* `Wait Time` - How many seconds to wait between steps when running in `Inference`.
* `Configuration` - The engine-level settings which correspond to rendering quality and engine speed.
* `Width` - Width of the environment window in pixels.
* `Height` - Width of the environment window in pixels.
* `Quality Level` - Rendering quality of environment. (Higher is better)
* `Time Scale` - Speed at which environment is run. (Higher is faster)
* `Target Frame Rate` - FPS engine attempts to maintain.
* `Default Reset Parameters` - List of custom parameters that can be changed in the environment on reset.
## Brain
![Brain Inspector](../images/brain.png)
* `Brain Parameters` - Define state, observation, and action spaces for the Brain.
* `State Size` - Length of state vector for brain (In _Continuous_ state space). Or number of possible
values (in _Discrete_ state space).
* `Action Size` - Length of action vector for brain (In _Continuous_ state space). Or number of possible
values (in _Discrete_ action space).
* `Memory Size` - Length of memory vector for brain. Used with Recurrent networks and frame-stacking CNNs.
* `Camera Resolution` - Describes height, width, and whether to greyscale visual observations for the Brain.
* `Action Descriptions` - A list of strings used to name the available actions for the Brain.
* `State Space Type` - Corresponds to whether state vector contains a single integer (Discrete) or a series of real-valued floats (Continuous).
* `Action Space Type` - Corresponds to whether action vector contains a single integer (Discrete) or a series of real-valued floats (Continuous).
* `Type of Brain` - Describes how Brain will decide actions.
* `External` - Actions are decided using Python API.
* `Internal` - Actions are decided using internal TensorflowSharp model.
* `Player` - Actions are decided using Player input mappings.
* `Heuristic` - Actions are decided using custom `Decision` script, which should be attached to the Brain game object.
### Internal Brain
![Internal Brain Inspector](../images/internal_brain.png)
* `Graph Model` : This must be the `bytes` file corresponding to the pretrained Tensorflow graph. (You must first drag this file into your Resources folder and then from the Resources folder into the inspector)
* `Graph Scope` : If you set a scope while training your tensorflow model, all your placeholder name will have a prefix. You must specify that prefix here.
* `Batch Size Node Name` : If the batch size is one of the inputs of your graph, you must specify the name if the placeholder here. The brain will make the batch size equal to the number of agents connected to the brain automatically.
* `State Node Name` : If your graph uses the state as an input, you must specify the name if the placeholder here.
* `Recurrent Input Node Name` : If your graph uses a recurrent input / memory as input and outputs new recurrent input / memory, you must specify the name if the input placeholder here.
* `Recurrent Output Node Name` : If your graph uses a recurrent input / memory as input and outputs new recurrent input / memory, you must specify the name if the output placeholder here.
* `Observation Placeholder Name` : If your graph uses observations as input, you must specify it here. Note that the number of observations is equal to the length of `Camera Resolutions` in the brain parameters.
* `Action Node Name` : Specify the name of the placeholder corresponding to the actions of the brain in your graph. If the action space type is continuous, the output must be a one dimensional tensor of float of length `Action Space Size`, if the action space type is discrete, the output must be a one dimensional tensor of int of length 1.
* `Graph Placeholder` : If your graph takes additional inputs that are fixed (example: noise level) you can specify them here. Note that in your graph, these must correspond to one dimensional tensors of int or float of size 1.
* `Name` : Corresponds to the name of the placeholdder.
* `Value Type` : Either Integer or Floating Point.
* `Min Value` and `Max Value` : Specify the range of the value here. The value will be sampled from the uniform distribution ranging from `Min Value` to `Max Value` inclusive.
### Player Brain
![Player Brain Inspector](../images/player_brain.png)
If the action space is discrete, you must map input keys to their corresponding integer values. If the action space is continuous, you must map input keys to their corresponding indices and float values.
## Agent
![Agent Inspector](../images/agent.png)
* `Brain` - The brain to register this agent to. Can be dragged into the inspector using the Editor.
* `Observations` - A list of `Cameras` which will be used to generate observations.
* `Max Step` - The per-agent maximum number of steps. Once this number is reached, the agent will be reset if `Reset On Done` is checked.

58
docs/Example-Environments.md


# Example Learning Environments
### About Example Environments
Unity ML Agents currently contains three example environments which demonstrate various features of the platform. In the coming months more will be added. We are also actively open to adding community contributed environments as examples, as long as they are small, simple, demonstrate a unique feature of the platform, and provide a unique non-trivial challenge to modern RL algorithms. Feel free to submit these environments with a Pull-Request explaining the nature of the environment and task.
Environments are located in `unity-environment/ML-Agents/Examples`.
## 3DBall
![Balance Ball](../images/balance.png)
* Set-up: A balance-ball task, where the agent controls the platform.
* Goal: The agent must balance the platform in order to keep the ball on it for as long as possible.
* Agents: The environment contains 12 agents of the same kind, all linked to a single brain.
* Agent Reward Function:
* +0.1 for every step the ball remains on the platform.
* -1.0 if the ball falls from the platform.
* Brains: One brain with the following state/action space.
* State space: (Continuous) 8 variables corresponding to rotation of platform, and position, rotation, and velocity of ball.
* Action space: (Continuous) Size of 2, with one value corresponding to X-rotation, and the other to Z-rotation.
* Observations: 0
* Reset Parameters: None
## GridWorld
![GridWorld](../images/gridworld.png)
* Set-up: A version of the classic grid-world task. Scene contains agent, goal, and obstacles.
* Goal: The agent must navigate the grid to the goal while avoiding the obstacles.
* Agents: The environment contains one agent linked to a single brain.
* Agent Reward Function:
* -0.01 for every step.
* +1.0 if the agent navigates to the goal position of the grid (episode ends).
* -1.0 if the agent navigates to an obstacle (episode ends).
* Brains: One brain with the following state/action space.
* State space: (Continuous) 6 variables corresponding to position of agent and nearest goal and obstacle.
* Action space: (Discrete) Size of 4, corresponding to movement in cardinal directions.
* Observations: One corresponding to top-down view of GridWorld.
* Reset Parameters: Three, corresponding to grid size, number of obstacles, and number of goals.
## Tennis
![Tennis](../images/tennis.png)
* Set-up: Two-player game where agents control rackets to bounce ball over a net.
* Goal: The agents must bounce ball between one another while not dropping or sending ball out of bounds.
* Agents: The environment contains two agent linked to a single brain.
* Agent Reward Function (independent):
* -0.1 To last agent to hit ball before going out of bounds or hitting ground/net (episode ends).
* +0.1 To agent when hitting ball after ball was hit by the other agent.
* +0.1 To agent who didn't hit ball last when ball hits ground.
* Brains: One brain with the following state/action space.
* State space: (Continuous) 6 variables corresponding to position of agent and nearest goal and obstacle.
* Action space: (Discrete) Size of 4, corresponding to movement toward net, away from net, jumping, and no-movement.
* Observations: None
* Reset Parameters: One, corresponding to size of ball.

134
docs/Getting-Started-with-Balance-Ball.md


# Getting Started with the Balance Ball Example
![Balance Ball](../images/balance.png)
This tutorial will walk through the end-to-end process of installing Unity Agents, building an example environment, training an agent in it, and finally embedding the trained model into the Unity environment.
Unity ML Agents contains a number of example environments which can be used as templates for new environments, or as ways to test a new ML algorithm to ensure it is functioning correctly.
In this walkthrough we will be using the **3D Balance Ball** environment. The environment contains a number of platforms and balls. Platforms can act to keep the ball up by rotating either horizontally or vertically. Each platform is an agent which is rewarded the longer it can keep a ball balanced on it, and provided a negative reward for dropping the ball. The goal of the training process is to have the platforms learn to never drop the ball.
Let's get started!
## Getting Unity ML Agents
### Start by installing **Unity 2017.1** or later (required)
Download link available [here](https://store.unity.com/download?ref=update).
If you are new to using the Unity Editor, you can find the general documentation [here](https://docs.unity3d.com/Manual/index.html).
### Clone the repository
Once installed, you will want to clone the Agents GitHub repository. References will be made throughout to `unity-environment` and `python` directories. Both are located at the root of the repository.
## Building Unity Environment
Launch the Unity Editor, and log in, if necessary.
1. Open the `unity-environment` folder using the Unity editor. *(If this is not first time running Unity, you'll be able to skip most of these immediate steps, choose directly from the list of recently opened projects)*
- On the initial dialog, choose `Open` on the top options
- On the file dialog, choose `unity-environment` and click `Open` *(It is safe to ignore any warning message about non-matching editor installation)*
- Once the project is open, on the `Project` panel (bottom of the tool), navigate to the folder `Assets/ML-Agents/Examples/3DBall/`
- Double-click the `Scene` icon (Unity logo) to load all environment assets
2. Go to `Edit -> Project Settings -> Player`
- Ensure that `Resolution and Presentation -> Run in Background` is Checked.
- Ensure that `Resolution and Presentation -> Display Resolution Dialog` is set to Disabled.
3. Expand the `Ball3DAcademy` GameObject and locate its child object `Ball3DBrain` within the Scene hierarchy in the editor. Ensure Type of Brain for this object is set to `External`.
4. *File -> Build Settings*
5. Choose your target platform:
- (optional) Select “Developer Build” to log debug messages.
6. Click *Build*:
- Save environment binary to the `python` sub-directory of the cloned repository *(you may need to click on the down arrow on the file chooser to be able to select that folder)*
## Installing Python API
In order to train an agent within the framework, you will need to install Python 2 or 3, and the dependencies described below.
### Windows Users
If you are a Windows user who is new to Python/TensorFlow, follow [this guide](https://nitishmutha.github.io/tensorflow/2017/01/22/TensorFlow-with-gpu-for-windows.html) to set up your Python environment.
### Requirements
* Jupyter
* Matplotlib
* numpy
* Pillow
* Python (2 or 3)
* TensorFlow (1.0+)
### Installing Dependencies
To install dependencies, go into the `python` directory and run:
`pip install .`
or
`pip3 install .`
If your Python environment doesn't include `pip`, see these [instructions](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) on installing it.
Once dependencies are installed, you are ready to test the Ball balance environment from Python.
### Testing Python API
To launch jupyter, run in the command line:
`jupyter notebook`
Then navigate to `localhost:8888` to access the notebooks. If you're new to jupyter, check out the [quick start guide](https://jupyter-notebook-beginner-guide.readthedocs.io/en/latest/execute.html) before you continue.
To ensure that your environment and the Python API work as expected, you can use the `python/Basics` Jupyter notebook. This notebook contains a simple walkthrough of the functionality of the API. Within `Basics`, be sure to set `env_name` to the name of the environment file you built earlier.
## Training the Brain with Reinforcement Learning
### Training with PPO
In order to train an agent to correctly balance the ball, we will use a Reinforcement Learning algorithm called Proximal Policy Optimization (PPO). This is a method that has been shown to be safe, efficient, and more general purpose than many other RL algorithms, as such we have chosen it as the example algorithm for use with ML Agents. For more information on PPO, OpenAI has a recent [blog post](https://blog.openai.com/openai-baselines-ppo/) explaining it.
In order to train the agents within the Ball Balance environment:
1. Open `python/PPO.ipynb` notebook from Jupyter.
2. Set `env_name` to whatever you named your environment file.
3. (optional) Set `run_path` directory to your choice.
4. Run all cells of notebook except for final.
### Observing Training Progress
In order to observe the training process in more detail, you can use Tensorboard.
In your command line, run :
`tensorboard --logdir='summaries`
Then navigate to `localhost:6006`.
From Tensorboard, you will see the summary statistics of six variables:
* Cumulative Reward - The mean cumulative episode reward over all agents. Should increase during a successful training session.
* Value Loss - The mean loss of the value function update. Correlates to how well the model is able to predict the value of each state. This should decrease during a succesful training session.
* Policy Loss - The mean loss of the policy function update. Correlates to how much the policy (process for deciding actions) is changing. The magnitude of this should decrease during a succesful training session.
* Episode Length - The mean length of each episode in the environment for all agents.
* Value Estimates - The mean value estimate for all states visited by the agent. Should increase during a successful training session.
* Policy Entropy - How random the decisions of the model are. Should slowly decrease during a successful training process. If it decreases too quickly, the `beta` hyperparameter should be increased.
## Embedding Trained Brain into Unity Environment _[Experimental]_
Once the training process displays an average reward of ~75 or greater, and there has been a recently saved model (denoted by the `Saved Model` message) you can choose to stop the training process by stopping the cell execution. Once this is done, you now have a trained TensorFlow model. You must now convert the saved model to a Unity-ready format which can be embedded directly into the Unity project by following the steps below.
### Setting up TensorFlowSharp Support
Because TensorFlowSharp support is still experimental, it is disabled by default. In order to enable it, you must follow these steps. Please note that the `Internal` Brain mode will only be available once completing these steps.
1. Make sure you are using Unity 2017.1 or newer.
2. Make sure the TensorFlowSharp plugin is in your Asset folder. A Plugins folder which includes TF# can be downloaded [here](https://s3.amazonaws.com/unity-agents/TFSharpPlugin.unitypackage).
3. Go to `Edit` -> `Project Settings` -> `Player`
4. For each of the platforms you target (**`PC, Mac and Linux Standalone`**, **`iOS`** or **`Android`**):
1. Go into `Other Settings`.
2. Select `Scripting Runtime Version` to `Experimental (.NET 4.6 Equivalent)`
3. In `Scripting Defined Symbols`, add the flag `ENABLE_TENSORFLOW`
5. Restart the Unity Editor.
### Embedding the trained model into Unity
1. Run the final cell of the notebook under "Export the trained TensorFlow graph" to produce an `<env_name >.bytes` file.
2. Move `<env_name>.bytes` from `python/models/...` into `unity-environment/Assets/ML-Agents/Examples/3DBall/TFModels/`.
3. Open the Unity Editor, and select the `3DBall` scene as described above.
4. Select the `3DBallBrain` object from the Scene hierarchy.
5. Change the `Type of Brain` to `Internal`.
6. Drag the `<env_name>.bytes` file from the Project window of the Editor to the `Graph Model` placeholder in the `3DBallBrain` inspector window.
7. Set the `Graph Placeholder` size to 1.
8. Add a placeholder called `epsilon` with a type of `floating point` and a range of values from 0 to 0.
9. Press the Play button at the top of the editor.
If you followed these steps correctly, you should now see the trained model being used to control the behavior of the balance ball within the Editor itself. From here you can re-build the Unity binary, and run it standalone with your agent's new learned behavior built right in.

43
docs/Limitations-&-Common-Issues.md


# Limitations and Common Issues
## Unity SDK
### Headless Mode
Currently headless mode is disabled. We hope to address these in a future version of Unity.
### Rendering Speed and Synchronization
Currently the speed of the game physics can only be increased to 100x real-time. The Academy also moves in time with FixedUpdate() rather than Update(), so game behavior tied to frame updates may be out of sync.
### macOS Metal Support
When running a Unity Environment on macOS using Metal rendering, the application can crash when the lock-screen is open. The solution is to set rendering to OpenGL. This can be done by navigating: `Edit -> Project Settings -> Player`. Clicking on `Other Settings`. Unchecking `Auto Graphics API for Mac`. Setting `OpenGL Core` to be above `Metal` in the priority list.
## Python API
### Environment Permission Error
If you directly import your Unity environment without building it in the editor, you might need to give it additionnal permissions to execute it.
If you receive such a permission error on macOS, run:
`chmod -R 755 *.app`
or on Linux:
`chmod -R 755 *.x86_64`
On Windows, you can find instructions [here](https://technet.microsoft.com/en-us/library/cc754344(v=ws.11).aspx).
### Environment Connection Timeout
If you are able to launch the environment from `UnityEnvironment` but then recieve a timeout error, there may be a number of possible causes.
* _Cause_: There may be no Brains in your environment which are set to `External`. In this case, the environment will not attempt to communicate with python. _Solution_: Set the train you wish to externally control through the Python API to `External` from the Unity Editor, and rebuild the environment.
* _Cause_: On OSX, the firewall may be preventing communication with the environment. _Solution_: Add the built environment binary to the list of exceptions on the firewall by following instructions [here](https://support.apple.com/en-us/HT201642).
### Filename not found
If you receive a file-not-found error while attempting to launch an environment, ensure that the environment files are in the root repository directory. For example, if there is a sub-folder containing the environment files, those files should be removed from the sub-folder and moved to the root.
### Communication port {} still in use
If you receive an exception `"Couldn't launch new environment because communication port {} is still in use. "`, you can change the worker number in the python script when calling
`UnityEnvironment(file_name=filename, worker_num=X)`

143
docs/Making-a-new-Unity-Environment.md


# Making a new Learning Environment
This tutorial walks through the process of creating a Unity Environment. A Unity Environment is an application built using the Unity Engine which can be used to train Reinforcement Learning agents.
## Setting up the Unity Project
1. Open an existing Unity project, or create a new one and import the RL interface package:
* [ML-Agents package without TensorflowSharp](https://s3.amazonaws.com/unity-agents/ML-AgentsNoPlugin.unitypackage)
* [ML-Agents package with TensorflowSharp](https://s3.amazonaws.com/unity-agents/ML-AgentsWithPlugin.unitypackage)
2. Rename `TemplateAcademy.cs` (and the contained class name) to the desired name of your new academy class. All Template files are in the folder `Assets -> Template -> Scripts`. Typical naming convention is `YourNameAcademy`.
3. Attach `YourNameAcademy.cs` to a new empty game object in the currently opened scene (`Unity` -> `GameObject` -> `Create Empty`) and rename this game object to `YourNameAcademy`. Since `YourNameAcademy` will be used to control all the environment logic, ensure the attached-to object is one which will remain in the scene regardless of the environment resetting, or other within-environment behavior.
4. Attach `Brain.cs` to a new empty game object and rename this game object to `YourNameBrain1`. Set this game object as a child of `YourNameAcademy` (Drag `YourNameBrain1` into `YourNameAcademy`). Note that you can have multiple brains in the Academy but they all must have different names.
5. Disable Window Resolution dialogue box and Splash Screen.
1. Go to `Edit` -> `Project Settings` -> `Player` -> `Resolution and Presentation`.
2. Set `Display Resolution Dialogue` to `Disabled`.
3.Check `Run In Background`.
4. Click `Splash Image`.
5. Uncheck `Show Splash Screen` _(Unity Pro only)_.
6. If you will be using Tensorflow Sharp in Unity, you must:
1. Make sure you are using Unity 2017.1 or newer.
2. Make sure the TensorflowSharp plugin is in your Asset folder. It can be downloaded [here](https://s3.amazonaws.com/unity-agents/TFSharpPlugin.unitypackage).
3. Go to `Edit` -> `Project Settings` -> `Player`
4. For each of the platforms you target (**`PC, Mac and Linux Standalone`**, **`iOS`** or **`Android`**):
1. Go into `Other Settings`.
2. Select `Scripting Runtime Version` to `Experimental (.NET 4.6 Equivalent)`
3. In `Scripting Defined Symbols`, add the flag `ENABLE_TENSORFLOW`
5. Note that some of these changes will require a Unity Restart
# Implementing `YourNameAcademy`
1. Click on the game object **`YourNameAcademy`**.
2. In the inspector tab, you can modify the characteristics of the academy:
* **`Max Steps`** Maximum length of each episode (set to 0 if you want do not want the environment to reset after a certain time).
* **`Wait Time`** Real-time between steps when running environment in test-mode. Used only when steps happen in `Update()`.
* **`Frames To Skip`** Number of frames (or physics updates) to skip between steps. The agents will act at every frame but get new actions only at every step.
* **`Training Configuration`** and **`Inference Configuration`** The first defines the configuration of the Engine at training time and the second at test / inference time. The training mode corresponds only to external training when the reset parameter `train_model` was set to True. The adjustable parameters are as follows:
* `Width` and `Height` Correspond to the width and height in pixels of the window (must be both greater than 0). Typically set it to a small size during training, and a larger size for visualization during inference.
* `Quality Level` Determines how mush rendering is performed. Typically set to small value during training and higher value for visualization during inference.
* `Time Scale` Physics speed. If environment utilized physics calculations, increase this during training, and set to `1.0f` during inference. Otherwise, set it to `1.0f`.
* `Target Frame Rate` Frequency of frame rendering. If environment utilizes observations, increase this during training, and set to `60` during inference. If no observations are used, this can be set to `1` during training.
* **`Default Reset Parameters`** You can set the default configuration to be passed at reset. This will be a mapping from strings to float values that you can call in the academy with `resetParameters["YourDefaultParameter"]`
3. Within **`InitializeAcademy()`**, you can define the initialization of the Academy. Note that this command is ran only once at the beginning of the training session.
3. Within **`AcademyStep()`**, you can define the environment logic each step. Use this function to modify the environment for the agents that will live in it.
4. Within **`AcademyReset()`**, you can reset the environment for a new episode. It should contain environment-specific code for setting up the environment. Note that `AcademyReset()` is called at the beginning of the training session to ensure the first episode is similar to the others.
## Implementing `YourNameBrain`
For each Brain game object in your academy :
1. Click on the game object `YourNameBrain`
2. In the inspector tab, you can modify the characteristics of the brain in **`Brain Parameters`**
* `State Size` Number of variables within the state provided to the agent(s).
* `Action Size` The number of possible actions for each individual agent to take.
* `Memory Size` The number of floats the agents will remember each step.
* `Camera Resolutions` A list of flexible length that contains resolution parameters : `height` and `width` define the number dimensions of the camera outputs in pixels. Check `Black And White` if you want the camera outputs to be black and white.
* `Action Descriptions` A list describing in human-readable language the meaning of each available action.
* `State Space Type` and `Action Space Type`. Either `discrete` or `continuous`.
* `discrete` corresponds to describing the action space with an `int`.
* `continuous` corresponds to describing the action space with an array of `float`.
3. If you want to collect data on you play, you can check the box **`Collect Data`**. This way, the states and actions passing through the brain will be saved in the folder `saved_plays`.
4. You can chose what kind of brain you want `YourNameBrain` to be. There are four possibilities:
* `External` : You need at least one of your brains to be external if you wish to interact with your environment from python.
* `Player` : To control you agents manually. If the action space is discrete, you must map input keys to their corresponding integer values. If the action space is continuous, you must map input keys to their corresponding indices and float values.
* `Heuristic` : You can have your brain automatically react to the observations and states in a customizable way. You will need to drag a `Decision` script into `YourNameBrain`. To create a custom reaction, you must :
* Rename `TemplateDecision.cs` (and the contained class name) to the desired name of your new reaction. Typical naming convention is `YourNameDecision`.
* Implement `Decide`: Given the state, observation and memory of an agent, this function must return an array of floats corresponding to the actions taken by the agent. If the action space type is discrete, the array must be of size 1.
* Optionally, implement `MakeMemory`: Given the state, observation and memory of an agent, this function must return an array of floats corresponding to the new memories of the agent.
* `Internal` : Note that you must have Tensorflow Sharp setup (see top of this page). Here are the fields that must be completed:
* `Graph Model` : This must be the `bytes` file corresponding to the pretrained Tensorflow graph. (You must first drag this file into your Resources folder and then from the Resources folder into the inspector)
* `Graph Scope` : If you set a scope while training your tensorflow model, all your placeholder name will have a prefix. You must specify that prefix here.
* `Batch Size Node Name` : If the batch size is one of the inputs of your graph, you must specify the name if the placeholder here. The brain will make the batch size equal to the number of agents connected to the brain automatically.
* `State Node Name` : If your graph uses the state as an input, you must specify the name if the placeholder here.
* `Recurrent Input Node Name` : If your graph uses a recurrent input / memory as input and outputs new recurrent input / memory, you must specify the name if the input placeholder here.
* `Recurrent Output Node Name` : If your graph uses a recurrent input / memory as input and outputs new recurrent input / memory, you must specify the name if the output placeholder here.
* `Observation Placeholder Name` : If your graph uses observations as input, you must specify it here. Note that the number of observations is equal to the length of `Camera Resolutions` in the brain parameters.
* `Action Node Name` : Specify the name of the placeholder corresponding to the actions of the brain in your graph. If the action space type is continuous, the output must be a one dimensional tensor of float of length `Action Space Size`, if the action space type is discrete, the output must be a one dimensional tensor of int of length 1.
* `Graph Placeholder` : If your graph takes additional inputs that are fixed (example: noise level) you can specify them here. Note that in your graph, these must correspond to one dimensional tensors of int or float of size 1.
* `Name` : Corresponds to the name of the placeholdder.
* `Value Type` : Either Integer or Floating Point.
* `Min Value` and 'Max Value' : Specify the minimum and maximum values (included) the placeholder can take. The value will be sampled from the uniform distribution at each step. If you want this value to be fixed, set both `Min Value` and `Max Value` to the same number.
## Implementing `YourNameAgent`
1. Rename `TemplateAgent.cs` (and the contained class name) to the desired name of your new agent. Typical naming convention is `YourNameAgent`.
2. Attach `YourNameAgent.cs` to the game object that represents your agent. (Example: if you want to make a self driving car, attach `YourNameAgent.cs` to a car looking game object)
3. In the inspector menu of you agent, drag the brain game object you want to use with this agent into the corresponding `Brain` box. Please note that you can have multiple agents with the same brain. If you want to give an agent a brain or change his brain via script, please use the method `ChangeBrain()`.
4. In the inspector menu of you agent, you can specify what cameras, your agent will use as its observations. To do so, drag the desired number of cameras into the `Observations` field. Note that if you want a camera to move along your agent, you can set this camera as a child of your agent
5. If `Reset On Done` is checked, `Reset()` will be called when the agent is done. Else, `AgentOnDone()` will be called. Note that if `Reset On Done` is unchecked, the agent will remain "done" until the Academy resets. This means that it will not take actions in the environment.
6. Implement the following functions in `YourNameAgent.cs` :
* `InitializeAgent()` : Use this method to initialize your agent. This method is called then the agent is created.
* `CollectState()` : Must return a list of floats corresponding to the state the agent is in. If the state space type is discrete, return a list of length 1 containing the float equivalent of your state.
* `AgentStep()` : This function will be called every frame, you must define what your agent will do given the input actions. You must also specify the rewards and whether or not the agent is done. To do so, modify the public fields of the agent `reward` and `done`.
* `AgentReset()` : This function is called at start, when the Academy resets and when the agent is done (if `Reset On Done` is checked).
* `AgentOnDone()` : If `Reset On Done` is not checked, this function will be called when the agent is done. `Reset()` will only be called when the Academy resets.
If you create Agents via script, we recommend you save them as prefabs and instantiate them either during steps or resets. If you do, you can use `GiveBrain(brain)` to have the agent subscribe to a specific brain. You can also use `RemoveBrain()` to unsubscribe from a brain.
# Defining the reward function
The reward function is the set of circumstances and event which we want to reward or punish the agent for making happen. Here are some examples of positive and negative rewards:
* Positive
* Reaching a goal
* Staying alive
* Defeating an enemy
* Gaining health
* Finishing a level
* Negative
* Taking damage
* Failing a level
* The agent’s death
Small negative rewards are also typically used each step in scenarios where the optimal agent behavior is to complete an episode as quickly as possible.
Note that the reward is reset to 0 at every step, you must add to the reward (`reward += rewardIncrement`). If you use `skipFrame` in the Academy and set your rewards instead of incrementing them, you might loose information since the reward is sent at every step, not at every frame.
## Agent Monitor
* You can add the script `AgentMonitor.cs` to any gameObject with a component `YourNameAgent.cs`. In the inspector of this component, you will see:
* `Fixed Position` : If this box is checked, the monitor will be on the left corner of the screen and will remain here. Note that you can only have one agent with a fixed monitor or multiple monitors will overlap.
* `Vertical Offset`: If `Fixed Position` is unchecked, the monitor will follow the Agent on the screen. Use `Vertical Offset` to decide how far above the agent the monitor should be.
* `Display Brain Name` : If this box is checked, the name of the brain will appear in the monitor. (Can be useful if you have similar agents using different brains).
* `Display Brain Type` : If this box is checked, the type of the brain of the agent will be displayed.
* `Display FrameCount` : If this box is checked, the number of frames that elapsed since the agent was reset will be displayed.
* `Display Current Reward`: If this box is checked, the current reward of the agent will be displayed.
* `Display Max Reward` : If this box is checked, the maximum reward obtained during this training session will be displayed.
* `Display State` : If this box is checked, the current state of the agent will be displayed.
* `Display Action` : If this box is checked, the current action the agent performs will be displayed.
If you passed a `value` from an external brain, the value will be displayed as a bar (green if value is positive / red if value is negative) above the monitor. The bar's maximum value is set to 1 by default but if the value of the agent is above this number, it becomes the new maximum.

33
docs/Organizing-the-Scene.md


# Organizing the Scene Layout
This tutorial will help you understand how to organize your scene when using Agents in your Unity environment.
## ML-Agents Game Objects
There are three kinds of game objects you need to include in your scene in order to use Unity ML-Agents:
* Academy
* Brain
* Agents
#### Keep in mind :
* There can only be one Academy game object in a scene.
* You can have multiple Brain game objects but they must be child of the Academy game object.
#### Here is an example of what your scene hierarchy should look like :
![Scene Hierarchy](../images/scene-hierarchy.png)
### Functionality
#### The Academy
The Academy is responsible for:
* Synchronizing the environment and keeping all agent's steps in pace. As such, there can only be one per scene.
* Determining the speed of the engine, its quality, and the display's resolution.
* Modifying the environment at every step and every reset according to the logic defined in `AcademyStep()` and `AcademyReset()`.
* Coordingating the Brains which must be set as children of the Academy.
#### Brains
Each brain corresponds to a specific Decision-making method. This often aligns with a specific neural network model. A Brains is responsible for deciding the action of all the Agents which are linked to it. There can be multiple brains in the same scene and multiple agents can subscribe to the same brain.
#### Agents
Each agent within a scene takes actions according to the decisions provided by it's linked Brain. There can be as many Agents of as many types as you like in the scene. The state size and action size of each agent must match the brain's parameters in order for the Brain to decide actions for it.

40
docs/Training-on-Amazon-Web-Service.md


# Training on Amazon Web Service
This page contains instructions for setting up an EC2 instance on Amazon Web Service for use in training ML-Agents environments. Current limitations of the Unity Engine require that a screen be available to render to. In order to make this possible when training on a remote server, a virtual screen is required. We can do this by installing Xorg and creating a virtual screen. Once installed and created, we can display the Unity environment in the virtual environment, and train as we would on a local machine.
## Pre-Configured AMI
A public pre-configured AMI is available with the ID: `ami-30ec184a` in the `us-east-1` region. It was created as a modification of the Amazon Deep Learning [AMI](https://aws.amazon.com/marketplace/pp/B01M0AXXQB).
## Configuring your own Instance
Instructions here are adapted from this [Medium post](https://medium.com/towards-data-science/how-to-run-unity-on-amazon-cloud-or-without-monitor-3c10ce022639) on running general Unity applications in the cloud.
1. To begin with, you will need an EC2 instance which contains the latest Nvidia drivers, CUDA8, and cuDNN. There are a number of external tutorials which describe this, such as:
* [Getting CUDA 8 to Work With openAI Gym on AWS and Compiling Tensorflow for CUDA 8 Compatibility](https://davidsanwald.github.io/2016/11/13/building-tensorflow-with-gpu-support.html)
* [Installing TensorFlow on an AWS EC2 P2 GPU Instance](http://expressionflow.com/2016/10/09/installing-tensorflow-on-an-aws-ec2-p2-gpu-instance/)
* [Updating Nvidia CUDA to 8.0.x in Ubuntu 16.04 – EC2 Gx instance](https://aichamp.wordpress.com/2016/11/09/updating-nvidia-cuda-to-8-0-x-in-ubuntu-16-04-ec2-gx-instance/)
2. Move `python` to remote instance.
2. Install the required packages with `pip install .`.
3. Run the following commands to install Xorg:
```
sudo apt-get update
sudo apt-get install -y xserver-xorg mesa-utils
sudo nvidia-xconfig -a --use-display-device=None --virtual=1280x1024
```
4. Restart the EC2 instance.
5. On start-up, run:
```
sudo /usr/bin/X :0 &
export DISPLAY=:0
```
Depending on how Xorg is configured, you may need to run `sudo killall Xorg` before starting Xorg with the above command.
6. To ensure the installation was succesful, run `glxgears`. If there are no errors, then Xorg is correctly configured.
7. There is a bug in _Unity 2017.1_ which requires the uninstallation of `libxrandr2`, which can be removed with `apt-get remove --purge libxrandr2`. This is scheduled to be fixed in 2017.3.
If all steps worked correctly, upload an example binary built for Linux to the instance, and test it from python with:
```python
from unityagents import UnityEnvironment
env = UnityEnvironment(your_env)
```
You should receive a message confirming that the environment was loaded succesfully.

44
docs/Unity-Agents---Python-API.md


# Python API
_Notice: Currently communication between Unity and Python takes place over an open socket without authentication. As such, please make sure that the network where training takes place is secure. This will be addressed in a future release._
## Loading a Unity Environment
Python-side communication happens through `UnityEnvironment` which is located in `python/unityagents`. To load a Unity environment from a built binary file, put the file in the same directory as `unityagents`. In python, run:
```python
from unityagents import UnityEnvironment
env = UnityEnvironment(file_name=filename, worker_num=0)
```
* `file_name` is the name of the environment binary (located in the root directory of the python project).
* `worker_num` indicates which port to use for communication with the environment. For use in parallel training regimes such as A3C.
## Interacting with a Unity Environment
A BrainInfo object contains the following fields:
* **`observations`** : A list of 4 dimensional numpy arrays. Matrix n of the list corresponds to the n<sup>th</sup> observation of the brain.
* **`states`** : A two dimensional numpy array of dimension `(batch size, state size)` if the state space is continuous and `(batch size, state size)` if the state space is discrete.
* **`memories`** : A two dimensional numpy array of dimension `(batch size, memory size)` which corresponds to the memories sent at the previous step.
* **`rewards`** : A list as long as the number of agents using the brain containing the rewards they each obtained at the previous step.
* **`local_done`** : A list as long as the number of agents using the brain containing `done` flags (wether or not the agent is done).
* **`agents`** : A list of the unique ids of the agents using the brain.
Once loaded, `env` can be used in the following way:
- **Print : `print(str(env))`**
Prints all parameters relevant to the loaded environment and the external brains.
- **Reset : `env.reset(train_model=True, config=None)`**
Send a reset signal to the environment, and provides a dictionary mapping brain names to BrainInfo objects.
- `train_model` indicates whether to run the environment in train (`True`) or test (`False`) mode.
- `config` is an optional dictionary of configuration flags specific to the environment. For more information on adding optional config flags to an environment, see [here](Making-a-new-Unity-Environment.md#implementing-yournameacademy). For generic environments, `config` can be ignored. `config` is a dictionary of strings to floats where the keys are the names of the `resetParameters` and the values are their corresponding float values.
- **Step : `env.step(action, memory=None, value = None)`**
Sends a step signal to the environment using the actions. Note that if you have more than one brain in the environment, you must provide a dictionary from brain names to actions.
- `action` can be one dimensional arrays or two dimensional arrays if you have multiple agents per brains.
- `memory` is an optional input that can be used to send a list of floats per agents to be retrieved at the next step.
- `value` is an optional input that be used to send a single float per agent to be displayed if and `AgentMonitor.cs` component is attached to the agent.
Returns a dictionary mapping brain names to BrainInfo objects.
- **Close : `env.close()`**
Sends a shutdown signal to the environment and closes the communication socket.

43
docs/Unity-Agents-Overview.md


# Learning Environments Overview
![diagram](../images/agents_diagram.png)
A visual depiction of how an Learning Environment might be configured within ML-Agents.
The three main kinds of objects within any Agents Learning Environment are:
* Agent - Each Agent can have a unique set of states and observations, take unique actions within the environment, and can receive unique rewards for events within the environment. An agent's actions are decided by the brain it is linked to.
* Brain - Each Brain defines a specific state and action space, and is responsible for deciding which actions each of its linked agents will take. Brains can be set to one of four modes:
* External - Action decisions are made using TensorFlow (or your ML library of choice) through communication over an open socket with our Python API.
* Internal (Experimental) - Actions decisions are made using a trained model embedded into the project via TensorFlowSharp.
* Player - Action decisions are made using player input.
* Heuristic - Action decisions are made using hand-coded behavior.
* Academy - The Academy object within a scene also contains as children all Brains within the environment. Each environment contains a single Academy which defines the scope of the environment, in terms of:
* Engine Configuration - The speed and rendering quality of the game engine in both training and inference modes.
* Frameskip - How many engine steps to skip between each agent making a new decision.
* Global episode length - How long the the episode will last. When reached, all agents are set to done.
The states and observations of all agents with brains set to External are collected by the External Communicator, and communicated via the Python API. By setting multiple agents to a single brain, actions can be decided in a batch fashion, taking advantage of the inherently parallel computations of neural networks. For more information on how these objects work together within a scene, see our wiki page.
## Flexible Training Scenarios
With the Unity ML-Agents, a variety of different kinds of training scenarios are possible, depending on how agents, brains, and rewards are connected. We are excited to see what kinds of novel and fun environments the community creates. For those new to training intelligent agents, below are a few examples that can serve as inspiration. Each is a prototypical environment configurations with a description of how it can be created using the ML-Agents SDK.
* **Single-Agent** - A single agent linked to a single brain. The traditional way of training an agent. An example is any single-player game, such as Chicken. [Video Link](https://www.youtube.com/watch?v=fiQsmdwEGT8&feature=youtu.be).
* **Simultaneous Single-Agent** - Multiple independent agents with independent reward functions linked to a single brain. A parallelized version of the traditional training scenario, which can speed-up and stabilize the training process. An example might be training a dozen robot-arms to each open a door simultaneously. [Video Link](https://www.youtube.com/watch?v=fq0JBaiCYNA).
* **Adversarial Self-Play** - Two interacting agents with inverse reward functions linked to a single brain. In two-player games, adversarial self-play can allow an agent to become increasingly more skilled, while always having the perfectly matched opponent: itself. This was the strategy employed when training AlphaGo, and more recently used by OpenAI to train a human-beating 1v1 Dota 2 agent.
* **Cooperative Multi-Agent** - Multiple interacting agents with a shared reward function linked to either a single or multiple different brains. In this scenario, all agents must work together to accomplish a task than couldn’t be done alone. Examples include environments where each agent only has access to partial information, which needs to be shared in order to accomplish the task or collaboratively solve a puzzle. (Demo project coming soon)
* **Competitive Multi-Agent** - Multiple interacting agents with inverse reward function linked to either a single or multiple different brains. In this scenario, agents must compete with one another to either win a competition, or obtain some limited set of resources. All team sports would fall into this scenario. (Demo project coming soon)
* **Ecosystem** - Multiple interacting agents with independent reward function linked to either a single or multiple different brains. This scenario can be thought of as creating a small world in which animals with different goals all interact, such a savanna in which there might be zebras, elephants, and giraffes, or an autonomous driving simulation within an urban environment. (Demo project coming soon)
## Additional Features
Beyond the flexible training scenarios made possible by the Academy/Brain/Agent system, ML-Agents also includes other features which improve the flexibility and interpretability of the training process.
* **Monitoring Agent’s Decision Making** - Since communication in ML-Agents is a two-way street, we provide an Agent Monitor class in Unity which can display aspects of the trained agent, such as policy and value output within the Unity environment itself. By providing these outputs in real-time, researchers and developers can more easily debug an agent’s behavior.
* **Curriculum Learning** - It is often difficult for agents to learn a complex task at the beginning of the training process. Curriculum learning is the process of gradually increasing the difficulty of a task to allow more efficient learning. ML-Agents supports setting custom environment parameters every time the environment is reset. This allows elements of the environment related to difficulty or complexity to be dynamically adjusted based on training progress.
* **Complex Visual Observations** - Unlike other platforms, where the agent’s observation might be limited to a single vector or image, ML-Agents allows multiple cameras to be used for observations per agent. This enables agents to learn to integrate information from multiple visual streams, as would be the case when training a self-driving car which required multiple cameras with different viewpoints, a navigational agent which might need to integrate aerial and first-person visuals, or an agent which takes both a raw visual input, as well as a depth-map or object-segmented image.
* **Imitation Learning (Coming Soon)** - It is often more intuitive to simply demonstrate the behavior we want an agent to perform, rather than attempting to have it learn via trial-and-error methods. In a future release, ML-Agents will provide the ability to record all state/action/reward information for use in supervised learning scenarios, such as imitation learning. By utilizing imitation learning, a player can provide demonstrations of how an agent should behave in an environment, and then utilize those demonstrations to train an agent in either a standalone fashion, or as a first-step in a reinforcement learning process.

112
docs/Using-TensorFlow-Sharp-in-Unity-(Experimental).md


# Using TensorFlowSharp in Unity (Experimental)
Unity now offers the possibility to use pretrained TensorFlow graphs inside of the game engine. This was made possible thanks to [the TensorFlowSharp project](https://github.com/migueldeicaza/TensorFlowSharp).
_Notice: This feature is still experimental. While it is possible to embed trained models into Unity games, Unity Technologies does not officially support this use-case for production games at this time. As such, no guarantees are provided regarding the quality of experience. If you encounter issues regarding battery life, or general performance (especially on mobile), please let us know._
## Supported devices :
* Linux 64 bits
* Mac OSX 64 bits
* Windows 64 bits
* iOS (Requires additional steps)
* Android
## Requirements
* Unity 2017.1 or above
* Unity Tensorflow Plugin ([Download here](https://s3.amazonaws.com/unity-agents/TFSharpPlugin.unitypackage))
# Using TensorflowSharp with ML-Agents
In order to bring a fully trained agent back into Unity, you will need to make sure the nodes of your graph have appropriate names. You can give names to nodes in Tensorflow :
```python
variable= tf.identity(variable, name="variable_name")
```
We recommend using the following naming convention:
* Name the batch size input placeholder `batch_size`
* Name the input state placeholder `state`
* Name the output node `action`
* Name the recurrent vector (memory) input placeholder `recurrent_in` (if any)
* Name the recurrent vector (memory) output node `recurrent_out` (if any)
* Name the observations placeholders input placeholders `observation_i` where `i` is the index of the observation (starting at 0)
You can have additional placeholders for float or integers but they must be placed in placeholders of dimension 1 and size 1. (Be sure to name them)
It is important that the inputs and outputs of the graph are exactly the one you receive / give when training your model with an `External` brain. This means you cannot have any operations such as reshaping outside of the graph.
The object you get by calling `step` or `reset` has fields `states`, `observations` and `memories` which must correspond to the placeholders of your graph. Similarly, the arguments `action` and `memory` you pass to `step` must correspond to the output nodes of your graph.
While training your Agent using the Python API, you can save your graph at any point of the training. Note that the argument `output_node_names` must be the name of the tensor your graph outputs (separated by a coma if multiple outputs). In this case, it will be either `action` or `action,recurrent_out` if you have recurrent outputs.
```python
from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph(input_graph = model_path +'/raw_graph_def.pb',
input_binary = True,
input_checkpoint = last_checkpoint,
output_node_names = "action",
output_graph = model_path +'/your_name_graph.bytes' ,
clear_devices = True, initializer_nodes = "",input_saver = "",
restore_op_name = "save/restore_all", filename_tensor_name = "save/Const:0")
```
Your model will be saved with the name `your_name_graph.bytes` and will contain both the graph and associated weights. Note that you must save your graph as a bytes file so Unity can load it.
## Inside Unity
Go to `Edit` -> `Player Settings` and add `ENABLE_TENSORFLOW` to the `Scripting Define Symbols` for each type of device you want to use (**`PC, Mac and Linux Standalone`**, **`iOS`** or **`Android`**).
Set the Brain you used for training to `Internal`. Drag `your_name_graph.bytes` into Unity and then drag it into The `Graph Model` field in the Brain. If you used a scope when training you graph, specify it in the `Graph Scope` field. Specify the names of the nodes you used in your graph. If you followed these instructions well, the agents in your environment that use this brain will use you fully trained network to make decisions.
# iOS additional instructions for building
* Once you build for iOS in the editor, Xcode will launch.
* In `General` -> `Linked Frameworks and Libraries`:
* Add a framework called `Framework.accelerate`
* Remove the library `libtensorflow-core.a`
* In `Build Settings`->`Linking`->`Other Linker Flags`:
* Double Click on the flag list
* Type `-force_load`
* Drag the library `libtensorflow-core.a` from the `Project Navigator` on the left under `Libraries/ML-Agents/Plugins/iOS` into the flag list.
# Using TensorflowSharp without ML-Agents
Beyond controlling an in-game agent, you may desire to use TensorFlowSharp for more general computation. The below instructions describe how to generally embed Tensorflow models without using the ML-Agents framework.
You must have a Tensorflow graph `your_name_graph.bytes` made using Tensorflow's `freeze_graph.py`. The process to create such graph is explained above.
## Inside of Unity
Put the file `your_name_graph.bytes` into Resources.
In your C# script :
At the top, add the line
```csharp
using Tensorflow;
```
If you will be building for android, you must add this block at the start of your code :
```csharp
#if UNITY_ANDROID
TensorFlowSharp.Android.NativeBinding.Init();
#endif
```
Put your graph as a text asset in the variable `graphModel`. You can do so in the inspector by making `graphModel` a public variable and dragging you asset in the inspector or load it from the Resources folder :
```csharp
TextAsset graphModel = Resources.Load (your_name_graph) as TextAsset;
```
You then must recreate the graph in Unity by adding the code :
```csharp
graph = new TFGraph ();
graph.Import (graphModel.bytes);
session = new TFSession (graph);
```
Your newly created graph need to get input tensors. Here is an example with a one dimensional tensor of size 2:
```csharp
var runner = session.GetRunner ();
runner.AddInput (graph ["input_placeholder_name"] [0], new float[]{ placeholder_value1, placeholder_value2 });
```
You need to give all required inputs to the graph. There is one input per TensorFlow placeholder.
To retrieve the output of your graph run the following code. Note that this is for an output that would be a two dimensional tensor of floats. Cast to a long array if your outputs are integers.
```csharp
runner.Fetch (graph["output_placeholder_name"][0]);
float[,] recurrent_tensor = runner.Run () [0].GetValue () as float[,];
```

19
docs/Readme.md


# Unity ML Agents Documentation
## Basic
* [Unity ML Agents Overview](Unity-Agents-Overview.md)
* [Installation & Set-up](installation.md)
* [Getting Started with the Balance Ball Environment](Getting-Started-with-Balance-Ball.md)
* [Example Environments](Example-Environments.md)
## Advanced
* [How to make a new Unity Environment](Making-a-new-Unity-Environment.md)
* [Best practices when designing an Environment](best-practices.md)
* [How to organize the Scene](Organizing-the-Scene.md)
* [How to use the Python API](Unity-Agents---Python-API.md)
* [How to use TensorflowSharp inside Unity [Experimental]](Using-TensorFlow-Sharp-in-Unity-(Experimental).md)
* [Agents SDK Inspector Descriptions](Agents-Editor-Interface.md)
* [Training on the Cloud with Amazon Web Services](Training-on-Amazon-Web-Service.md)
## Help
* [Limitations & Common Issues](Limitations-&-Common-Issues.md)

20
docs/best-practices.md


# Environment Design Best Practices
## General
* It is often helpful to being with the simplest version of the problem, to ensure the agent can learn it. From there increase
complexity over time.
* When possible, It is often helpful to ensure that you can complete the task by using a Player Brain to control the agent.
## Rewards
* The magnitude of any given reward should typically not be greater than 1.0 in order to ensure a more stable learning process.
* Positive rewards are often more helpful to shaping the desired behavior of an agent than negative rewards.
* For locomotion tasks, a small positive reward (+0.1) for forward progress is typically used.
* If you want the agent the finish a task quickly, it is often helpful to provide a small penalty every step (-0.1).
## States
* The magnitude of each state variable should be normalized to around 1.0.
* States should include all variables relevant to allowing the agent to take the optimally informed decision.
* Categorical state variables such as type of object (Sword, Shield, Bow) should be encoded in one-hot fashion (ie `3` -> `0, 0, 1`).
## Actions
* When using continuous control, action values should be clipped to an appropriate range.

51
docs/installation.md


# Installation & Set-up
## Install **Unity 2017.1** or later (required)
Download link available [here](https://store.unity.com/download?ref=update).
## Clone the repository
Once installed, you will want to clone the Agents GitHub repository. References will be made
throughout to `unity-environment` and `python` directories. Both are located at the root of the repository.
## Installing Python API
In order to train an agent within the framework, you will need to install Python 2 or 3, and the dependencies described below.
### Windows Users
If you are a Windows user who is new to Python/TensorFlow, follow [this guide](https://nitishmutha.github.io/tensorflow/2017/01/22/TensorFlow-with-gpu-for-windows.html) to set up your Python environment.
### Requirements
* Jupyter
* Matplotlib
* numpy
* Pillow
* Python (2 or 3)
* docopt (Training)
* TensorFlow (1.0+) (Training)
### Installing Dependencies
To install dependencies, go into the `python` directory and run (depending on your python version):
`pip install .`
or
`pip3 install .`
If your Python environment doesn't include `pip`, see these [instructions](https://packaging.python.org/guides/installing-using-linux-tools/#installing-pip-setuptools-wheel-with-linux-package-managers) on installing it.
Once the requirements are successfully installed, the next step is to check out the [Getting Started guide](Getting-Started-with-Balance-Ball.md).
## Installation Help
### Using Jupyter Notebook
For a walkthrough of how to use Jupyter notebook, see [here](http://jupyter-notebook-beginner-guide.readthedocs.io/en/latest/execute.html).
### General Issues
If you run into issues while attempting to install and run Unity ML Agents, see [here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Limitations-&-Common-Issues.md) for a list of common issues and solutions.
If you have an issue that isn't covered here, feel free to contact us at ml-agents@unity3d.com. Alternatively, feel free to create an issue on the repository.
Be sure to include relevant information on OS, Python version, and exact error message if possible.

354
unity-environment/Assets/ML-Agents/Examples/3DBall/Prefabs/Game.prefab


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!1001 &100100000
Prefab:
m_ObjectHideFlags: 1
serializedVersion: 2
m_Modification:
m_TransformParent: {fileID: 0}
m_Modifications: []
m_RemovedComponents: []
m_ParentPrefab: {fileID: 0}
m_RootGameObject: {fileID: 1665577603478558}
m_IsPrefabParent: 1
--- !u!1 &1536511242562482
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
serializedVersion: 5
m_Component:
- component: {fileID: 4095153324598508}
- component: {fileID: 33624138613246242}
- component: {fileID: 135537954369514846}
- component: {fileID: 23926879167055144}
- component: {fileID: 54606488393645454}
m_Layer: 0
m_Name: Ball
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!1 &1665577603478558
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
serializedVersion: 5
m_Component:
- component: {fileID: 4162486845013972}
m_Layer: 0
m_Name: Game
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!1 &1796982831911906
GameObject:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
serializedVersion: 5
m_Component:
- component: {fileID: 4426913383042176}
- component: {fileID: 33771194379694972}
- component: {fileID: 64845544452232838}
- component: {fileID: 23074177913792258}
m_Layer: 0
m_Name: Plane (1)
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!1 &1914042422505674
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
serializedVersion: 5
m_Component:
- component: {fileID: 4014323973652284}
- component: {fileID: 33653230733766482}
- component: {fileID: 65551894134645910}
- component: {fileID: 23487775825466554}
- component: {fileID: 114980646877373948}
- component: {fileID: 114290313258162170}
m_Layer: 0
m_Name: Platform
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &4014323973652284
Transform:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1914042422505674}
m_LocalRotation: {x: -0.069583125, y: 0.0049145464, z: 0.0702813, w: 0.99508524}
m_LocalPosition: {x: 0, y: 2.22, z: -5}
m_LocalScale: {x: 5, y: 0.19999997, z: 5}
m_Children:
- {fileID: 4426913383042176}
m_Father: {fileID: 4162486845013972}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: -8, y: 0, z: 8.08}
--- !u!4 &4095153324598508
Transform:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1536511242562482}
m_LocalRotation: {x: -0, y: -0, z: -0, w: 1}
m_LocalPosition: {x: 0, y: 6.2200003, z: -5}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 4162486845013972}
m_RootOrder: 1
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!4 &4162486845013972
Transform:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1665577603478558}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: -10.3, y: 9, z: 5}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children:
- {fileID: 4014323973652284}
- {fileID: 4095153324598508}
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!4 &4426913383042176
Transform:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1796982831911906}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0.51, z: 0}
m_LocalScale: {x: 0.1, y: 0.1, z: 0.1}
m_Children: []
m_Father: {fileID: 4014323973652284}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!23 &23074177913792258
MeshRenderer:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1796982831911906}
m_Enabled: 1
m_CastShadows: 0
m_ReceiveShadows: 0
m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_Materials:
- {fileID: 2100000, guid: e35c6159207d7448e988c8cf0c137ab6, type: 2}
m_StaticBatchInfo:
firstSubMesh: 0
subMeshCount: 0
m_StaticBatchRoot: {fileID: 0}
m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0
m_SelectedEditorRenderState: 3
m_MinimumChartSize: 4
m_AutoUVMaxDistance: 0.5
m_AutoUVMaxAngle: 89
m_LightmapParameters: {fileID: 0}
m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
--- !u!23 &23487775825466554
MeshRenderer:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1914042422505674}
m_Enabled: 1
m_CastShadows: 1
m_ReceiveShadows: 1
m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_Materials:
- {fileID: 2100000, guid: f8f8a9c03cd1f4fdbb7a3e95be9ea341, type: 2}
m_StaticBatchInfo:
firstSubMesh: 0
subMeshCount: 0
m_StaticBatchRoot: {fileID: 0}
m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0
m_SelectedEditorRenderState: 3
m_MinimumChartSize: 4
m_AutoUVMaxDistance: 0.5
m_AutoUVMaxAngle: 89
m_LightmapParameters: {fileID: 0}
m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
--- !u!23 &23926879167055144
MeshRenderer:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1536511242562482}
m_Enabled: 1
m_CastShadows: 1
m_ReceiveShadows: 1
m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_Materials:
- {fileID: 2100000, guid: edd958d75ed1448138de86f3335ea4fa, type: 2}
m_StaticBatchInfo:
firstSubMesh: 0
subMeshCount: 0
m_StaticBatchRoot: {fileID: 0}
m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0
m_SelectedEditorRenderState: 3
m_MinimumChartSize: 4
m_AutoUVMaxDistance: 0.5
m_AutoUVMaxAngle: 89
m_LightmapParameters: {fileID: 0}
m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
--- !u!33 &33624138613246242
MeshFilter:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1536511242562482}
m_Mesh: {fileID: 10207, guid: 0000000000000000e000000000000000, type: 0}
--- !u!33 &33653230733766482
MeshFilter:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1914042422505674}
m_Mesh: {fileID: 10202, guid: 0000000000000000e000000000000000, type: 0}
--- !u!33 &33771194379694972
MeshFilter:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1796982831911906}
m_Mesh: {fileID: 10209, guid: 0000000000000000e000000000000000, type: 0}
--- !u!54 &54606488393645454
Rigidbody:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1536511242562482}
serializedVersion: 2
m_Mass: 1
m_Drag: 0
m_AngularDrag: 0.01
m_UseGravity: 1
m_IsKinematic: 0
m_Interpolate: 0
m_Constraints: 0
m_CollisionDetection: 0
--- !u!64 &64845544452232838
MeshCollider:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1796982831911906}
m_Material: {fileID: 0}
m_IsTrigger: 0
m_Enabled: 0
serializedVersion: 2
m_Convex: 0
m_InflateMesh: 0
m_SkinWidth: 0.01
m_Mesh: {fileID: 10209, guid: 0000000000000000e000000000000000, type: 0}
--- !u!65 &65551894134645910
BoxCollider:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1914042422505674}
m_Material: {fileID: 0}
m_IsTrigger: 0
m_Enabled: 1
serializedVersion: 2
m_Size: {x: 1, y: 1, z: 1}
m_Center: {x: 0, y: 0, z: 0}
--- !u!114 &114290313258162170
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1914042422505674}
m_Enabled: 0
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: e040eaa8759024abbbb14994dc4c55ee, type: 3}
m_Name:
m_EditorClassIdentifier:
fixedPosition: 1
verticalOffset: 10
DisplayBrainName: 1
DisplayBrainType: 1
DisplayFrameCount: 1
DisplayCurrentReward: 0
DisplayMaxReward: 1
DisplayState: 0
DisplayAction: 0
--- !u!114 &114980646877373948
MonoBehaviour:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1914042422505674}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: aaba48bf82bee4751aa7b89569e57f73, type: 3}
m_Name:
m_EditorClassIdentifier:
brain: {fileID: 0}
observations: []
maxStep: 5000
resetOnDone: 1
reward: 0
done: 0
value: 0
CummulativeReward: 0
stepCounter: 0
agentStoredAction: []
memory: []
id: 0
ball: {fileID: 1536511242562482}
--- !u!135 &135537954369514846
SphereCollider:
m_ObjectHideFlags: 1
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 100100000}
m_GameObject: {fileID: 1536511242562482}
m_Material: {fileID: 13400000, guid: 56162663048874fd4b10e065f9cf78b7, type: 2}
m_IsTrigger: 0
m_Enabled: 1
serializedVersion: 2
m_Radius: 0.5
m_Center: {x: 0, y: 0, z: 0}

9
unity-environment/Assets/ML-Agents/Examples/3DBall/Prefabs/Game.prefab.meta


fileFormatVersion: 2
guid: ff026d63a00abdc48ad6ddcff89aba04
timeCreated: 1506066551
licenseType: Free
NativeFormatImporter:
mainObjectFileID: 100100000
userData:
assetBundleName:
assetBundleVariant:

9
unity-environment/Assets/ML-Agents/Examples/Basic/Materials.meta


fileFormatVersion: 2
guid: 0f9b2a7b3f61045b8a791eeae8175dc5
folderAsset: yes
timeCreated: 1506189694
licenseType: Pro
DefaultImporter:
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Basic/Materials/agent.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: agent
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0.5
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.10980392, g: 0.6039216, b: 1, a: 0.8392157}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

9
unity-environment/Assets/ML-Agents/Examples/Basic/Materials/agent.mat.meta


fileFormatVersion: 2
guid: 260483cdfc6b14e26823a02f23bd8baa
timeCreated: 1506189720
licenseType: Pro
NativeFormatImporter:
mainObjectFileID: 2100000
userData:
assetBundleName:
assetBundleVariant:

76
unity-environment/Assets/ML-Agents/Examples/Basic/Materials/goal.mat


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!21 &2100000
Material:
serializedVersion: 6
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_Name: goal
m_Shader: {fileID: 46, guid: 0000000000000000f000000000000000, type: 0}
m_ShaderKeywords:
m_LightmapFlags: 4
m_EnableInstancingVariants: 0
m_DoubleSidedGI: 0
m_CustomRenderQueue: -1
stringTagMap: {}
disabledShaderPasses: []
m_SavedProperties:
serializedVersion: 3
m_TexEnvs:
- _BumpMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailAlbedoMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailMask:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _DetailNormalMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _EmissionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MainTex:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _MetallicGlossMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _OcclusionMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
- _ParallaxMap:
m_Texture: {fileID: 0}
m_Scale: {x: 1, y: 1}
m_Offset: {x: 0, y: 0}
m_Floats:
- _BumpScale: 1
- _Cutoff: 0.5
- _DetailNormalMapScale: 1
- _DstBlend: 0
- _GlossMapScale: 1
- _Glossiness: 0.5
- _GlossyReflections: 1
- _Metallic: 0
- _Mode: 0
- _OcclusionStrength: 1
- _Parallax: 0.02
- _SmoothnessTextureChannel: 0
- _SpecularHighlights: 1
- _SrcBlend: 1
- _UVSec: 0
- _ZWrite: 1
m_Colors:
- _Color: {r: 0.5058824, g: 0.74509805, b: 0.25490198, a: 1}
- _EmissionColor: {r: 0, g: 0, b: 0, a: 1}

9
unity-environment/Assets/ML-Agents/Examples/Basic/Materials/goal.mat.meta


fileFormatVersion: 2
guid: 624b24bbec31f44babfb57ef2dfbc537
timeCreated: 1506189863
licenseType: Pro
NativeFormatImporter:
mainObjectFileID: 2100000
userData:
assetBundleName:
assetBundleVariant:

702
unity-environment/Assets/ML-Agents/Examples/Basic/Scene.unity


%YAML 1.1
%TAG !u! tag:unity3d.com,2011:
--- !u!29 &1
OcclusionCullingSettings:
m_ObjectHideFlags: 0
serializedVersion: 2
m_OcclusionBakeSettings:
smallestOccluder: 5
smallestHole: 0.25
backfaceThreshold: 100
m_SceneGUID: 00000000000000000000000000000000
m_OcclusionCullingData: {fileID: 0}
--- !u!104 &2
RenderSettings:
m_ObjectHideFlags: 0
serializedVersion: 8
m_Fog: 0
m_FogColor: {r: 0.5, g: 0.5, b: 0.5, a: 1}
m_FogMode: 3
m_FogDensity: 0.01
m_LinearFogStart: 0
m_LinearFogEnd: 300
m_AmbientSkyColor: {r: 0.212, g: 0.227, b: 0.259, a: 1}
m_AmbientEquatorColor: {r: 0.114, g: 0.125, b: 0.133, a: 1}
m_AmbientGroundColor: {r: 0.047, g: 0.043, b: 0.035, a: 1}
m_AmbientIntensity: 1
m_AmbientMode: 0
m_SubtractiveShadowColor: {r: 0.42, g: 0.478, b: 0.627, a: 1}
m_SkyboxMaterial: {fileID: 10304, guid: 0000000000000000f000000000000000, type: 0}
m_HaloStrength: 0.5
m_FlareStrength: 1
m_FlareFadeSpeed: 3
m_HaloTexture: {fileID: 0}
m_SpotCookie: {fileID: 10001, guid: 0000000000000000e000000000000000, type: 0}
m_DefaultReflectionMode: 0
m_DefaultReflectionResolution: 128
m_ReflectionBounces: 1
m_ReflectionIntensity: 1
m_CustomReflection: {fileID: 0}
m_Sun: {fileID: 0}
m_IndirectSpecularColor: {r: 0, g: 0, b: 0, a: 1}
--- !u!157 &3
LightmapSettings:
m_ObjectHideFlags: 0
serializedVersion: 11
m_GIWorkflowMode: 1
m_GISettings:
serializedVersion: 2
m_BounceScale: 1
m_IndirectOutputScale: 1
m_AlbedoBoost: 1
m_TemporalCoherenceThreshold: 1
m_EnvironmentLightingMode: 0
m_EnableBakedLightmaps: 1
m_EnableRealtimeLightmaps: 1
m_LightmapEditorSettings:
serializedVersion: 9
m_Resolution: 2
m_BakeResolution: 40
m_TextureWidth: 1024
m_TextureHeight: 1024
m_AO: 0
m_AOMaxDistance: 1
m_CompAOExponent: 1
m_CompAOExponentDirect: 0
m_Padding: 2
m_LightmapParameters: {fileID: 0}
m_LightmapsBakeMode: 1
m_TextureCompression: 1
m_FinalGather: 0
m_FinalGatherFiltering: 1
m_FinalGatherRayCount: 256
m_ReflectionCompression: 2
m_MixedBakeMode: 2
m_BakeBackend: 0
m_PVRSampling: 1
m_PVRDirectSampleCount: 32
m_PVRSampleCount: 500
m_PVRBounces: 2
m_PVRFiltering: 0
m_PVRFilteringMode: 1
m_PVRCulling: 1
m_PVRFilteringGaussRadiusDirect: 1
m_PVRFilteringGaussRadiusIndirect: 5
m_PVRFilteringGaussRadiusAO: 2
m_PVRFilteringAtrousColorSigma: 1
m_PVRFilteringAtrousNormalSigma: 1
m_PVRFilteringAtrousPositionSigma: 1
m_LightingDataAsset: {fileID: 0}
m_UseShadowmask: 1
--- !u!196 &4
NavMeshSettings:
serializedVersion: 2
m_ObjectHideFlags: 0
m_BuildSettings:
serializedVersion: 2
agentTypeID: 0
agentRadius: 0.5
agentHeight: 2
agentSlope: 45
agentClimb: 0.4
ledgeDropHeight: 0
maxJumpAcrossDistance: 0
minRegionArea: 2
manualCellSize: 0
cellSize: 0.16666667
manualTileSize: 0
tileSize: 256
accuratePlacement: 0
m_NavMeshData: {fileID: 0}
--- !u!1 &282272644
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 282272648}
- component: {fileID: 282272647}
- component: {fileID: 282272646}
- component: {fileID: 282272645}
- component: {fileID: 282272649}
m_Layer: 0
m_Name: Agent
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!23 &282272645
MeshRenderer:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 282272644}
m_Enabled: 1
m_CastShadows: 1
m_ReceiveShadows: 1
m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_Materials:
- {fileID: 2100000, guid: 260483cdfc6b14e26823a02f23bd8baa, type: 2}
m_StaticBatchInfo:
firstSubMesh: 0
subMeshCount: 0
m_StaticBatchRoot: {fileID: 0}
m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0
m_SelectedEditorRenderState: 3
m_MinimumChartSize: 4
m_AutoUVMaxDistance: 0.5
m_AutoUVMaxAngle: 89
m_LightmapParameters: {fileID: 0}
m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
--- !u!65 &282272646
BoxCollider:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 282272644}
m_Material: {fileID: 0}
m_IsTrigger: 0
m_Enabled: 1
serializedVersion: 2
m_Size: {x: 1, y: 1, z: 1}
m_Center: {x: 0, y: 0, z: 0}
--- !u!33 &282272647
MeshFilter:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 282272644}
m_Mesh: {fileID: 10202, guid: 0000000000000000e000000000000000, type: 0}
--- !u!4 &282272648
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 282272644}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 0}
m_RootOrder: 3
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &282272649
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 282272644}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 624480a72e46148118ab2e2d89b537de, type: 3}
m_Name:
m_EditorClassIdentifier:
brain: {fileID: 846768605}
observations: []
maxStep: 0
resetOnDone: 1
reward: 0
done: 0
value: 0
CummulativeReward: 0
stepCounter: 0
agentStoredAction: []
memory: []
id: 0
position: 0
smallGoalPosition: -3
largeGoalPosition: 7
largeGoal: {fileID: 984725368}
smallGoal: {fileID: 1178588871}
minPosition: -10
maxPosition: 10
--- !u!114 &395380616
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 943466ab374444748a364f9d6c3e2fe2, type: 3}
m_Name: (Clone)
m_EditorClassIdentifier:
brain: {fileID: 0}
--- !u!114 &577874698
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 41e9bda8f3cf1492fa74926a530f6f70, type: 3}
m_Name: (Clone)
m_EditorClassIdentifier:
continuousPlayerActions: []
discretePlayerActions:
- key: 97
value: 0
- key: 100
value: 1
defaultAction: -1
brain: {fileID: 846768605}
--- !u!1 &762086410
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 762086412}
- component: {fileID: 762086411}
m_Layer: 0
m_Name: Directional Light
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!108 &762086411
Light:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 762086410}
m_Enabled: 1
serializedVersion: 8
m_Type: 1
m_Color: {r: 1, g: 0.95686275, b: 0.8392157, a: 1}
m_Intensity: 1
m_Range: 10
m_SpotAngle: 30
m_CookieSize: 10
m_Shadows:
m_Type: 2
m_Resolution: -1
m_CustomResolution: -1
m_Strength: 1
m_Bias: 0.05
m_NormalBias: 0.4
m_NearPlane: 0.2
m_Cookie: {fileID: 0}
m_DrawHalo: 0
m_Flare: {fileID: 0}
m_RenderMode: 0
m_CullingMask:
serializedVersion: 2
m_Bits: 4294967295
m_Lightmapping: 4
m_AreaSize: {x: 1, y: 1}
m_BounceIntensity: 1
m_ColorTemperature: 6570
m_UseColorTemperature: 0
m_ShadowRadius: 0
m_ShadowAngle: 0
--- !u!4 &762086412
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 762086410}
m_LocalRotation: {x: 0.40821788, y: -0.23456968, z: 0.10938163, w: 0.8754261}
m_LocalPosition: {x: 0, y: 3, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 0}
m_RootOrder: 1
m_LocalEulerAnglesHint: {x: 50, y: -30, z: 0}
--- !u!1 &846768603
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 846768604}
- component: {fileID: 846768605}
m_Layer: 0
m_Name: Brain
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!4 &846768604
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 846768603}
m_LocalRotation: {x: -0, y: -0, z: -0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 1574236049}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &846768605
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 846768603}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: c676a8ddf5a5f4f64b35e9ed5028679d, type: 3}
m_Name:
m_EditorClassIdentifier:
brainParameters:
stateSize: 1
actionSize: 2
memorySize: 0
cameraResolutions: []
actionDescriptions:
- Left
- Right
actionSpaceType: 0
stateSpaceType: 0
brainType: 0
CoreBrains:
- {fileID: 577874698}
- {fileID: 395380616}
- {fileID: 1503497339}
instanceID: 10208
--- !u!1 &984725368
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 984725372}
- component: {fileID: 984725371}
- component: {fileID: 984725370}
- component: {fileID: 984725369}
m_Layer: 0
m_Name: largeGoal
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!23 &984725369
MeshRenderer:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 984725368}
m_Enabled: 1
m_CastShadows: 1
m_ReceiveShadows: 1
m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_Materials:
- {fileID: 2100000, guid: 624b24bbec31f44babfb57ef2dfbc537, type: 2}
m_StaticBatchInfo:
firstSubMesh: 0
subMeshCount: 0
m_StaticBatchRoot: {fileID: 0}
m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0
m_SelectedEditorRenderState: 3
m_MinimumChartSize: 4
m_AutoUVMaxDistance: 0.5
m_AutoUVMaxAngle: 89
m_LightmapParameters: {fileID: 0}
m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
--- !u!135 &984725370
SphereCollider:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 984725368}
m_Material: {fileID: 0}
m_IsTrigger: 0
m_Enabled: 1
serializedVersion: 2
m_Radius: 0.5
m_Center: {x: 0, y: 0, z: 0}
--- !u!33 &984725371
MeshFilter:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 984725368}
m_Mesh: {fileID: 10207, guid: 0000000000000000e000000000000000, type: 0}
--- !u!4 &984725372
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 984725368}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 0}
m_RootOrder: 4
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!1 &1178588871
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 1178588875}
- component: {fileID: 1178588874}
- component: {fileID: 1178588873}
- component: {fileID: 1178588872}
m_Layer: 0
m_Name: smallGoal
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!23 &1178588872
MeshRenderer:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1178588871}
m_Enabled: 1
m_CastShadows: 1
m_ReceiveShadows: 1
m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_Materials:
- {fileID: 2100000, guid: 624b24bbec31f44babfb57ef2dfbc537, type: 2}
m_StaticBatchInfo:
firstSubMesh: 0
subMeshCount: 0
m_StaticBatchRoot: {fileID: 0}
m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0
m_SelectedEditorRenderState: 3
m_MinimumChartSize: 4
m_AutoUVMaxDistance: 0.5
m_AutoUVMaxAngle: 89
m_LightmapParameters: {fileID: 0}
m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
--- !u!135 &1178588873
SphereCollider:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1178588871}
m_Material: {fileID: 0}
m_IsTrigger: 0
m_Enabled: 1
serializedVersion: 2
m_Radius: 0.5
m_Center: {x: 0, y: 0, z: 0}
--- !u!33 &1178588874
MeshFilter:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1178588871}
m_Mesh: {fileID: 10207, guid: 0000000000000000e000000000000000, type: 0}
--- !u!4 &1178588875
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1178588871}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 0, z: 0}
m_LocalScale: {x: 0.5, y: 0.5, z: 0.5}
m_Children: []
m_Father: {fileID: 0}
m_RootOrder: 5
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!114 &1503497339
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 0}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 35813a1be64e144f887d7d5f15b963fa, type: 3}
m_Name: (Clone)
m_EditorClassIdentifier:
brain: {fileID: 846768605}
--- !u!1 &1574236047
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 1574236049}
- component: {fileID: 1574236048}
m_Layer: 0
m_Name: Academy
m_TagString: Untagged
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!114 &1574236048
MonoBehaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1574236047}
m_Enabled: 1
m_EditorHideFlags: 0
m_Script: {fileID: 11500000, guid: 19276d4dc78ee49f1ba258293f17636c, type: 3}
m_Name:
m_EditorClassIdentifier:
maxSteps: 0
frameToSkip: 0
waitTime: 0.5
trainingConfiguration:
width: 80
height: 80
qualityLevel: 1
timeScale: 100
targetFrameRate: 60
inferenceConfiguration:
width: 1280
height: 720
qualityLevel: 5
timeScale: 1
targetFrameRate: 60
defaultResetParameters: []
done: 0
episodeCount: 1
currentStep: 0
isInference: 0
windowResize: 0
--- !u!4 &1574236049
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1574236047}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0.71938086, y: 0.27357092, z: 4.1970553}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children:
- {fileID: 846768604}
m_Father: {fileID: 0}
m_RootOrder: 2
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}
--- !u!1 &1715640920
GameObject:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
serializedVersion: 5
m_Component:
- component: {fileID: 1715640925}
- component: {fileID: 1715640924}
- component: {fileID: 1715640923}
- component: {fileID: 1715640922}
- component: {fileID: 1715640921}
m_Layer: 0
m_Name: Main Camera
m_TagString: MainCamera
m_Icon: {fileID: 0}
m_NavMeshLayer: 0
m_StaticEditorFlags: 0
m_IsActive: 1
--- !u!81 &1715640921
AudioListener:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1715640920}
m_Enabled: 1
--- !u!124 &1715640922
Behaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1715640920}
m_Enabled: 1
--- !u!92 &1715640923
Behaviour:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1715640920}
m_Enabled: 1
--- !u!20 &1715640924
Camera:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1715640920}
m_Enabled: 1
serializedVersion: 2
m_ClearFlags: 2
m_BackGroundColor: {r: 0.7411765, g: 0.7411765, b: 0.7529412, a: 0}
m_NormalizedViewPortRect:
serializedVersion: 2
x: 0
y: 0
width: 1
height: 1
near clip plane: 0.3
far clip plane: 1000
field of view: 60
orthographic: 0
orthographic size: 5
m_Depth: -1
m_CullingMask:
serializedVersion: 2
m_Bits: 4294967295
m_RenderingPath: -1
m_TargetTexture: {fileID: 0}
m_TargetDisplay: 0
m_TargetEye: 3
m_HDR: 1
m_AllowMSAA: 1
m_ForceIntoRT: 0
m_OcclusionCulling: 1
m_StereoConvergence: 10
m_StereoSeparation: 0.022
m_StereoMirrorMode: 0
--- !u!4 &1715640925
Transform:
m_ObjectHideFlags: 0
m_PrefabParentObject: {fileID: 0}
m_PrefabInternal: {fileID: 0}
m_GameObject: {fileID: 1715640920}
m_LocalRotation: {x: 0, y: 0, z: 0, w: 1}
m_LocalPosition: {x: 0, y: 1, z: -10}
m_LocalScale: {x: 1, y: 1, z: 1}
m_Children: []
m_Father: {fileID: 0}
m_RootOrder: 0
m_LocalEulerAnglesHint: {x: 0, y: 0, z: 0}

8
unity-environment/Assets/ML-Agents/Examples/Basic/Scene.unity.meta


fileFormatVersion: 2
guid: cf1d119a8748d406e90ecb623b45f92f
timeCreated: 1504127824
licenseType: Pro
DefaultImporter:
userData:
assetBundleName:
assetBundleVariant:

9
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts.meta


fileFormatVersion: 2
guid: fbcbd038eb29041f580c463e454e10fc
folderAsset: yes
timeCreated: 1503355437
licenseType: Free
DefaultImporter:
userData:
assetBundleName:
assetBundleVariant:

17
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAcademy.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class BasicAcademy : Academy {
public override void AcademyReset()
{
}
public override void AcademyStep()
{
}
}

12
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAcademy.cs.meta


fileFormatVersion: 2
guid: 19276d4dc78ee49f1ba258293f17636c
timeCreated: 1503355437
licenseType: Free
MonoImporter:
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

64
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class BasicAgent : Agent
{
public int position;
public int smallGoalPosition;
public int largeGoalPosition;
public GameObject largeGoal;
public GameObject smallGoal;
public int minPosition;
public int maxPosition;
public override List<float> CollectState()
{
List<float> state = new List<float>();
state.Add(position);
return state;
}
public override void AgentStep(float[] act)
{
float movement = act[0];
int direction = 0;
if (movement == 0) { direction = -1; }
if (movement == 1) { direction = 1; }
position += direction;
if (position < minPosition) { position = minPosition; }
if (position > maxPosition) { position = maxPosition; }
gameObject.transform.position = new Vector3(position, 0f, 0f);
if (position == smallGoalPosition)
{
done = true;
reward = 0.1f;
}
if (position == largeGoalPosition)
{
done = true;
reward = 1f;
}
}
public override void AgentReset()
{
position = 0;
minPosition = -10;
maxPosition = 10;
smallGoalPosition = -3;
largeGoalPosition = 7;
smallGoal.transform.position = new Vector3(smallGoalPosition, 0f, 0f);
largeGoal.transform.position = new Vector3(largeGoalPosition, 0f, 0f);
}
public override void AgentOnDone()
{
}
}

12
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicAgent.cs.meta


fileFormatVersion: 2
guid: 624480a72e46148118ab2e2d89b537de
timeCreated: 1503355437
licenseType: Free
MonoImporter:
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

18
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicDecision.cs


using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class BasicDecision : MonoBehaviour, Decision {
public float[] Decide (List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{
return default(float[]);
}
public float[] MakeMemory (List<float> state, List<Camera> observation, float reward, bool done, float[] memory)
{
return default(float[]);
}
}

12
unity-environment/Assets/ML-Agents/Examples/Basic/Scripts/BasicDecision.cs.meta


fileFormatVersion: 2
guid: 99399d2439f894b149d8e67b85b6e07a
timeCreated: 1503355437
licenseType: Free
MonoImporter:
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存