浏览代码

Merge branch 'master' into global-variables

/MLA-1734-demo-provider
Anupam Bhatnagar 4 年前
当前提交
abc1220f
共有 20 个文件被更改,包括 672 次插入90 次删除
  1. 2
      .yamato/com.unity.ml-agents-performance.yml
  2. 4
      com.unity.ml-agents.extensions/Editor/Unity.ML-Agents.Extensions.Editor.asmdef
  3. 1
      com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs
  4. 29
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  5. 8
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  6. 54
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  7. 117
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  8. 80
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  9. 68
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  10. 88
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  11. 67
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  12. 5
      com.unity.ml-agents/CHANGELOG.md
  13. 16
      ml-agents/mlagents/trainers/policy/tf_policy.py
  14. 2
      ml-agents/mlagents/trainers/stats.py
  15. 3
      ml-agents/mlagents/trainers/tests/mock_brain.py
  16. 113
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  17. 26
      ml-agents/mlagents/trainers/tf/models.py
  18. 2
      ml-agents/mlagents/trainers/trainer_controller.py
  19. 66
      com.unity.ml-agents.extensions/Editor/RigidBodySensorComponentEditor.cs
  20. 11
      com.unity.ml-agents.extensions/Editor/RigidBodySensorComponentEditor.cs.meta

2
.yamato/com.unity.ml-agents-performance.yml


variables:
UNITY_VERSION: {{ editor.version }}
commands:
- python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple
- python -m pip install unity-downloader-cli --index-url https://artifactory.prd.it.unity3d.com/artifactory/api/pypi/pypi/simple --upgrade
- unity-downloader-cli -u {{ editor.version }} -c editor --wait --fast
- curl -s https://artifactory.internal.unity3d.com/core-automation/tools/utr-standalone/utr --output utr
- chmod +x ./utr

4
com.unity.ml-agents.extensions/Editor/Unity.ML-Agents.Extensions.Editor.asmdef


{
"name": "Unity.ML-Agents.Extensions.Editor",
"references": [
"Unity.ML-Agents.Extensions"
"Unity.ML-Agents.Extensions",
"Unity.ML-Agents",
"Unity.ML-Agents.Editor"
],
"includePlatforms": [
"Editor"

1
com.unity.ml-agents.extensions/Runtime/AssemblyInfo.cs


using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.Editor")]

29
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


return new Pose { rotation = t.rotation, position = t.position };
}
/// <inheritdoc/>
protected internal override Object GetObjectAt(int index)
{
return m_Bodies[index];
}
internal IEnumerable<ArticulationBody> GetEnabledArticulationBodies()
{
if (m_Bodies == null)
{
yield break;
}
for (var i = 0; i < m_Bodies.Length; i++)
{
var articBody = m_Bodies[i];
if (articBody == null)
{
// Ignore a virtual root.
continue;
}
if (IsPoseEnabled(i))
{
yield return articBody;
}
}
}
}
}
#endif // UNITY_2020_1_OR_NEWER

8
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numPoseObservations = poseExtractor.GetNumPoseObservations(Settings);
var numJointObservations = 0;
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
numJointObservations += ArticulationBodyJointExtractor.NumObservations(
poseExtractor.Bodies[i], Settings
);
numJointObservations += ArticulationBodyJointExtractor.NumObservations(articBody, Settings);
}
return new[] { numPoseObservations + numJointObservations };
}

54
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

string m_SensorName;
PoseExtractor m_PoseExtractor;
IJointExtractor[] m_JointExtractors;
List<IJointExtractor> m_JointExtractors;
/// Construct a new PhysicsBodySensor
/// Construct a new PhysicsBodySensor
/// <param name="rootBody">The root Rigidbody. This has no Joints on it (but other Joints may connect to it).</param>
/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
/// <param name="poseExtractor"></param>
Rigidbody rootBody,
GameObject rootGameObject,
GameObject virtualRoot,
RigidBodyPoseExtractor poseExtractor,
string sensorName=null
string sensorName
var poseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject, virtualRoot);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_SensorName = sensorName;
var rigidBodies = poseExtractor.Bodies;
if (rigidBodies != null)
{
m_JointExtractors = new IJointExtractor[rigidBodies.Length - 1]; // skip the root
for (var i = 1; i < rigidBodies.Length; i++)
{
var jointExtractor = new RigidBodyJointExtractor(rigidBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var rb in poseExtractor.GetEnabledRigidbodies())
m_JointExtractors = new IJointExtractor[0];
var jointExtractor = new RigidBodyJointExtractor(rb);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);

m_Settings = settings;
var numJointExtractorObservations = 0;
var articBodies = poseExtractor.Bodies;
if (articBodies != null)
m_JointExtractors = new List<IJointExtractor>(poseExtractor.NumEnabledPoses);
foreach(var articBody in poseExtractor.GetEnabledArticulationBodies())
m_JointExtractors = new IJointExtractor[articBodies.Length - 1]; // skip the root
for (var i = 1; i < articBodies.Length; i++)
{
var jointExtractor = new ArticulationBodyJointExtractor(articBodies[i]);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors[i - 1] = jointExtractor;
}
}
else
{
m_JointExtractors = new IJointExtractor[0];
var jointExtractor = new ArticulationBodyJointExtractor(articBody);
numJointExtractorObservations += jointExtractor.NumObservations(settings);
m_JointExtractors.Add(jointExtractor);
}
var numTransformObservations = m_PoseExtractor.GetNumPoseObservations(settings);

117
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


using System;
using Object = UnityEngine.Object;
namespace Unity.MLAgents.Extensions.Sensors
{

{
if (m_ParentIndices == null)
{
return -1;
throw new NullReferenceException("No parent indices set");
}
return m_ParentIndices[index];

public void SetPoseEnabled(int index, bool val)
{
m_PoseEnabled[index] = val;
}
public bool IsPoseEnabled(int index)
{
return m_PoseEnabled[index];
}
/// <summary>

/// <returns></returns>
protected internal abstract Vector3 GetLinearVelocityAt(int index);
/// <summary>
/// Return the underlying object at the given index. This is only
/// used for display in the inspector.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected internal virtual Object GetObjectAt(int index)
{
return null;
}
/// <summary>
/// Update the internal model space transform storage based on the underlying system.

Debug.DrawLine(current.position+offset, current.position+offset+.1f*localRight, Color.blue);
}
}
/// <summary>
/// Simplified representation of the a node in the hierarchy for display.
/// </summary>
internal struct DisplayNode
{
/// <summary>
/// Underlying object in the hierarchy. Pass to EditorGUIUtility.ObjectContent() for display.
/// </summary>
public Object NodeObject;
/// <summary>
/// Whether the poses for the object are enabled.
/// </summary>
public bool Enabled;
/// <summary>
/// Depth in the hierarchy, used for adjusting the indent level.
/// </summary>
public int Depth;
/// <summary>
/// The index of the corresponding object in the PoseExtractor.
/// </summary>
public int OriginalIndex;
}
/// <summary>
/// Get a list of display nodes in depth-first order.
/// </summary>
/// <returns></returns>
internal IList<DisplayNode> GetDisplayNodes()
{
if (NumPoses == 0)
{
return Array.Empty<DisplayNode>();
}
var nodesOut = new List<DisplayNode>(NumPoses);
// List of children for each node
var tree = new Dictionary<int, List<int>>();
for (var i = 0; i < NumPoses; i++)
{
var parent = GetParentIndex(i);
if (i == -1)
{
continue;
}
if (!tree.ContainsKey(parent))
{
tree[parent] = new List<int>();
}
tree[parent].Add(i);
}
// Store (index, depth) in the stack
var stack = new Stack<(int, int)>();
stack.Push((0, 0));
while (stack.Count != 0)
{
var (current, depth) = stack.Pop();
var obj = GetObjectAt(current);
var node = new DisplayNode
{
NodeObject = obj,
Enabled = IsPoseEnabled(current),
OriginalIndex = current,
Depth = depth
};
nodesOut.Add(node);
// Add children
if (tree.ContainsKey(current))
{
// Push to the stack in reverse order
var children = tree[current];
for (var childIdx = children.Count-1; childIdx >= 0; childIdx--)
{
stack.Push((children[childIdx], depth+1));
}
}
// Safety check
// This shouldn't even happen, but in case we have a cycle in the graph
// exit instead of looping forever and eating up all the memory.
if (nodesOut.Count > NumPoses)
{
return nodesOut;
}
}
return nodesOut;
}
}
/// <summary>

80
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


/// <param name="rootGameObject">Optional GameObject used to find Rigidbodies in the hierarchy.</param>
/// <param name="virtualRoot">Optional GameObject used to determine the root of the poses,
/// separate from the actual Rigidbodies in the hierarchy. For locomotion tasks, with ragdolls, this provides
/// a stabilized refernece frame, which can improve learning.</param>
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null, GameObject virtualRoot = null)
/// a stabilized reference frame, which can improve learning.</param>
/// <param name="enableBodyPoses">Optional mapping of whether a body's psoe should be enabled or not.</param>
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null,
GameObject virtualRoot = null, Dictionary<Rigidbody, bool> enableBodyPoses = null)
{
if (rootBody == null)
{

Rigidbody[] rbs;
Joint[] joints;
joints = rootBody.GetComponentsInChildren <Joint>();
joints = rootGameObject.GetComponentsInChildren<Joint>();
}
if (rbs == null || rbs.Length == 0)

}
if (rbs[0] != rootBody)
if (rbs[0] != rootBody)
{
Debug.Log("Expected root body at index 0");
return;

}
}
var joints = rootBody.GetComponentsInChildren <Joint>();
foreach (var j in joints)
{
var parent = j.connectedBody;

// By default, ignore the root
SetPoseEnabled(0, false);
if (enableBodyPoses != null)
{
foreach (var pair in enableBodyPoses)
{
var rb = pair.Key;
if (bodyToIndex.TryGetValue(rb, out var index))
{
SetPoseEnabled(index, pair.Value);
}
}
}
}
/// <inheritdoc/>

return new Pose { rotation = body.rotation, position = body.position };
}
/// <inheritdoc/>
protected internal override Object GetObjectAt(int index)
{
if (index == 0 && m_VirtualRoot != null)
{
return m_VirtualRoot;
}
return m_Bodies[index];
}
/// <summary>
/// Get a dictionary indicating which Rigidbodies' poses are enabled or disabled.
/// </summary>
/// <returns></returns>
internal Dictionary<Rigidbody, bool> GetBodyPosesEnabled()
{
var bodyPosesEnabled = new Dictionary<Rigidbody, bool>(m_Bodies.Length);
for (var i = 0; i < m_Bodies.Length; i++)
{
var rb = m_Bodies[i];
if (rb == null)
{
continue; // skip virtual root
}
bodyPosesEnabled[rb] = IsPoseEnabled(i);
}
return bodyPosesEnabled;
}
internal IEnumerable<Rigidbody> GetEnabledRigidbodies()
{
if (m_Bodies == null)
{
yield break;
}
for (var i = 0; i < m_Bodies.Length; i++)
{
var rb = m_Bodies[i];
if (rb == null)
{
// Ignore a virtual root.
continue;
}
if (IsPoseEnabled(i))
{
yield return rb;
}
}
}
}
}

68
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;

/// <summary>
/// Optional sensor name. This must be unique for each Agent.
/// </summary>
[SerializeField]
[SerializeField]
[HideInInspector]
RigidBodyPoseExtractor m_PoseExtractor;
/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>

return new PhysicsBodySensor(RootBody, gameObject, VirtualRoot, Settings, sensorName);
var _sensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{RootBody?.name}" : sensorName;
return new PhysicsBodySensor(GetPoseExtractor(), Settings, _sensorName);
}
/// <inheritdoc/>

return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot);
var poseExtractor = GetPoseExtractor();
// Start from i=1 to ignore the root
for (var i = 1; i < poseExtractor.Bodies.Length; i++)
foreach(var rb in poseExtractor.GetEnabledRigidbodies())
var body = poseExtractor.Bodies[i];
var joint = body?.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(body, joint, Settings);
var joint = rb.GetComponent<Joint>();
numJointObservations += RigidBodyJointExtractor.NumObservations(rb, joint, Settings);
}
/// <summary>
/// Get the DisplayNodes of the hierarchy.
/// </summary>
/// <returns></returns>
internal IList<PoseExtractor.DisplayNode> GetDisplayNodes()
{
return GetPoseExtractor().GetDisplayNodes();
}
/// <summary>
/// Lazy construction of the PoseExtractor.
/// </summary>
/// <returns></returns>
RigidBodyPoseExtractor GetPoseExtractor()
{
if (m_PoseExtractor == null)
{
ResetPoseExtractor();
}
return m_PoseExtractor;
}
/// <summary>
/// Reset the pose extractor, trying to keep the enabled state of the corresponding poses the same.
/// </summary>
internal void ResetPoseExtractor()
{
// Get the current enabled state of each body, so that we can reinitialize with them.
Dictionary<Rigidbody, bool> bodyPosesEnabled = null;
if (m_PoseExtractor != null)
{
bodyPosesEnabled = m_PoseExtractor.GetBodyPosesEnabled();
}
m_PoseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject, VirtualRoot, bodyPosesEnabled);
}
/// <summary>
/// Toggle the pose at the given index.
/// </summary>
/// <param name="index"></param>
/// <param name="enabled"></param>
internal void SetPoseEnabled(int index, bool enabled)
{
GetPoseExtractor().SetPoseEnabled(index, enabled);
}
}

88
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs


using System;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;

public class PoseExtractorTests
{
class UselessPoseExtractor : PoseExtractor
class BasicPoseExtractor : PoseExtractor
{
protected internal override Pose GetPoseAt(int index)
{

protected internal override Vector3 GetLinearVelocityAt(int index)
protected internal override Vector3 GetLinearVelocityAt(int index)
}
class UselessPoseExtractor : BasicPoseExtractor
{
public void Init(int[] parentIndices)
{
Setup(parentIndices);

poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
// Iterating through poses and velocities should be an empty loop
foreach (var pose in poseExtractor.GetEnabledModelSpacePoses())
{
throw new UnityAgentsException("This shouldn't happen");
}
foreach (var pose in poseExtractor.GetEnabledLocalSpacePoses())
{
throw new UnityAgentsException("This shouldn't happen");
}
foreach (var vel in poseExtractor.GetEnabledModelSpaceVelocities())
{
throw new UnityAgentsException("This shouldn't happen");
}
foreach (var vel in poseExtractor.GetEnabledLocalSpaceVelocities())
{
throw new UnityAgentsException("This shouldn't happen");
}
// Getting a parent index should throw an index exception
Assert.Throws <NullReferenceException>(
() => poseExtractor.GetParentIndex(0)
);
// DisplayNodes should be empty
var displayNodes = poseExtractor.GetDisplayNodes();
Assert.AreEqual(0, displayNodes.Count);
}
[Test]

Assert.AreEqual(size, localPoseIndex);
}
class BadPoseExtractor : PoseExtractor
[Test]
public void TestChainDisplayNodes()
{
var size = 4;
var chain = new ChainPoseExtractor(size);
var displayNodes = chain.GetDisplayNodes();
Assert.AreEqual(size, displayNodes.Count);
for (var i = 0; i < size; i++)
{
var displayNode = displayNodes[i];
Assert.AreEqual(i, displayNode.OriginalIndex);
Assert.AreEqual(null, displayNode.NodeObject);
Assert.AreEqual(i, displayNode.Depth);
Assert.AreEqual(true, displayNode.Enabled);
}
}
[Test]
public void TestDisplayNodesLoop()
{
// Degenerate case with a loop
var poseExtractor = new UselessPoseExtractor();
poseExtractor.Init(new[] {-1, 2, 1});
// This just shouldn't blow up
poseExtractor.GetDisplayNodes();
// Self-loop
poseExtractor.Init(new[] {-1, 1});
// This just shouldn't blow up
poseExtractor.GetDisplayNodes();
}
class BadPoseExtractor : BasicPoseExtractor
{
public BadPoseExtractor()
{

}
Setup(parents);
}
protected internal override Pose GetPoseAt(int index)
{
return Pose.identity;
}
protected internal override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
}
[Test]

var bad = new BadPoseExtractor();
});
}
}
public class PoseExtensionTests

67
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs


var rootRb = go.AddComponent<Rigidbody>();
var poseExtractor = new RigidBodyPoseExtractor(rootRb);
Assert.AreEqual(1, poseExtractor.NumPoses);
// Also pass the GameObject
poseExtractor = new RigidBodyPoseExtractor(rootRb, go);
Assert.AreEqual(1, poseExtractor.NumPoses);
}
[Test]
public void TestNoBodiesFound()
{
// Check that if we can't find any bodies under the game object, we get an empty extractor
var gameObj = new GameObject();
var rootRb = gameObj.AddComponent<Rigidbody>();
var otherGameObj = new GameObject();
var poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj);
Assert.AreEqual(0, poseExtractor.NumPoses);
// Add an RB under the other GameObject. Constructor will find a rigid body, but not the root.
var otherRb = otherGameObj.AddComponent<Rigidbody>();
poseExtractor = new RigidBodyPoseExtractor(rootRb, otherGameObj);
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]

Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(0).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(0).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(0));
// Check DisplayNodes gives expected results
var displayNodes = poseExtractor.GetDisplayNodes();
Assert.AreEqual(2, displayNodes.Count);
Assert.AreEqual(rb1, displayNodes[0].NodeObject);
Assert.AreEqual(false, displayNodes[0].Enabled);
Assert.AreEqual(rb2, displayNodes[1].NodeObject);
Assert.AreEqual(true, displayNodes[1].Enabled);
}
[Test]

Assert.AreEqual(rb1.position, poseExtractor.GetPoseAt(1).position);
Assert.IsTrue(rb1.rotation == poseExtractor.GetPoseAt(1).rotation);
Assert.AreEqual(rb1.velocity, poseExtractor.GetLinearVelocityAt(1));
}
[Test]
public void TestBodyPosesEnabledDictionary()
{
// * rootObj
// - rb1
// * go2
// - rb2
// - joint
var rootObj = new GameObject();
var rb1 = rootObj.AddComponent<Rigidbody>();
var go2 = new GameObject();
var rb2 = go2.AddComponent<Rigidbody>();
go2.transform.SetParent(rootObj.transform);
var joint = go2.AddComponent<ConfigurableJoint>();
joint.connectedBody = rb1;
var poseExtractor = new RigidBodyPoseExtractor(rb1);
// Expect the root body disabled and the attached one enabled.
Assert.IsFalse(poseExtractor.IsPoseEnabled(0));
Assert.IsTrue(poseExtractor.IsPoseEnabled(1));
var bodyPosesEnabled = poseExtractor.GetBodyPosesEnabled();
Assert.IsFalse(bodyPosesEnabled[rb1]);
Assert.IsTrue(bodyPosesEnabled[rb2]);
// Swap the values
bodyPosesEnabled[rb1] = true;
bodyPosesEnabled[rb2] = false;
var poseExtractor2 = new RigidBodyPoseExtractor(rb1, null, null, bodyPosesEnabled);
Assert.IsTrue(poseExtractor2.IsPoseEnabled(0));
Assert.IsFalse(poseExtractor2.IsPoseEnabled(1));
}
}
}

5
com.unity.ml-agents/CHANGELOG.md


#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
## [1.3.0-preview] 2020-08-12
## [1.3.0-preview] - 2020-08-12
### Major Changes
#### com.unity.ml-agents (C#)

Previously, this would result in an infinite loop and cause the editor to hang.
(#4226)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- The algorithm used to normalize observations was introducing NaNs if the initial observations were too large
due to incorrect initialization. The initialization was fixed and is now the observation means from the
first trajectory processed. (#4299)
## [1.2.0-preview] - 2020-07-15

16
ml-agents/mlagents/trainers/policy/tf_policy.py


self.assign_ops: List[tf.Operation] = []
self.update_dict: Dict[str, tf.Tensor] = {}
self.inference_dict: Dict[str, tf.Tensor] = {}
self.first_normalization_update: bool = False
self.graph = tf.Graph()
self.sess = tf.Session(

:param vector_obs: The vector observations to add to the running estimate of the distribution.
"""
if self.use_vec_obs and self.normalize:
self.sess.run(
self.update_normalization_op, feed_dict={self.vector_in: vector_obs}
)
if self.first_normalization_update:
self.sess.run(
self.init_normalization_op, feed_dict={self.vector_in: vector_obs}
)
self.first_normalization_update = False
else:
self.sess.run(
self.update_normalization_op, feed_dict={self.vector_in: vector_obs}
)
@property
def use_vis_obs(self):

self.normalization_steps: Optional[tf.Variable] = None
self.running_mean: Optional[tf.Variable] = None
self.running_variance: Optional[tf.Variable] = None
self.init_normalization_op: Optional[tf.Operation] = None
self.update_normalization_op: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: tf.Tensor = None

self.behavior_spec.observation_shapes
)
if self.normalize:
self.first_normalization_update = True
self.init_normalization_op = normalization_tensors.init_op
self.normalization_steps = normalization_tensors.steps
self.running_mean = normalization_tensors.running_mean
self.running_variance = normalization_tensors.running_variance

2
ml-agents/mlagents/trainers/stats.py


class GaugeWriter(StatsWriter):
"""
Write all stats that we recieve to the timer gauges, so we can track them offline easily
Write all stats that we receive to the timer gauges, so we can track them offline easily
"""
@staticmethod

3
ml-agents/mlagents/trainers/tests/mock_brain.py


memory=memory,
)
steps_list.append(experience)
obs = []
for _shape in observation_shapes:
obs.append(np.ones(_shape, dtype=np.float32))
last_experience = AgentExperience(
obs=obs,
reward=reward,

113
ml-agents/mlagents/trainers/tests/test_nn_policy.py


DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 32
NUM_AGENTS = 12
EPSILON = 1e-7
def create_policy_mock(

assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE)
def test_large_normalization():
behavior_spec = mb.setup_test_behavior_specs(
use_discrete=True, use_visual=False, vector_action_space=[2], vector_obs_space=1
)
# Taken from Walker seed 3713 which causes NaN without proper initialization
large_obs1 = [
1800.00036621,
1799.96972656,
1800.01245117,
1800.07214355,
1800.02758789,
1799.98303223,
1799.88647461,
1799.89575195,
1800.03479004,
1800.14025879,
1800.17675781,
1800.20581055,
1800.33740234,
1800.36450195,
1800.43457031,
1800.45544434,
1800.44604492,
1800.56713867,
1800.73901367,
]
large_obs2 = [
1799.99975586,
1799.96679688,
1799.92980957,
1799.89550781,
1799.93774414,
1799.95300293,
1799.94067383,
1799.92993164,
1799.84057617,
1799.69873047,
1799.70605469,
1799.82849121,
1799.85095215,
1799.76977539,
1799.78283691,
1799.76708984,
1799.67163086,
1799.59191895,
1799.5135498,
1799.45556641,
1799.3717041,
]
policy = TFPolicy(
0,
behavior_spec,
TrainerSettings(network_settings=NetworkSettings(normalize=True)),
"testdir",
False,
)
time_horizon = len(large_obs1)
trajectory = make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
)
for i in range(time_horizon):
trajectory.steps[i].obs[0] = np.array([large_obs1[i]], dtype=np.float32)
trajectory_buffer = trajectory.to_agentbuffer()
policy.update_normalization(trajectory_buffer["vector_obs"])
# Check that the running mean and variance is correct
steps, mean, variance = policy.sess.run(
[policy.normalization_steps, policy.running_mean, policy.running_variance]
)
assert mean[0] == pytest.approx(np.mean(large_obs1, dtype=np.float32), abs=0.01)
assert variance[0] / steps == pytest.approx(
np.var(large_obs1, dtype=np.float32), abs=0.01
)
time_horizon = len(large_obs2)
trajectory = make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
)
for i in range(time_horizon):
trajectory.steps[i].obs[0] = np.array([large_obs2[i]], dtype=np.float32)
trajectory_buffer = trajectory.to_agentbuffer()
policy.update_normalization(trajectory_buffer["vector_obs"])
steps, mean, variance = policy.sess.run(
[policy.normalization_steps, policy.running_mean, policy.running_variance]
)
assert mean[0] == pytest.approx(
np.mean(large_obs1 + large_obs2, dtype=np.float32), abs=0.01
)
assert variance[0] / steps == pytest.approx(
np.var(large_obs1 + large_obs2, dtype=np.float32), abs=0.01
)
time_horizon = 6
trajectory = make_fake_trajectory(
length=time_horizon,

assert steps == 6
assert mean[0] == 0.5
# Note: variance is divided by number of steps, and initialized to 1 to avoid
# divide by 0. The right answer is 0.25
assert (variance[0] - 1) / steps == 0.25
# Note: variance is initalized to the variance of the initial trajectory + EPSILON
# (to avoid divide by 0) and multiplied by the number of steps. The correct answer is 0.25
assert variance[0] / steps == pytest.approx(0.25, abs=0.01)
# Make another update, this time with all 1's
time_horizon = 10
trajectory = make_fake_trajectory(

assert steps == 16
assert mean[0] == 0.8125
assert (variance[0] - 1) / steps == pytest.approx(0.152, abs=0.01)
assert variance[0] / steps == pytest.approx(0.152, abs=0.01)
def test_min_visual_size():

26
ml-agents/mlagents/trainers/tf/models.py


class NormalizerTensors(NamedTuple):
init_op: tf.Operation
update_op: tf.Operation
steps: tf.Tensor
running_mean: tf.Tensor

:return: A NormalizerTensors tuple that holds running mean, running variance, number of steps,
and the update operation.
"""
steps = tf.get_variable(
"normalization_steps",
[],

dtype=tf.float32,
initializer=tf.ones_initializer(),
)
update_normalization = ModelUtils.create_normalizer_update(
initialize_normalization, update_normalization = ModelUtils.create_normalizer_update(
update_normalization, steps, running_mean, running_variance
initialize_normalization,
update_normalization,
steps,
running_mean,
running_variance,
)
@staticmethod

running_mean: tf.Tensor,
running_variance: tf.Tensor,
) -> tf.Operation:
) -> Tuple[tf.Operation, tf.Operation]:
"""
Creates the update operation for the normalizer.
:param vector_input: Vector observation to use for updating the running mean and variance.

update_mean = tf.assign(running_mean, new_mean)
update_variance = tf.assign(running_variance, new_variance)
update_norm_step = tf.assign(steps, total_new_steps)
return tf.group([update_mean, update_variance, update_norm_step])
# First mean and variance calculated normally
initial_mean, initial_variance = tf.nn.moments(vector_input, axes=[0])
initialize_mean = tf.assign(running_mean, initial_mean)
# Multiplied by total_new_step because it is divided by total_new_step in the normalization
initialize_variance = tf.assign(
running_variance,
(initial_variance + EPSILON) * tf.cast(total_new_steps, dtype=tf.float32),
)
return (
tf.group([initialize_mean, initialize_variance, update_norm_step]),
tf.group([update_mean, update_variance, update_norm_step]),
)
@staticmethod
def create_vector_observation_encoder(

2
ml-agents/mlagents/trainers/trainer_controller.py


) in self.param_manager.get_current_lesson_number().items():
for trainer in self.trainers.values():
trainer.stats_reporter.set_stat(
f"Environment/Lesson/{param_name}", lesson_number
f"Environment/Lesson Number/{param_name}", lesson_number
)
for trainer in self.trainers.values():

66
com.unity.ml-agents.extensions/Editor/RigidBodySensorComponentEditor.cs


using UnityEngine;
using UnityEditor;
using Unity.MLAgents.Editor;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Editor
{
[CustomEditor(typeof(RigidBodySensorComponent))]
[CanEditMultipleObjects]
internal class RigidBodySensorComponentEditor : UnityEditor.Editor
{
bool ShowHierarchy = true;
public override void OnInspectorGUI()
{
var so = serializedObject;
so.Update();
var rbSensorComp = so.targetObject as RigidBodySensorComponent;
bool requireExtractorUpdate;
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// All the fields affect the sensor order or observation size,
// So can't be changed at runtime.
EditorGUI.BeginChangeCheck();
EditorGUILayout.PropertyField(so.FindProperty("RootBody"), true);
EditorGUILayout.PropertyField(so.FindProperty("VirtualRoot"), true);
// Changing the root body or virtual root changes the hierarchy, so we need to reset later.
requireExtractorUpdate = EditorGUI.EndChangeCheck();
EditorGUILayout.PropertyField(so.FindProperty("Settings"), true);
// Collapsible tree for the body hierarchy
ShowHierarchy = EditorGUILayout.Foldout(ShowHierarchy, "Hierarchy", true);
if (ShowHierarchy)
{
var treeNodes = rbSensorComp.GetDisplayNodes();
var originalIndent = EditorGUI.indentLevel;
foreach (var node in treeNodes)
{
var obj = node.NodeObject;
var objContents = EditorGUIUtility.ObjectContent(obj, obj.GetType());
EditorGUI.indentLevel = originalIndent + node.Depth;
var enabled = EditorGUILayout.Toggle(objContents, node.Enabled);
rbSensorComp.SetPoseEnabled(node.OriginalIndex, enabled);
}
EditorGUI.indentLevel = originalIndent;
}
EditorGUILayout.PropertyField(so.FindProperty("sensorName"), true);
}
EditorGUI.EndDisabledGroup();
so.ApplyModifiedProperties();
if (requireExtractorUpdate)
{
rbSensorComp.ResetPoseExtractor();
}
}
}
}

11
com.unity.ml-agents.extensions/Editor/RigidBodySensorComponentEditor.cs.meta


fileFormatVersion: 2
guid: 8c3481f5312564501b381742673d3100
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存