浏览代码

New Learning Brain (#1303)

* Initial Commit

* attempt at refactor

* Put all static methods into the CoreInternalBrain

* improvements

* more testing

* modifications

* renamed epsilon

* misc

* Now supports discrete actions

* added discrete support and RNN and visual. Left to do is refactor and save variables into models

* code cleaning

* made a tensor generator and applier

* fix on the models.py file

* Moved the Checks to a different Class

* Added some unit tests

* BugFix

* Need to generate the output tensors as well as inputs before executing the graph

* Made NodeNames static and created a new namespace

* Added comments to the TensorAppliers

* Started adding comments on the TensorGenerators code

* Added comments for the Tensor Generator

* Moving the helper classes into a separate folder

* Added initial comments to the TensorChecks

* Renamed NodeNames -> TensorNames

* Removing warnings in tests

* Now using Aut...
/develop-generalizationTraining-TrainerController
GitHub 6 年前
当前提交
6c354d16
共有 31 个文件被更改,包括 1626 次插入458 次删除
  1. 56
      UnitySDK/Assets/ML-Agents/Editor/LearningBrainEditor.cs
  2. 1
      UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
  3. 3
      UnitySDK/Assets/ML-Agents/Editor/Tests/UtilitiesTests.cs
  4. 20
      UnitySDK/Assets/ML-Agents/Scripts/Brain.cs
  5. 1
      UnitySDK/Assets/ML-Agents/Scripts/HeuristicBrain.cs
  6. 498
      UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs
  7. 3
      UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs.meta
  8. 2
      UnitySDK/Assets/ML-Agents/Scripts/PlayerBrain.cs
  9. 6
      ml-agents/mlagents/trainers/bc/models.py
  10. 27
      ml-agents/mlagents/trainers/models.py
  11. 8
      ml-agents/mlagents/trainers/policy.py
  12. 10
      ml-agents/mlagents/trainers/ppo/policy.py
  13. 3
      ml-agents/mlagents/trainers/ppo/trainer.py
  14. 138
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
  15. 3
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs.meta
  16. 158
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
  17. 11
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs.meta
  18. 8
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain.meta
  19. 143
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
  20. 3
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs.meta
  21. 224
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
  22. 3
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs.meta
  23. 531
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs
  24. 3
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs.meta
  25. 80
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
  26. 11
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs.meta
  27. 99
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
  28. 3
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs.meta
  29. 25
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorNames.cs
  30. 3
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorNames.cs.meta

56
UnitySDK/Assets/ML-Agents/Editor/LearningBrainEditor.cs


[CustomEditor(typeof(LearningBrain))]
public class LearningBrainEditor : BrainEditor
{
private const string ModelPropName = "model";
private const float TimeBetweenModelReloads = 2f;
// Time since the last reload of the model
private float _timeSinceModelReload;
// Whether or not the model needs to be reloaded
private bool _requireReload;
/// <summary>
/// Called when the user opens the Inspector for the LearningBrain
/// </summary>
public void OnEnable()
{
_requireReload = true;
EditorApplication.update += IncreaseTimeSinceLastModelReload;
}
/// <summary>
/// Called when the user leaves the Inspector for the LearningBrain
/// </summary>
public void OnDisable()
{
EditorApplication.update -= IncreaseTimeSinceLastModelReload;
}
EditorGUI.BeginChangeCheck();
EditorGUILayout.PropertyField(serializedBrain.FindProperty("graphModel"), true);
var tfGraphModel = serializedBrain.FindProperty(ModelPropName);
EditorGUILayout.ObjectField(tfGraphModel);
if (EditorGUI.EndChangeCheck())
{
_requireReload = true;
}
if (_requireReload && _timeSinceModelReload > TimeBetweenModelReloads)
{
brain.ReloadModel();
_requireReload = false;
_timeSinceModelReload = 0;
}
// Display all failed checks
var failedChecks = brain.GetModelFailedChecks();
foreach (var check in failedChecks)
{
if (check != null)
{
EditorGUILayout.HelpBox(check, MessageType.Warning);
}
}
}
/// <summary>
/// Increases the time since last model reload by the deltaTime since the last Update call
/// from the UnityEditor
/// </summary>
private void IncreaseTimeSinceLastModelReload()
{
_timeSinceModelReload += Time.deltaTime;
}
}
}

1
UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs


protected override void DecideAction()
{
base.DecideAction();
numberOfCallsToDecideAction++;
agentInfos.Clear();
}

3
UnitySDK/Assets/ML-Agents/Editor/Tests/UtilitiesTests.cs


using NUnit.Framework;
using UnityEngine;
namespace MLAgents.Tests
{

{
var output = Utilities.CumSum(new int[]{1, 2, 3, 10});
CollectionAssert.AreEqual(output, new int[] {0, 1, 3, 6, 16});
output = Utilities.CumSum(new int[0]);
CollectionAssert.AreEqual(output, new int[]{0});

20
UnitySDK/Assets/ML-Agents/Scripts/Brain.cs


{
if (!_isInitialized)
{
FindObjectOfType<Academy>().BrainDecideAction += DecideAction;
FindObjectOfType<Academy>().BrainDecideAction += BrainDecideAction;
/// <summary>
/// Calls the DecideAction method that the concrete brain implements.
/// </summary>
private void BrainDecideAction()
{
brainBatcher?.SendBrainInfo(name, agentInfos);
DecideAction();
}
/// Is called once per Environment Step when the Brain has been initialized.
/// Is called once per Environment Step after the Brain has been initialized.
protected virtual void DecideAction()
{
brainBatcher?.SendBrainInfo(name, agentInfos);
}
protected abstract void DecideAction();
}
}

1
UnitySDK/Assets/ML-Agents/Scripts/HeuristicBrain.cs


///Uses the Decision Component to decide that action to take
protected override void DecideAction()
{
base.DecideAction();
if (decision == null)
{
throw new UnityAgentsException(

498
UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs


using System;
using System.Collections;
#if UNITY_EDITOR
using UnityEditor;
#endif
using Random = UnityEngine.Random;
#if ENABLE_TENSORFLOW
using TensorFlow;
#endif
using MLAgents.InferenceBrain;
using UnityEngine.MachineLearning.InferenceEngine;
using UnityEngine.Profiling;
namespace MLAgents
{

/// the checkbox Control. When using a pretrained model, just drag the Model file into the
/// Model property of the Learning Brain.
/// When the Learning Braom is noe training, it uses a TensorFlow model to make decisions.
/// The property model corresponds to the Model currently attached to the Brain. Before
/// being used, a call to ReloadModel is required.
/// When the Learning Brain is not training, it uses a TensorFlow model to make decisions.
/// The Proximal Policy Optimization (PPO) and Behavioral Cloning algorithms included with
/// the ML-Agents SDK produce trained TensorFlow models that you can use with the
/// Learning Brain.

{
[System.Serializable]
private struct TensorFlowAgentPlaceholder
{
public enum TensorType
{
Integer,
FloatingPoint
};
private TensorGenerator _tensorGenerator;
private TensorApplier _tensorApplier;
private ModelParamLoader _modelParamLoader;
public Model model;
public string name;
public TensorType valueType;
public float minValue;
public float maxValue;
}
[Tooltip("This must be the bytes file corresponding to the pretrained TensorFlow graph.")]
/// Modify only in inspector : Reference to the Graph asset
public TextAsset graphModel;
private InferenceEngine _engine;
private IEnumerable<Tensor> _inferenceInputs;
private IEnumerable<Tensor> _inferenceOutputs;
private bool isControlled;
[SerializeField]
[Tooltip(
"If your graph takes additional inputs that are fixed (example: noise level) you can specify them here.")]
/// Modify only in inspector : If your graph takes additional inputs that are fixed you can specify them here.
private TensorFlowAgentPlaceholder[] graphPlaceholders;
/// Modify only in inspector : Name of the placholder of the batch size
public string BatchSizePlaceholderName = "batch_size";
/// Modify only in inspector : Name of the state placeholder
public string VectorObservationPlacholderName = "vector_observation";
/// Modify only in inspector : Name of the recurrent input
public string RecurrentInPlaceholderName = "recurrent_in";
/// Modify only in inspector : Name of the recurrent output
public string RecurrentOutPlaceholderName = "recurrent_out";
/// Modify only in inspector : Names of the observations placeholders
public string[] VisualObservationPlaceholderName;
/// Modify only in inspector : Name of the action node
public string ActionPlaceholderName = "action";
/// Modify only in inspector : Name of the previous action node
public string PreviousActionPlaceholderName = "prev_action";
/// Name of the action mask node
private string ActionMaskPlaceholderName = "action_masks";
#if ENABLE_TENSORFLOW
TFGraph graph;
TFSession session;
bool hasRecurrent;
bool hasState;
bool hasBatchSize;
bool hasPrevAction;
bool hasMaskedActions;
bool hasValueEstimate;
float[,] inputState;
int[,] inputPrevAction;
List<float[,,,]> observationMatrixList;
float[,] inputOldMemories;
float[,] maskedActions;
List<Texture2D> texturesHolder;
int memorySize;
#endif
private bool _isControlled;
/// <summary>
/// When Called, the brain will be controlled externally. It will not use the

{
isControlled = true;
_isControlled = true;
/// <inheritdoc />
#if ENABLE_TENSORFLOW
#if UNITY_ANDROID && !UNITY_EDITOR
// This needs to ba called only once and will raise an exception if
// there are multiple internal brains
try{
TensorFlowSharp.Android.NativeBinding.Init();
ReloadModel();
catch{
}
#endif
if (graphModel != null)
/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for
/// the agents
/// </summary>
/// <param name="seed"> The seed that will be used to initialize the RandomNormal
/// and Multinomial obsjects used when running inference.</param>
/// <exception cref="UnityAgentsException">Throws an error when the model is null
/// </exception>
public void ReloadModel(int seed = 0)
{
if (model != null)
graph = new TFGraph();
graph.Import(graphModel.bytes);
session = new TFSession(graph);
if (graph[BatchSizePlaceholderName] != null)
var config = new InferenceEngineConfig
hasBatchSize = true;
}
if ((graph[RecurrentInPlaceholderName] != null) &&
(graph[RecurrentOutPlaceholderName] != null))
{
hasRecurrent = true;
var runner = session.GetRunner();
runner.Fetch(graph["memory_size"][0]);
var networkOutput = runner.Run()[0].GetValue();
memorySize = (int) networkOutput;
}
if (graph[VectorObservationPlacholderName] != null)
{
hasState = true;
}
if (graph[PreviousActionPlaceholderName] != null)
{
hasPrevAction = true;
}
if (graph["value_estimate"] != null)
{
hasValueEstimate = true;
}
if (graph[ActionMaskPlaceholderName] != null)
{
hasMaskedActions = true;
}
Device = InferenceEngineConfig.DeviceType.CPU
};
_engine = InferenceAPI.LoadModel(model, config);
observationMatrixList = new List<float[,,,]>();
texturesHolder = new List<Texture2D>();
#endif
else
{
_engine = null;
}
_modelParamLoader = ModelParamLoader.GetLoaderAndCheck(_engine, brainParameters);
_inferenceInputs = _modelParamLoader.GetInputTensors();
_inferenceOutputs = _modelParamLoader.GetOutputTensors();
_tensorGenerator = new TensorGenerator(brainParameters, seed);
_tensorApplier = new TensorApplier(brainParameters, seed);
/// <summary>
/// Return a list of failed checks corresponding to the failed compatibility checks
/// between the Model and the BrainParameters. Note : This does not reload the model.
/// If changes have been made to the BrainParameters or the Model, the model must be
/// reloaded using GiveModel before trying to get the compatibility checks.
/// </summary>
/// <returns> The list of the failed compatibility checks between the Model and the
/// Brain Parameters</returns>
public IEnumerable<string> GetModelFailedChecks()
{
return (_modelParamLoader != null) ? _modelParamLoader.GetChecks() : new List<string>();
}
/// Uses the stored information to run the tensorflow graph and generate
/// the actions.
/// <inheritdoc />
#if ENABLE_TENSORFLOW
base.DecideAction();
if (isControlled)
if (_isControlled)
int currentBatchSize = agentInfos.Count();
List<Agent> agentList = agentInfos.Keys.ToList();
var currentBatchSize = agentInfos.Count();
// Create the state tensor
if (hasState)
{
int stateLength = 1;
stateLength = brainParameters.vectorObservationSize;
inputState =
new float[currentBatchSize, stateLength * brainParameters.numStackedVectorObservations];
var i = 0;
foreach (Agent agent in agentList)
{
List<float> stateList = agentInfos[agent].stackedVectorObservation;
for (int j =
0;
j < stateLength * brainParameters.numStackedVectorObservations;
j++)
{
inputState[i, j] = stateList[j];
}
i++;
}
}
// Create the state tensor
if (hasPrevAction)
{
int totalNumberActions = brainParameters.vectorActionSize.Length;
inputPrevAction = new int[currentBatchSize, totalNumberActions];
var i = 0;
foreach (Agent agent in agentList)
{
float[] actionList = agentInfos[agent].storedVectorActions;
for (var j = 0 ; j < totalNumberActions; j++)
{
inputPrevAction[i,j] = Mathf.FloorToInt(actionList[j]);
}
i++;
}
}
if (hasMaskedActions)
{
maskedActions = new float[
currentBatchSize,
brainParameters.vectorActionSize.Sum()
];
var i = 0;
foreach (Agent agent in agentList)
{
for (int j = 0; j < brainParameters.vectorActionSize.Sum(); j++)
{
if (agentInfos[agent].actionMasks != null)
{
maskedActions[i, j] = agentInfos[agent].actionMasks[j] ? 0.0f : 1.0f;
}
else
{
maskedActions[i, j] = 1.0f;
}
}
i++;
}
}
// Prepare the input tensors to be feed into the engine
_tensorGenerator.GenerateTensors(_inferenceInputs, currentBatchSize, agentInfos);
observationMatrixList.Clear();
for (int observationIndex =
0;
observationIndex < brainParameters.cameraResolutions.Length;
observationIndex++)
{
texturesHolder.Clear();
foreach (Agent agent in agentList)
{
texturesHolder.Add(agentInfos[agent].visualObservations[observationIndex]);
}
// Prepare the output tensors to be feed into the engine
_tensorGenerator.GenerateTensors(_inferenceOutputs, currentBatchSize, agentInfos);
observationMatrixList.Add(
Utilities.TextureToFloatArray(texturesHolder,
brainParameters.cameraResolutions[observationIndex].blackAndWhite));
}
// Execute the Model
Profiler.BeginSample($"MLAgents.{name}.ExecuteGraph");
_engine.ExecuteGraph(_inferenceInputs, _inferenceOutputs);
Profiler.EndSample();
// Create the recurrent tensor
if (hasRecurrent)
{
// Need to have variable memory size
inputOldMemories = new float[currentBatchSize, memorySize];
var i = 0;
foreach (Agent agent in agentList)
{
float[] m = agentInfos[agent].memories.ToArray();
for (int j = 0; j < m.Length; j++)
{
inputOldMemories[i, j] = m[j];
}
i++;
}
}
var runner = session.GetRunner();
try
{
runner.Fetch(graph[ActionPlaceholderName][0]);
}
catch
{
throw new UnityAgentsException(string.Format(
@"The node {0} could not be found. Please make sure the node name is correct",
ActionPlaceholderName));
}
if (hasBatchSize)
{
runner.AddInput(graph[BatchSizePlaceholderName][0], new int[] {currentBatchSize});
}
foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)
{
try
{
if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.FloatingPoint)
{
runner.AddInput(graph[placeholder.name][0],
new float[] {Random.Range(placeholder.minValue, placeholder.maxValue)});
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.Integer)
{
runner.AddInput(graph[placeholder.name][0],
new int[] {Random.Range((int) placeholder.minValue, (int) placeholder.maxValue + 1)});
}
}
catch
{
throw new UnityAgentsException(string.Format(
@"One of the Tensorflow placeholder cound nout be found.
In brain {0}, there are no {1} placeholder named {2}.",
name, placeholder.valueType.ToString(), placeholder.name));
}
}
// Create the state tensor
if (hasState)
{
runner.AddInput(graph[VectorObservationPlacholderName][0], inputState);
}
// Create the previous action tensor
if (hasPrevAction)
{
runner.AddInput(graph[PreviousActionPlaceholderName][0], inputPrevAction);
}
// Create the mask action tensor
if (hasMaskedActions)
{
runner.AddInput(graph[ActionMaskPlaceholderName][0], maskedActions);
}
// Create the observation tensors
for (int obsNumber =
0;
obsNumber < brainParameters.cameraResolutions.Length;
obsNumber++)
{
runner.AddInput(graph[VisualObservationPlaceholderName[obsNumber]][0],
observationMatrixList[obsNumber]);
}
if (hasRecurrent)
{
runner.AddInput(graph["sequence_length"][0], 1);
runner.AddInput(graph[RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[RecurrentOutPlaceholderName][0]);
}
if (hasValueEstimate)
{
runner.Fetch(graph["value_estimate"][0]);
}
TFTensor[] networkOutput;
try
{
networkOutput = runner.Run();
}
catch (TFException e)
{
string errorMessage = e.Message;
try
{
errorMessage =
$@"The tensorflow graph needs an input for {e.Message.Split(new string[] {"Node: "}, 0)[1].Split('=')[0]} of type {e.Message.Split(new string[] {"dtype="}, 0)[1].Split(',')[0]}";
}
finally
{
throw new UnityAgentsException(errorMessage);
}
}
// Create the recurrent tensor
if (hasRecurrent)
{
float[,] recurrentTensor = networkOutput[1].GetValue() as float[,];
var i = 0;
foreach (Agent agent in agentList)
{
var m = new float[memorySize];
for (int j = 0; j < memorySize; j++)
{
m[j] = recurrentTensor[i, j];
}
agent.UpdateMemoriesAction(m.ToList());
i++;
}
}
if (hasValueEstimate)
{
float[,] value_estimates = new float[currentBatchSize,1];
if (hasRecurrent)
{
value_estimates = networkOutput[2].GetValue() as float[,];
}
else
{
value_estimates = networkOutput[1].GetValue() as float[,];
}
var i = 0;
foreach (Agent agent in agentList)
{
agent.UpdateValueAction(value_estimates[i,0]);
}
}
if (brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
var output = networkOutput[0].GetValue() as float[,];
var i = 0;
foreach (Agent agent in agentList)
{
var a = new float[brainParameters.vectorActionSize[0]];
for (int j = 0; j < brainParameters.vectorActionSize[0]; j++)
{
a[j] = output[i, j];
}
agent.UpdateVectorAction(a);
i++;
}
}
else if (brainParameters.vectorActionSpaceType == SpaceType.discrete)
{
long[,] output = networkOutput[0].GetValue() as long[,];
var i = 0;
foreach (Agent agent in agentList)
{
var actSize = brainParameters.vectorActionSize.Length;
var a = new float[actSize];
for (int actIdx = 0; actIdx < actSize; actIdx++)
{
a[actIdx] = output[i, actIdx];
}
agent.UpdateVectorAction(a);
i++;
}
}
#else
base.DecideAction();
if (isControlled)
{
agentInfos.Clear();
return;
}
if (agentInfos.Count > 0)
{
throw new UnityAgentsException(string.Format(
@"The brain {0} was set to Internal but the Tensorflow
library is not present in the Unity project.",
name));
}
#endif
// Update the outputs
_tensorApplier.ApplyTensors(_inferenceOutputs, agentInfos);
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/LearningBrain.cs.meta


fileFormatVersion: 2
guid: 8b23992c8eb17439887f5e944bf04a40
timeCreated: 1504070347
licenseType: Free
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0

2
UnitySDK/Assets/ML-Agents/Scripts/PlayerBrain.cs


/// decide action
protected override void DecideAction()
{
base.DecideAction();
if (brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
foreach (Agent agent in agentInfos.Keys)

6
ml-agents/mlagents/trainers/bc/models.py


self.action_probs = tf.concat(
[tf.nn.softmax(branch) for branch in policy_branches], axis=1, name="action_probs")
self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks")
self.sample_action_float, _ = self.create_discrete_action_masking_layer(
tf.concat(policy_branches, axis = 1), self.action_masks, self.act_size)
self.sample_action_float = tf.identity(self.sample_action_float, name="action")
self.sample_action_float, normalized_logits = self.create_discrete_action_masking_layer(
tf.concat(policy_branches, axis=1), self.action_masks, self.act_size)
tf.identity(normalized_logits, name='action')
self.sample_action = tf.cast(self.sample_action_float, tf.int32)
self.true_action = tf.placeholder(shape=[None, len(policy_branches)], dtype=tf.int32, name="teacher_action")
self.action_oh = tf.concat([

27
ml-agents/mlagents/trainers/models.py


class LearningModel(object):
_version_number_ = 1
def __init__(self, m_size, normalize, use_recurrent, brain, seed):
tf.set_random_seed(seed)
self.brain = brain

self.sequence_length = tf.placeholder(shape=None, dtype=tf.int32, name='sequence_length')
self.mask_input = tf.placeholder(shape=[None], dtype=tf.float32, name='masks')
self.mask = tf.cast(self.mask_input, tf.int32)
self.m_size = m_size
self.use_recurrent = use_recurrent
if self.use_recurrent:
self.m_size = m_size
else:
self.m_size = 0
self.use_recurrent = use_recurrent
tf.Variable(int(brain.vector_action_space_type == 'continuous'),
name='is_continuous_control', trainable=False, dtype=tf.int32)
tf.Variable(self._version_number_, name='version_number', trainable=False, dtype=tf.int32)
tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)
if brain.vector_action_space_type == 'continuous':
tf.Variable(self.act_size[0], name="action_output_shape", trainable=False, dtype=tf.int32)
else:
tf.Variable(sum(self.act_size), name="action_output_shape", trainable=False, dtype=tf.int32)
@staticmethod
def create_global_steps():

hidden_streams = self.create_observation_streams(2, h_size, num_layers)
if self.use_recurrent:
tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)
self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32,
name='recurrent_in')
_half_point = int(self.m_size / 2)

sigma_sq = tf.exp(log_sigma_sq)
epsilon = tf.random_normal(tf.shape(mu), dtype=tf.float32)
self.epsilon = tf.placeholder(shape=[None, self.act_size[0]], dtype=tf.float32, name='epsilon')
self.output_pre = mu + tf.sqrt(sigma_sq) * epsilon
self.output_pre = mu + tf.sqrt(sigma_sq) * self.epsilon
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name='action')
self.selected_actions = tf.stop_gradient(output_post)

hidden = hidden_streams[0]
if self.use_recurrent:
tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32)
self.prev_action = tf.placeholder(shape=[None, len(self.act_size)], dtype=tf.int32,
name='prev_action')
prev_action_oh = tf.concat([

output, normalized_logits = self.create_discrete_action_masking_layer(
self.all_log_probs, self.action_masks, self.act_size)
self.output = tf.identity(output, name="action")
self.output = tf.identity(output)
self.normalized_logits = tf.identity(normalized_logits, name='action')
value = tf.layers.dense(hidden, 1, activation=None)
self.value = tf.identity(value, name="value_estimate")

8
ml-agents/mlagents/trainers/policy.py


functions to interact with it to perform evaluate and updating.
"""
possible_output_nodes = ['action', 'value_estimate',
'action_probs', 'recurrent_out', 'memory_size']
'action_probs', 'recurrent_out', 'memory_size',
'version_number', 'is_continuous_control',
'action_output_shape']
def __init__(self, seed, brain, trainer_parameters):
"""

def export_model(self):
"""
Exports latest saved model to .bytes format for Unity embedding.
Exports latest saved model to .tf format for Unity embedding.
"""
with self.graph.as_default():
target_nodes = ','.join(self._process_graph())

input_binary=True,
input_checkpoint=ckpt.model_checkpoint_path,
output_node_names=target_nodes,
output_graph=(self.model_path + '.bytes'),
output_graph=(self.model_path + '.tf'),
clear_devices=True, initializer_nodes='', input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0')

10
ml-agents/mlagents/trainers/ppo/policy.py


import logging
import numpy as np
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.policy import Policy

"""
feed_dict = {self.model.batch_size: len(brain_info.vector_observations),
self.model.sequence_length: 1}
epsilon = None
if self.use_recurrent:
if not self.use_continuous_act:
feed_dict[self.model.prev_action] = brain_info.previous_vector_actions.reshape(

feed_dict[self.model.memory_in] = brain_info.memories
if self.use_continuous_act:
epsilon = np.random.normal(
size=(len(brain_info.vector_observations), self.model.act_size[0]))
feed_dict[self.model.epsilon] = epsilon
if self.use_continuous_act:
run_out['random_normal_epsilon'] = epsilon
return run_out
def update(self, mini_batch, num_sequences):

[-1, sum(self.model.act_size)])}
if self.use_continuous_act:
feed_dict[self.model.output_pre] = mini_batch['actions_pre'].reshape(
[-1, self.model.act_size[0]])
feed_dict[self.model.epsilon] = mini_batch['random_normal_epsilon'].reshape(
[-1, self.model.act_size[0]])
else:
feed_dict[self.model.action_holder] = mini_batch['actions'].reshape(

3
ml-agents/mlagents/trainers/ppo/trainer.py


if self.policy.use_continuous_act:
actions_pre = stored_take_action_outputs['pre_action']
self.training_buffer[agent_id]['actions_pre'].append(actions_pre[idx])
epsilons = stored_take_action_outputs['random_normal_epsilon']
self.training_buffer[agent_id]['random_normal_epsilon'].append(
epsilons[idx])
else:
self.training_buffer[agent_id]['action_mask'].append(
stored_info.action_masks[idx])

138
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs


using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.MachineLearning.InferenceEngine;
using UnityEngine.MachineLearning.InferenceEngine.Util;
using System.Reflection;
using MLAgents.InferenceBrain;
namespace MLAgents.Tests
{
public class EditModeTestInternalBrainTensorApplier
{
private class TestAgent : Agent
{
public AgentAction GetAction()
{
FieldInfo f = typeof(Agent).GetField(
"action", BindingFlags.Instance | BindingFlags.NonPublic);
return (AgentAction) f.GetValue(this);
}
}
private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
{
var goA = new GameObject("goA");
var agentA = goA.AddComponent<TestAgent>();
var infoA = new AgentInfo();
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();
var infoB = new AgentInfo();
return new Dictionary<Agent, AgentInfo>(){{agentA, infoA},{agentB, infoB}};
}
[Test]
public void Contruction()
{
var bp = new BrainParameters();
var tensorGenerator = new TensorApplier(bp, 0);
Assert.IsNotNull(tensorGenerator);
}
[Test]
public void ApplyContinuousActionOutput()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 3},
Data = new float[,] {{1, 2, 3}, {4, 5, 6}}
};
var agentInfos = GetFakeAgentInfos();
var applier = new ContinuousActionOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agent = agents[0] as TestAgent;
var action = agent.GetAction();
Assert.AreEqual(action.vectorActions[0], 1);
Assert.AreEqual(action.vectorActions[1], 2);
Assert.AreEqual(action.vectorActions[2], 3);
agent = agents[1] as TestAgent;
action = agent.GetAction();
Assert.AreEqual(action.vectorActions[0], 4);
Assert.AreEqual(action.vectorActions[1], 5);
Assert.AreEqual(action.vectorActions[2], 6);
}
[Test]
public void ApplyDiscreteActionOutput()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 5},
Data = new float[,] {{0.5f, 22.5f, 0.1f, 5f, 1f},
{4f, 5f, 6f, 7f, 8f}}
};
var agentInfos = GetFakeAgentInfos();
var applier = new DiscreteActionOutputApplier(new int[]{2, 3}, 0);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agent = agents[0] as TestAgent;
var action = agent.GetAction();
Assert.AreEqual(action.vectorActions[0], 1);
Assert.AreEqual(action.vectorActions[1], 1);
agent = agents[1] as TestAgent;
action = agent.GetAction();
Assert.AreEqual(action.vectorActions[0], 1);
Assert.AreEqual(action.vectorActions[1], 2);
}
[Test]
public void ApplyMemoryOutput()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 5},
Data = new float[,] {{0.5f, 22.5f, 0.1f, 5f, 1f},
{4f, 5f, 6f, 7f, 8f}}
};
var agentInfos = GetFakeAgentInfos();
var applier = new MemoryOutputApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agent = agents[0] as TestAgent;
var action = agent.GetAction();
Assert.AreEqual(action.memories[0], 0.5f);
Assert.AreEqual(action.memories[1], 22.5f);
agent = agents[1] as TestAgent;
action = agent.GetAction();
Assert.AreEqual(action.memories[2], 6);
Assert.AreEqual(action.memories[3], 7);
}
[Test]
public void ApplyValueEstimate()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 1},
Data = new float[,] {{0.5f}, {8f}}
};
var agentInfos = GetFakeAgentInfos();
var applier = new ValueEstimateApplier();
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agent = agents[0] as TestAgent;
var action = agent.GetAction();
Assert.AreEqual(action.value, 0.5f);
agent = agents[1] as TestAgent;
action = agent.GetAction();
Assert.AreEqual(action.value, 8);
}
}
}

3
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs.meta


fileFormatVersion: 2
guid: be419f7ed5c24b24a6f2636d3b107535
timeCreated: 1537915674

158
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs


using System;
using System.Collections.Generic;
using System.Linq;
using NUnit.Framework;
using UnityEngine;
using UnityEngine.MachineLearning.InferenceEngine;
using UnityEngine.MachineLearning.InferenceEngine.Util;
using MLAgents.InferenceBrain;
namespace MLAgents.Tests
{
public class EditModeTestInternalBrainTensorGenerator
{
private class TestAgent : Agent
{
}
private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
{
var goA = new GameObject("goA");
var agentA = goA.AddComponent<TestAgent>();
var infoA = new AgentInfo()
{
stackedVectorObservation = (new float[] {1f, 2f, 3f}).ToList(),
memories = null,
storedVectorActions = new float[] {1, 2},
actionMasks = null,
};
var goB = new GameObject("goB");
var agentB = goB.AddComponent<TestAgent>();
var infoB = new AgentInfo()
{
stackedVectorObservation = (new float[] {4f, 5f, 6f}).ToList(),
memories = (new float[] {1f, 1f, 1f}).ToList(),
storedVectorActions = new float[] {3, 4},
actionMasks = new bool[] {true, false, false, false, false},
};
return new Dictionary<Agent, AgentInfo>(){{agentA, infoA},{agentB, infoB}};
}
[Test]
public void Contruction()
{
var bp = new BrainParameters();
var tensorGenerator = new TensorGenerator(bp, 0);
Assert.IsNotNull(tensorGenerator);
}
[Test]
public void GenerateBatchSize()
{
var inputTensor = new Tensor();
var batchSize = 4;
var generator = new BatchSizeGenerator();
generator.Generate(inputTensor, batchSize, null);
Assert.IsNotNull(inputTensor.Data as int[]);
Assert.AreEqual((inputTensor.Data as int[])[0], batchSize);
}
[Test]
public void GenerateSequenceLength()
{
var inputTensor = new Tensor();
var batchSize = 4;
var generator = new SequenceLengthGenerator();
generator.Generate(inputTensor, batchSize, null);
Assert.IsNotNull(inputTensor.Data as int[]);
Assert.AreEqual((inputTensor.Data as int[])[0], 1);
}
[Test]
public void GenerateVectorObservation()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 3}
};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new VectorObservationGenerator();
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data as float[,]);
Assert.AreEqual((inputTensor.Data as float[,])[0, 0], 1);
Assert.AreEqual((inputTensor.Data as float[,])[0, 2], 3);
Assert.AreEqual((inputTensor.Data as float[,])[1, 0], 4);
Assert.AreEqual((inputTensor.Data as float[,])[1, 2], 6);
}
[Test]
public void GenerateRecurrentInput()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 5}
};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new RecurrentInputGenerator();
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data as float[,]);
Assert.AreEqual((inputTensor.Data as float[,])[0, 0], 0);
Assert.AreEqual((inputTensor.Data as float[,])[0, 4], 0);
Assert.AreEqual((inputTensor.Data as float[,])[1, 0], 1);
Assert.AreEqual((inputTensor.Data as float[,])[1, 4], 0);
}
[Test]
public void GeneratePreviousActionInput()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 2},
ValueType = Tensor.TensorType.FloatingPoint
};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new PreviousActionInputGenerator();
Assert.Catch<NotImplementedException>(
() => generator.Generate(inputTensor, batchSize, agentInfos));
inputTensor.ValueType = Tensor.TensorType.Integer;
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data as int[,]);
Assert.AreEqual((inputTensor.Data as int[,])[0, 0], 1);
Assert.AreEqual((inputTensor.Data as int[,])[0, 1], 2);
Assert.AreEqual((inputTensor.Data as int[,])[1, 0], 3);
Assert.AreEqual((inputTensor.Data as int[,])[1, 1], 4);
}
[Test]
public void GenerateActionMaskInput()
{
var inputTensor = new Tensor()
{
Shape = new long[] {2, 5},
ValueType = Tensor.TensorType.FloatingPoint
};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new ActionMaskInputGenerator();
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data as float[,]);
Assert.AreEqual((inputTensor.Data as float[,])[0, 0], 1);
Assert.AreEqual((inputTensor.Data as float[,])[0, 4], 1);
Assert.AreEqual((inputTensor.Data as float[,])[1, 0], 0);
Assert.AreEqual((inputTensor.Data as float[,])[1, 4], 1);
}
}
}

11
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs.meta


fileFormatVersion: 2
guid: d2d2076c51c414ac7a91f8fbf15d4f7c
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

8
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain.meta


fileFormatVersion: 2
guid: 79c170c0af66140e68d7eca827f0d788
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

143
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs


using UnityEngine.MachineLearning.InferenceEngine;
using System.Collections.Generic;
using UnityEngine;
using UnityEngine.MachineLearning.InferenceEngine.Util;
namespace MLAgents.InferenceBrain
{
/// <summary>
/// The Applier for the Continuous Action output tensor. Tensor is assumed to contain the
/// continuous action data of the agents in the batch.
/// </summary>
public class ContinuousActionOutputApplier : TensorApplier.Applier
{
public void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo)
{
var tensorDataAction = tensor.Data as float[,];
var actionSize = tensor.Shape[1];
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
var action = new float[actionSize];
for (var j = 0; j < actionSize; j++)
{
action[j] = tensorDataAction[agentIndex, j];
}
agent.UpdateVectorAction(action);
agentIndex++;
}
}
}
/// <summary>
/// The Applier for the Discrete Action output tensor. Uses multinomial to sample discrete
/// actions from the logits contained in the tensor.
/// </summary>
public class DiscreteActionOutputApplier : TensorApplier.Applier
{
private int[] _actionSize;
private Multinomial _multinomial;
public DiscreteActionOutputApplier(int[] actionSize, int seed)
{
_actionSize = actionSize;
_multinomial = new Multinomial(seed);
}
public void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo)
{
var tensorDataProbabilities = tensor.Data as float[,];
var batchSize = agentInfo.Keys.Count;
var actions = new float[batchSize, _actionSize.Length];
var startActionIndices = Utilities.CumSum(_actionSize);
for (var actionIndex=0; actionIndex < _actionSize.Length; actionIndex++)
{
var nBranchAction = _actionSize[actionIndex];
var actionProbs = new float[batchSize, nBranchAction];
for (var batchIndex = 0; batchIndex < batchSize; batchIndex++)
{
for (var branchActionIndex = 0;
branchActionIndex < nBranchAction;
branchActionIndex++)
{
actionProbs[batchIndex, branchActionIndex] =
tensorDataProbabilities[
batchIndex, startActionIndices[actionIndex] + branchActionIndex];
}
}
var inputTensor = new Tensor()
{
ValueType = Tensor.TensorType.FloatingPoint,
Shape = new long[]{batchSize, _actionSize[actionIndex]},
Data = actionProbs
};
var outputTensor = new Tensor()
{
ValueType = Tensor.TensorType.FloatingPoint,
Shape = new long[]{batchSize, 1},
Data = new float[batchSize, 1]
};
_multinomial.Eval(inputTensor, outputTensor);
var outTensor = outputTensor.Data as float[,];
for (var ii = 0; ii < batchSize; ii++)
{
actions[ii, actionIndex] = outTensor[ii, 0];
}
}
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
var action = new float[_actionSize.Length];
for (var j = 0; j < _actionSize.Length; j++)
{
action[j] = actions[agentIndex, j];
}
agent.UpdateVectorAction(action);
agentIndex++;
}
}
}
/// <summary>
/// The Applier for the Memory output tensor. Tensor is assumed to contain the new
/// memory data of the agents in the batch.
/// </summary>
public class MemoryOutputApplier : TensorApplier.Applier
{
public void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo)
{
var tensorDataMemory = tensor.Data as float[,];
var agentIndex = 0;
var memorySize = tensor.Shape[1];
foreach (var agent in agentInfo.Keys)
{
var memory = new List<float>();
for (var j = 0; j < memorySize; j++)
{
memory.Add(tensorDataMemory[agentIndex, j]);
}
agent.UpdateMemoriesAction(memory);
agentIndex++;
}
}
}
/// <summary>
/// The Applier for the Value Estimate output tensor. Tensor is assumed to contain the
/// value estimates of the agents in the batch.
/// </summary>
public class ValueEstimateApplier : TensorApplier.Applier
{
public void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo)
{
var tensorDataValue = tensor.Data as float[,];
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
agent.UpdateValueAction(tensorDataValue[agentIndex, 0]);
agentIndex++;
}
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs.meta


fileFormatVersion: 2
guid: 99d5dc2d52e442d1a1f466a246cfb28d
timeCreated: 1539118675

224
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs


using UnityEngine.MachineLearning.InferenceEngine;
using System.Collections.Generic;
using System;
using UnityEngine.MachineLearning.InferenceEngine.Util;
using System.Linq;
namespace MLAgents.InferenceBrain
{
/// <summary>
/// Reshapes a Tensor so that its first dimension becomes equal to the current batch size
/// and initializes its content to be zeros. Will only work on 2-dimensional tensors.
/// The second dimension of the Tensor will not be modified.
/// </summary>
public class BiDimensionalOutputGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
var shapeSecondAxis = tensor.Shape[1];
tensor.Shape[0] = batchSize;
if (tensor.ValueType == Tensor.TensorType.FloatingPoint)
{
tensor.Data = new float[batchSize, shapeSecondAxis];
}
else
{
tensor.Data = new int[batchSize, shapeSecondAxis];
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the BatchSize input : Will be a one dimensional
/// integer array of size 1 containing the batch size.
/// </summary>
public class BatchSizeGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
tensor.Data = new int[] {batchSize};
}
}
/// <summary>
/// Generates the Tensor corresponding to the SequenceLength input : Will be a one
/// dimensional integer array of size 1 containing 1.
/// Note : the sequence length is always one since recurrent networks only predict for
/// one step at the time.
/// </summary>
public class SequenceLengthGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
tensor.Data = new int[] {1};
}
}
/// <summary>
/// Generates the Tensor corresponding to the VectorObservation input : Will be a two
/// dimensional float array of dimension [batchSize x vectorObservationSize].
/// It will use the Vector Observation data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class VectorObservationGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
tensor.Shape[0] = batchSize;
var vecObsSizeT = tensor.Shape[1];
tensor.Data = new float[batchSize, vecObsSizeT];
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
var vectorObs = agentInfo[agent].stackedVectorObservation;
for (var j = 0; j < vecObsSizeT; j++)
{
tensor.Data.SetValue(vectorObs[j], new int[2] {agentIndex, j});
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Recurrent input : Will be a two
/// dimensional float array of dimension [batchSize x memorySize].
/// It will use the Memory data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class RecurrentInputGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
tensor.Shape[0] = batchSize;
var memorySize = tensor.Shape[1];
tensor.Data = new float[batchSize, memorySize];
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
var memory = agentInfo[agent].memories;
if (memory == null)
{
agentIndex++;
continue;
}
for (var j = 0; j < Math.Min(memorySize, memory.Count); j++)
{
if (j >= memory.Count)
{
break;
}
tensor.Data.SetValue(memory[j], new int[2] {agentIndex, j});
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Previous Action input : Will be a two
/// dimensional integer array of dimension [batchSize x actionSize].
/// It will use the previous action data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class PreviousActionInputGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
if (tensor.ValueType != Tensor.TensorType.Integer)
{
throw new NotImplementedException(
"Previous Action Inputs are only valid for discrete control");
}
tensor.Shape[0] = batchSize;
var actionSize = tensor.Shape[1];
tensor.Data = new int[batchSize, actionSize];
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
var pastAction = agentInfo[agent].storedVectorActions;
for (var j = 0; j < actionSize; j++)
{
tensor.Data.SetValue((int) pastAction[j], new int[2] {agentIndex, j});
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Action Mask input : Will be a two
/// dimensional float array of dimension [batchSize x numActionLogits].
/// It will use the Action Mask data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class ActionMaskInputGenerator : TensorGenerator.Generator
{
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
tensor.Shape[0] = batchSize;
var maskSize = tensor.Shape[1];
tensor.Data = new float[batchSize, maskSize];
var agentIndex = 0;
foreach (var agent in agentInfo.Keys)
{
var maskList = agentInfo[agent].actionMasks;
for (var j = 0; j < maskSize; j++)
{
var isUnmasked = (maskList != null && maskList[j]) ? 0.0f : 1.0f;
tensor.Data.SetValue(isUnmasked, new int[2] {agentIndex, j});
}
agentIndex++;
}
}
}
/// <summary>
/// Generates the Tensor corresponding to the Epsilon input : Will be a two
/// dimensional float array of dimension [batchSize x actionSize].
/// It will use the generate random input data from a normal Distribution.
/// </summary>
public class RandomNormalInputGenerator : TensorGenerator.Generator
{
private RandomNormal _randomNormal;
public RandomNormalInputGenerator(int seed)
{
_randomNormal = new RandomNormal(seed);
}
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
tensor.Shape[0] = batchSize;
var actionSize = tensor.Shape[1];
tensor.Data = new float[batchSize, actionSize];
_randomNormal.FillTensor(tensor);
}
}
/// <summary>
/// Generates the Tensor corresponding to the Visual Observation input : Will be a 4
/// dimensional float array of dimension [batchSize x width x heigth x numChannels].
/// It will use the Texture input data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class VisualObservationInputGenerator : TensorGenerator.Generator
{
private int _index;
private bool _grayScale;
public VisualObservationInputGenerator(int index, bool grayScale)
{
_index = index;
_grayScale = grayScale;
}
public void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo)
{
var textures = agentInfo.Keys.Select(
agent => agentInfo[agent].visualObservations[_index]).ToList();
tensor.Data = Utilities.TextureToFloatArray(textures, _grayScale);
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs.meta


fileFormatVersion: 2
guid: c57a4989c7e54b93ab56293698d7d237
timeCreated: 1539109542

531
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs


using UnityEngine.MachineLearning.InferenceEngine;
using UnityEngine;
using System;
using System.Collections.Generic;
using System.Linq;
namespace MLAgents.InferenceBrain
{
/// <summary>
/// Prepares the Tensors for the Learning Brain and exposes a list of failed checks if Model
/// and BrainParameters are incompatible.
/// </summary>
public class ModelParamLoader
{
private enum ModelActionType
{
Unknown,
Discrete,
Continuous
}
private const long ApiVersion = 1;
private InferenceEngine _engine;
private BrainParameters _brainParameters;
private List<string> _failedModelChecks = new List<string>();
/// <summary>
/// Factory for the ModelParamLoader : Creates a ModelParamLoader and runs the checks
/// on it.
/// </summary>
/// <param name="engine"> The InferenceEngine we get the parameters and the checks from
/// </param>
/// <param name="brainParameters"> The BrainParamters that are used verify the
/// compatibility with the InferenceEngine</param>
/// <returns></returns>
public static ModelParamLoader GetLoaderAndCheck(InferenceEngine engine,
BrainParameters brainParameters)
{
ModelParamLoader modelParamLoader = new ModelParamLoader(engine, brainParameters);
modelParamLoader.GenerateChecks();
return modelParamLoader;
}
private ModelParamLoader(InferenceEngine engine, BrainParameters brainParameters)
{
_engine = engine;
_brainParameters = brainParameters;
}
/// <summary>
/// Generates the Tensor inputs that are expected to be present in the Model.
/// </summary>
/// <returns>Tensor IEnumerable with the expected Tensor inputs</returns>
public IEnumerable<Tensor> GetInputTensors()
{
return _engine?.InputFeatures();
}
/// <summary>
/// Generates the Tensor outputs that are expected to be present in the Model.
/// </summary>
/// <returns>Tensor IEnumerable with the expected Tensor outputs</returns>
public IEnumerable<Tensor> GetOutputTensors()
{
var tensorList = new List<Tensor>();
if (_brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
tensorList.Add(new Tensor()
{
Name = TensorNames.ActionOutput,
Shape = new long[]
{
-1, _brainParameters.vectorActionSize[0]
},
ValueType = Tensor.TensorType.FloatingPoint,
Data = null
});
}
else
{
tensorList.Add(
new Tensor()
{
Name = TensorNames.ActionOutput,
Shape = new long[]
{
-1, _brainParameters.vectorActionSize.Sum()
},
ValueType = Tensor.TensorType.FloatingPoint,
Data = null
});
}
var memory = GetIntScalar(TensorNames.MemorySize);
if (memory > 0)
{
tensorList.Add(new Tensor()
{
Name = TensorNames.RecurrentOutput,
Shape = new long[2]
{
-1, memory
},
ValueType = Tensor.TensorType.FloatingPoint,
Data = null
});
}
return tensorList;
}
/// <summary>
/// Queries the InferenceEngine for the value of a variable in the graph given its name.
/// Only works with int32 Tensors with zero dimensions containing a unique element.
/// If the node was not found or could not be retrieved, the value -1 will be returned.
/// </summary>
/// <param name="name">The name of the Tensor variable</param>
/// <returns>The value of the scalar variable in the model. (-1 if not found)</returns>
private int GetIntScalar(string name)
{
var outputs = new Tensor[]
{
new Tensor()
{
Name = name,
ValueType = Tensor.TensorType.Integer,
Shape = new long[] { },
Data = new long[1]
},
};
try
{
_engine.ExecuteGraph(new Tensor[0], outputs);
}
catch (Exception e)
{
Debug.Log("Node not in graph: " + name + ". The following error occured : \n" + e);
return -1;
}
return (outputs[0].Data as int[])[0];
}
/// <summary>
/// Retrieves an IEnumerable of string corresponding to the failed compatibility checks
/// between the InferenceEngine and the BrainParameters.
/// </summary>
public IEnumerable<string> GetChecks()
{
return _failedModelChecks;
}
/// <summary>
/// Generates the list of failed checks that failed when comparing the data from the Model
/// and from the BrainParameters
/// </summary>
private void GenerateChecks()
{
_failedModelChecks.Clear();
if (_engine == null)
{
_failedModelChecks.Add(
"There is no model for this Brain, cannot run inference. " +
"(But can still train)");
return;
}
var modelApiVersion = GetIntScalar(TensorNames.VersionNumber);
var memorySize = GetIntScalar(TensorNames.MemorySize);
var isContinuousInt = GetIntScalar(TensorNames.IsContinuousControl);
var isContinuous = GetActionType(isContinuousInt);
var actionSize = GetIntScalar(TensorNames.ActionOutputShape);
if (modelApiVersion == -1)
{
_failedModelChecks.Add(
"Model was not trained using the right version of ML-Agents. Cannot use this " +
"model.");
return;
}
if (modelApiVersion != ApiVersion)
{
_failedModelChecks.Add(
$"Version of the trainer the model was trained with ({modelApiVersion}) " +
$"is not compatible with the Brain's version ({ApiVersion}).");
return;
}
CheckIntScalarPresenceHelper(new Dictionary<string, int>()
{
{TensorNames.MemorySize, memorySize},
{TensorNames.IsContinuousControl, isContinuousInt},
{TensorNames.ActionOutputShape, actionSize}
});
CheckInputTensorPresence(memorySize, isContinuous);
CheckOutputTensorPresence(memorySize);
CheckInputTensorShape();
CheckOutputTensorShape(isContinuous, actionSize);
}
/// <summary>
/// Converts the integer value in the model corresponding to the type of control to a
/// ModelActionType.
/// </summary>
/// <param name="isContinuousInt"> The integer value in the model indicating the
/// type of control</param>
/// <returns>The equivalent ModelActionType</returns>
private static ModelActionType GetActionType(int isContinuousInt)
{
ModelActionType isContinuous;
switch (isContinuousInt)
{
case 0:
isContinuous = ModelActionType.Discrete;
break;
case 1:
isContinuous = ModelActionType.Continuous;
break;
default:
isContinuous = ModelActionType.Unknown;
break;
}
return isContinuous;
}
/// <summary>
/// Given a Dictionary of node names to int values, create checks if the values have the
/// invalid value of -1.
/// </summary>
/// <param name="requiredScalarFields"> Mapping from node names to int values</param>
private void CheckIntScalarPresenceHelper(Dictionary<string, int> requiredScalarFields)
{
foreach(var field in requiredScalarFields)
if (field.Value == -1)
{
_failedModelChecks.Add(
$"Missing node in the model provided : {field.Key}");
}
}
/// <summary>
/// Generates failed checks that correspond to inputs expected by the model that are not
/// present in the BrainParameters.
/// </summary>
/// <param name="memory"> The memory size that the model is expecting/</param>
/// <param name="isContinuous"> Whether the model is expecting continuous or
/// discrete control.</param>
/// <returns>A IEnumerable of string corresponding to the failed input presence
/// checks.</returns>
private void CheckInputTensorPresence(int memory, ModelActionType isContinuous)
{
var tensorsNames = GetInputTensors().Select(x => x.Name).ToList();
// If there is no Vector Observation Input but the Brain Parameters expect one.
if ((_brainParameters.vectorObservationSize != 0) &&
(!tensorsNames.Contains(TensorNames.VectorObservationPlacholder)))
{
_failedModelChecks.Add(
"The model does not contain a Vector Observation Placeholder Input. " +
"You must set the Vector Observation Space Size to 0.");
}
// If there are not enough Visual Observation Input compared to what the
// Brain Parameters expect.
for (var visObsIndex = 0;
visObsIndex < _brainParameters.cameraResolutions.Length;
visObsIndex++)
{
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
{
_failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
"for visual observation "+visObsIndex+".");
}
}
// If the model has a non-negative memory size but requires a recurrent input
if (memory > 0)
{
if (!tensorsNames.Contains(TensorNames.RecurrentInPlaceholder))
{
_failedModelChecks.Add(
"The model does not contain a Recurrent Input Node but has memory_size.");
}
}
// If the model uses discrete control but does not have an input for action masks
if (isContinuous == ModelActionType.Discrete)
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
_failedModelChecks.Add(
"The model does not contain an Action Mask but is using Discrete Control.");
}
}
}
/// <summary>
/// Generates failed checks that correspond to outputs expected by the model that are not
/// present in the BrainParameters.
/// </summary>
/// <param name="memory"> The memory size that the model is expecting/</param>
/// <returns>A IEnumerable of string corresponding to the failed output presence
/// checks.</returns>
private void CheckOutputTensorPresence(int memory)
{
var tensorsNames = GetOutputTensors().Select(x => x.Name).ToList();
// If there is no Action Output.
if (!tensorsNames.Contains(TensorNames.ActionOutput))
{
_failedModelChecks.Add("The model does not contain an Action Output Node.");
}
// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)
{
if (!tensorsNames.Contains(TensorNames.RecurrentOutput))
{
_failedModelChecks.Add(
"The model does not contain a Recurrent Output Node but has memory_size.");
}
}
}
/// <summary>
/// Generates failed checks that correspond to inputs shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>
private void CheckInputTensorShape()
{
var tensorTester =
new Dictionary<string, Func<Tensor, string>>()
{
{TensorNames.VectorObservationPlacholder, CheckVectorObsShape},
{TensorNames.PreviousActionPlaceholder, CheckPreviousActionShape},
{TensorNames.RandomNormalEpsilonPlaceholder, ((tensor) => null)},
{TensorNames.ActionMaskPlaceholder, ((tensor) => null)},
{TensorNames.SequenceLengthPlaceholder, ((tensor) => null)},
{TensorNames.RecurrentInPlaceholder, ((tensor) => null)},
};
for (var obsIndex = 0; obsIndex < _brainParameters.cameraResolutions.Length; obsIndex++)
{
var index = obsIndex;
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + obsIndex] =
(tensor) => CheckVisualObsShape(tensor, index);
}
// If the model expects an input but it is not in this list
foreach (var tensor in GetInputTensors())
{
if (!tensorTester.ContainsKey(tensor.Name))
{
_failedModelChecks.Add(
"Model requires an unknown input named : " + tensor.Name);
}
else
{
var tester = tensorTester[tensor.Name];
var error = tester.Invoke(tensor);
if (error != null)
{
_failedModelChecks.Add(error);
}
}
}
}
/// <summary>
/// Checks that the shape of the Vector Observation input placeholder is the same in the
/// model and in the Brain Parameters.
/// </summary>
/// <param name="tensor"> The tensor that is expected by the model</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
private string CheckVectorObsShape(Tensor tensor)
{
var vecObsSizeBp = _brainParameters.vectorObservationSize;
var numStackedVector = _brainParameters.numStackedVectorObservations;
var totalVecObsSizeT = tensor.Shape[1];
if (vecObsSizeBp * numStackedVector != totalVecObsSizeT)
{
return string.Format(
"Vector Observation Size of the model does not match. " +
"Received {0} x {1} but was expecting {2}.",
vecObsSizeBp, numStackedVector, totalVecObsSizeT);
}
return null;
}
/// <summary>
/// Checks that the shape of the Previous Vector Action input placeholder is the same in the
/// model and in the Brain Parameters.
/// </summary>
/// <param name="tensor"> The tensor that is expected by the model</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
private string CheckPreviousActionShape(Tensor tensor)
{
var numberActionsBp = _brainParameters.vectorActionSize.Length;
var numberActionsT = tensor.Shape[1];
if (numberActionsBp != numberActionsT)
{
return string.Format(
"Previous Action Size of the model does not match. " +
"Received {0} but was expecting {1}.",
numberActionsBp, numberActionsT);
}
return null;
}
/// <summary>
/// Checks that the shape of the visual observation input placeholder is the same in the
/// model and in the Brain Parameters.
/// </summary>
/// <param name="tensor"> The tensor that is expected by the model</param>
/// <param name="visObsIndex"> The index of the visual observation.</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
private string CheckVisualObsShape(Tensor tensor, int visObsIndex)
{
var resolutionBp = _brainParameters.cameraResolutions[visObsIndex];
var widthBp = resolutionBp.width;
var heightBp = resolutionBp.height;
var pixelBp = resolutionBp.blackAndWhite ? 1 : 3;
var widthT = tensor.Shape[1];
var heightT = tensor.Shape[2];
var pixelT = tensor.Shape[3];
if ((widthBp != widthT) || (heightBp != heightT) || (pixelBp != pixelT))
{
return string.Format(
"The visual Observation {0} of the model does not match. " +
"Received Tensor of shape [?x{1}x{2}x{3}] but was expecting [?x{4}x{5}x{6}].",
visObsIndex, widthBp, heightBp, pixelBp, widthT, heightT, pixelT);
}
return null;
}
/// <summary>
/// Generates failed checks that correspond to output shapes incompatibilities between
/// the model and the BrainParameters.
/// </summary>
/// <param name="isContinuous"> Whether the model is expecting continuous or
/// discrete control.</param>
/// <param name="modelActionSize"> The size of the action output that is expected
/// by the model.</param>
/// <returns>A IEnumerable of string corresponding to the incompatible shapes between
/// model and BrainParameters.</returns>
private void CheckOutputTensorShape(ModelActionType isContinuous, int modelActionSize)
{
if (isContinuous == ModelActionType.Unknown)
{
_failedModelChecks.Add(
"Cannot infer type of Control from the provided model.");
return;
}
if (isContinuous == ModelActionType.Continuous &&
_brainParameters.vectorActionSpaceType != SpaceType.continuous)
{
_failedModelChecks.Add(
"Model has been trained using Continuous Control but the Brain Parameters " +
"suggest Discrete Control.");
return;
}
if (isContinuous == ModelActionType.Discrete &&
_brainParameters.vectorActionSpaceType != SpaceType.discrete)
{
_failedModelChecks.Add(
"Model has been trained using Discrete Control but the Brain Parameters " +
"suggest Continuous Control.");
return;
}
var tensorTester = new Dictionary<string, Func<Tensor, int, string>>();
if (_brainParameters.vectorActionSpaceType == SpaceType.continuous)
{
tensorTester[TensorNames.ActionOutput] = CheckContinuousActionOutputShape;
}
else
{
tensorTester[TensorNames.ActionOutput] = CheckDiscreteActionOutputShape;
}
// If the model expects an output but it is not in this list
foreach (var tensor in GetOutputTensors())
{
if (tensorTester.ContainsKey(tensor.Name))
{
var tester = tensorTester[tensor.Name];
var error = tester.Invoke(tensor, modelActionSize);
if (error != null)
{
_failedModelChecks.Add(error);
}
}
}
}
/// <summary>
/// Checks that the shape of the discrete action output is the same in the
/// model and in the Brain Parameters.
/// </summary>
/// <param name="tensor"> The tensor that is expected by the model</param>
/// <param name="modelActionSize"> The size of the action output that is expected
/// by the model.</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
private string CheckDiscreteActionOutputShape(Tensor tensor, int modelActionSize)
{
var bpActionSize = _brainParameters.vectorActionSize.Sum();
if (modelActionSize != bpActionSize)
{
return string.Format(
"Action Size of the model does not match. " +
"The BrainParameters expect {0} but the model contains {1}.",
bpActionSize, modelActionSize);
}
return null;
}
/// <summary>
/// Checks that the shape of the continuous action output is the same in the
/// model and in the Brain Parameters.
/// </summary>
/// <param name="tensor"> The tensor that is expected by the model</param>
/// <param name="modelActionSize"> The size of the action output that is expected
/// by the model.</param>
/// <returns>If the Check failed, returns a string containing information about why the
/// check failed. If the check passed, returns null.</returns>
private string CheckContinuousActionOutputShape(Tensor tensor, int modelActionSize)
{
var bpActionSize = _brainParameters.vectorActionSize[0];
if (modelActionSize != bpActionSize)
{
return string.Format(
"Action Size of the model does not match. " +
"The BrainParameters expect {0} but the model contains {1}.",
bpActionSize, modelActionSize);
}
return null;
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ModelParamLoader.cs.meta


fileFormatVersion: 2
guid: 259e3a0e37204794a885219327bd4c02
timeCreated: 1539197357

80
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs


using UnityEngine.MachineLearning.InferenceEngine;
using System.Collections.Generic;
using UnityEngine.MachineLearning.InferenceEngine.Util;
using System;
namespace MLAgents.InferenceBrain
{
/// <summary>
/// Mapping between the output Tensor names and the method that will use the
/// output tensors and the Agents present in the batch to update their action, memories and
/// value estimates.
/// A TensorApplier implements a Dictionary of strings (node names) to an Action.
/// This action takes as input the Tensor and the Dictionary of Agent to AgentInfo for
/// the current batch.
/// </summary>
public class TensorApplier
{
/// <summary>
/// A tensor Applier's Execute method takes a Tensor and a Dictionary of Agent to AgentInfo.
/// Uses the data contained inside the Tensor to modify the state of the Agent. The Tensors
/// are assumed to have the batch size on the first dimension and the agents to be ordered
/// the same way in the dictionary and in the Tensor.
/// </summary>
public interface Applier
{
/// <summary>
/// Applies the values in the Tensor to the Agents present in the agentInfos
/// </summary>
/// <param name="tensor"> The Tensor containing the data to be applied to the Agents</param>
/// <param name="agentInfo"> Dictionary of Agents to AgentInfo that will reveive
/// the values of the Tensor.</param>
void Apply(Tensor tensor, Dictionary<Agent, AgentInfo> agentInfo);
}
Dictionary<string, Applier> _dict = new Dictionary<string, Applier>();
/// <summary>
/// Returns a new TensorAppliers object.
/// </summary>
/// <param name="bp"> The BrainParameters used to determine what Appliers will be
/// used</param>
/// <param name="seed"> The seed the Appliers will be initialized with.</param>
public TensorApplier(BrainParameters bp, int seed)
{
_dict[TensorNames.ValueEstimateOutput] = new ValueEstimateApplier();
if (bp.vectorActionSpaceType == SpaceType.continuous)
{
_dict[TensorNames.ActionOutput] = new ContinuousActionOutputApplier();
}
else
{
_dict[TensorNames.ActionOutput] = new DiscreteActionOutputApplier(
bp.vectorActionSize, seed);
}
_dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier();
}
/// <summary>
/// Updates the state of the agents based on the data present in the tensor.
/// </summary>
/// <param name="tensors"> Enumerable of tensors containing the data.</param>
/// <param name="agentInfos"> Dictionary of Agent to AgentInfo that contains the
/// Agents that will be updated using the tensor's data</param>
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated applier.</exception>
public void ApplyTensors(
IEnumerable<Tensor> tensors, Dictionary<Agent, AgentInfo> agentInfos)
{
foreach (var tensor in tensors)
{
if (!_dict.ContainsKey(tensor.Name))
{
throw new UnityAgentsException(
"Unknow tensor expected as output : "+tensor.Name);
}
_dict[tensor.Name].Apply(tensor, agentInfos);
}
}
}
}

11
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs.meta


fileFormatVersion: 2
guid: d1bef4f4ae72645108f16614355473e8
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

99
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs


using UnityEngine.MachineLearning.InferenceEngine;
using System.Collections.Generic;
using UnityEngine.MachineLearning.InferenceEngine.Util;
using System.Linq;
using System;
namespace MLAgents.InferenceBrain
{
/// <summary>
/// Mapping between Tensor names and generators.
/// A TensorGenerator implements a Dictionary of strings (node names) to an Action.
/// The Action take as argument the tensor, the current batch size and a Dictionary of
/// Agent to AgentInfo corresponding to the current batch.
/// Each Generator reshapes and fills the data of the tensor based of the data of the batch.
/// When the Tensor is an Input to the model, the shape of the Tensor will be modified
/// depending on the current batch size and the data of the Tensor will be filled using the
/// Dictionary of Agent to AgentInfo.
/// When the Tensor is an Output of the model, only the shape of the Tensor will be modified
/// using the current batch size. The data will be prefilled with zeros.
/// </summary>
public class TensorGenerator
{
public interface Generator
{
/// <summary>
/// Modifies the data inside a Tensor according to the information contained in the
/// AgentInfos contained in the current batch.
/// </summary>
/// <param name="tensor"> The tensor the data and shape will be modified</param>
/// <param name="batchSize"> The number of agents present in the current batch</param>
/// <param name="agentInfo"> Dictionary of Agent to AgentInfo containing the
/// information that will be used to populate the tensor's data</param>
void Generate(Tensor tensor, int batchSize, Dictionary<Agent, AgentInfo> agentInfo);
}
Dictionary<string, Generator> _dict = new Dictionary<string, Generator>();
/// <summary>
/// Returns a new TensorGenerators object.
/// </summary>
/// <param name="bp"> The BrainParameters used to determine what Generators will be
/// used</param>
/// <param name="seed"> The seed the Generators will be initialized with.</param>
public TensorGenerator(BrainParameters bp, int seed)
{
// Generator for Inputs
_dict[TensorNames.BatchSizePlaceholder] = new BatchSizeGenerator();
_dict[TensorNames.SequenceLengthPlaceholder] = new SequenceLengthGenerator();
_dict[TensorNames.VectorObservationPlacholder] = new VectorObservationGenerator();
_dict[TensorNames.RecurrentInPlaceholder] = new RecurrentInputGenerator();
_dict[TensorNames.PreviousActionPlaceholder] = new PreviousActionInputGenerator();
_dict[TensorNames.ActionMaskPlaceholder] = new ActionMaskInputGenerator();
_dict[TensorNames.RandomNormalEpsilonPlaceholder] = new RandomNormalInputGenerator(seed);
if (bp.cameraResolutions != null)
{
for (var visIndex = 0;
visIndex < bp.cameraResolutions.Length;
visIndex++)
{
var index = visIndex;
var bw = bp.cameraResolutions[visIndex].blackAndWhite;
_dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] = new
VisualObservationInputGenerator(index, bw);
}
}
// Generators for Outputs
_dict[TensorNames.ActionOutput] = new BiDimensionalOutputGenerator();
_dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator();
_dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator();
}
/// <summary>
/// Populates the data of the tensor inputs given the data contained in the current batch
/// of agents.
/// </summary>
/// <param name="tensors"> Enumerable of tensors that will be modified.</param>
/// <param name="currentBatchSize"> The number of agents present in the current batch
/// </param>
/// <param name="agentInfos"> Dictionary of Agent to AgentInfo that contains the
/// data that will be used to modify the tensors</param>
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(IEnumerable<Tensor> tensors,
int currentBatchSize,
Dictionary<Agent, AgentInfo> agentInfos)
{
foreach (var tensor in tensors)
{
if (!_dict.ContainsKey(tensor.Name))
{
throw new UnityAgentsException(
"Unknow tensor expected as input : " + tensor.Name);
}
_dict[tensor.Name].Generate(tensor, currentBatchSize, agentInfos);
}
}
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs.meta


fileFormatVersion: 2
guid: 6a24e86bc77c4a5088a5fd04d6d30e81
timeCreated: 1537484304

25
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorNames.cs


namespace MLAgents.InferenceBrain
{
/// <summary>
/// Contains the names of the input and output Tensor for the Inference Brain.
/// </summary>
public static class TensorNames
{
public const string BatchSizePlaceholder = "batch_size";
public const string SequenceLengthPlaceholder = "sequence_length";
public const string VectorObservationPlacholder = "vector_observation";
public const string RecurrentInPlaceholder = "recurrent_in";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string PreviousActionPlaceholder = "prev_action";
public const string ActionMaskPlaceholder = "action_masks";
public const string RandomNormalEpsilonPlaceholder = "epsilon";
public const string ValueEstimateOutput = "value_estimate";
public const string RecurrentOutput = "recurrent_out";
public const string MemorySize = "memory_size";
public const string VersionNumber = "version_number";
public const string IsContinuousControl = "is_continuous_control";
public const string ActionOutputShape = "action_output_shape";
public const string ActionOutput = "action";
}
}

3
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorNames.cs.meta


fileFormatVersion: 2
guid: b28a46ea97c2445794d29d5a8a718a4a
timeCreated: 1538158527
正在加载...
取消
保存