浏览代码

Merge master

/release_4_branch
Arthur Juliani 4 年前
当前提交
6bee0fd1
共有 53 个文件被更改,包括 1432 次插入260 次删除
  1. 25
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  2. 33
      Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
  3. 28
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
  4. 44
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
  5. 85
      com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
  6. 28
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
  7. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
  8. 1
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
  9. 8
      com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef
  10. 13
      com.unity.ml-agents/CHANGELOG.md
  11. 69
      com.unity.ml-agents/Runtime/Academy.cs
  12. 32
      com.unity.ml-agents/Tests/Editor/AcademyTests.cs
  13. 22
      com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
  14. 4
      docs/Training-ML-Agents.md
  15. 4
      gym-unity/gym_unity/__init__.py
  16. 4
      ml-agents-envs/mlagents_envs/__init__.py
  17. 20
      ml-agents/mlagents/model_serialization.py
  18. 4
      ml-agents/mlagents/trainers/__init__.py
  19. 17
      ml-agents/mlagents/trainers/ghost/trainer.py
  20. 1
      ml-agents/mlagents/trainers/learn.py
  21. 31
      ml-agents/mlagents/trainers/policy/tf_policy.py
  22. 1
      ml-agents/mlagents/trainers/ppo/trainer.py
  23. 22
      ml-agents/mlagents/trainers/sac/trainer.py
  24. 12
      ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
  25. 4
      ml-agents/mlagents/trainers/tests/test_config_conversion.py
  26. 7
      ml-agents/mlagents/trainers/tests/test_nn_policy.py
  27. 27
      ml-agents/mlagents/trainers/tests/test_ppo.py
  28. 43
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  29. 31
      ml-agents/mlagents/trainers/tests/test_sac.py
  30. 8
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  31. 64
      ml-agents/mlagents/trainers/tests/test_training_status.py
  32. 66
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  33. 22
      ml-agents/mlagents/trainers/trainer/trainer.py
  34. 15
      ml-agents/mlagents/trainers/trainer_controller.py
  35. 5
      ml-agents/mlagents/trainers/trainer_util.py
  36. 2
      ml-agents/mlagents/trainers/training_status.py
  37. 41
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
  38. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta
  39. 93
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
  40. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta
  41. 52
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
  42. 11
      com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta
  43. 63
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs
  44. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta
  45. 113
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
  46. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta
  47. 113
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
  48. 11
      com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta
  49. 66
      com.unity.ml-agents/Runtime/SensorHelper.cs
  50. 11
      com.unity.ml-agents/Runtime/SensorHelper.cs.meta
  51. 98
      ml-agents/mlagents/trainers/policy/checkpoint_manager.py
  52. 92
      ml-agents/mlagents/trainers/tests/test_tf_policy.py
  53. 71
      ml-agents/mlagents/trainers/tests/test_policy.py

25
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


using System;
using Random = UnityEngine.Random;
[RequireComponent(typeof(JointDriveController))] // Required to set joint forces
public class CrawlerAgent : Agent

{
//Add body rotation delta relative to orientation cube
sensor.AddObservation(Quaternion.FromToRotation(body.forward, orientationCube.transform.forward));
//Add pos of target relative to orientation cube
sensor.AddObservation(orientationCube.transform.InverseTransformPoint(target.transform.position));

{
var movingTowardsDot = Vector3.Dot(orientationCube.transform.forward,
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[body].rb.velocity, maximumWalkingSpeed));
;
if (float.IsNaN(movingTowardsDot))
{
throw new ArgumentException(
"NaN in movingTowardsDot.\n" +
$" orientationCube.transform.forward: {orientationCube.transform.forward}\n"+
$" body.velocity: {m_JdController.bodyPartsDict[body].rb.velocity}\n"+
$" maximumWalkingSpeed: {maximumWalkingSpeed}"
);
}
AddReward(0.03f * movingTowardsDot);
}

void RewardFunctionFacingTarget()
{
AddReward(0.01f * Vector3.Dot(orientationCube.transform.forward, body.forward));
var facingReward = Vector3.Dot(orientationCube.transform.forward, body.forward);
if (float.IsNaN(facingReward))
{
throw new ArgumentException(
"NaN in movingTowardsDot.\n" +
$" orientationCube.transform.forward: {orientationCube.transform.forward}\n"+
$" body.forward: {body.forward}"
);
}
AddReward(0.01f * facingReward);
}
/// <summary>

33
Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs


using System;
using MLAgentsExamples;
using UnityEngine;
using Unity.MLAgents;

using Random = UnityEngine.Random;
public class WalkerAgent : Agent
{

// a. Velocity alignment with goal direction.
var moveTowardsTargetReward = Vector3.Dot(cubeForward,
Vector3.ClampMagnitude(m_JdController.bodyPartsDict[hips].rb.velocity, maximumWalkingSpeed));
if (float.IsNaN(moveTowardsTargetReward))
{
throw new ArgumentException(
"NaN in moveTowardsTargetReward.\n" +
$" cubeForward: {cubeForward}\n"+
$" hips.velocity: {m_JdController.bodyPartsDict[hips].rb.velocity}\n"+
$" maximumWalkingSpeed: {maximumWalkingSpeed}"
);
}
if (float.IsNaN(lookAtTargetReward))
{
throw new ArgumentException(
"NaN in lookAtTargetReward.\n" +
$" cubeForward: {cubeForward}\n"+
$" head.forward: {head.forward}"
);
}
var headHeightOverFeetReward =
var headHeightOverFeetReward =
if (float.IsNaN(headHeightOverFeetReward))
{
throw new ArgumentException(
"NaN in headHeightOverFeetReward.\n" +
$" head.position: {head.position}\n"+
$" footL.position: {footL.position}\n"+
$" footR.position: {footR.position}"
);
}
AddReward(
+ 0.02f * moveTowardsTargetReward
+ 0.02f * lookAtTargetReward

28
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs


namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Utility class to track a hierarchy of ArticulationBodies.
/// </summary>
public class ArticulationBodyPoseExtractor : PoseExtractor
{
ArticulationBody[] m_Bodies;

if (rootBody == null)
{
return;
}
if (!rootBody.isRoot)
{
Debug.Log("Must pass ArticulationBody.isRoot");

for (var i = 1; i < numBodies; i++)
{
var body = m_Bodies[i];
var parent = body.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parent];
var currentArticBody = m_Bodies[i];
// Component.GetComponentInParent will consider the provided object as well.
// So start looking from the parent.
var currentGameObject = currentArticBody.gameObject;
var parentGameObject = currentGameObject.transform.parent;
var parentArticBody = parentGameObject.GetComponentInParent<ArticulationBody>();
parentIndices[i] = bodyToIndex[parentArticBody];
/// <inheritdoc/>
protected override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}
/// <inheritdoc/>
protected override Pose GetPoseAt(int index)
{
var body = m_Bodies[index];

}
}
}
#endif // UNITY_2020_1_OR_NEWER

44
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs


namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Settings that define the observations generated for physics-based sensors.
/// </summary>
[Serializable]
public struct PhysicsSensorSettings
{

public bool UseModelSpaceTranslations;
/// <summary>
/// Whether to use model space (relative to the root body) rotatoins as observations.
/// Whether to use model space (relative to the root body) rotations as observations.
/// </summary>
public bool UseModelSpaceRotations;

public bool UseLocalSpaceRotations;
/// <summary>
/// Whether to use model space (relative to the root body) linear velocities as observations.
/// </summary>
public bool UseModelSpaceLinearVelocity;
/// <summary>
/// Whether to use local space (relative to the parent body) linear velocities as observations.
/// </summary>
public bool UseLocalSpaceLinearVelocity;
/// <summary>
/// Creates a PhysicsSensorSettings with reasonable default values.
/// </summary>
/// <returns></returns>

/// </summary>
public bool UseModelSpace
{
get { return UseModelSpaceTranslations || UseModelSpaceRotations; }
get { return UseModelSpaceTranslations || UseModelSpaceRotations || UseModelSpaceLinearVelocity; }
}
/// <summary>

{
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations; }
get { return UseLocalSpaceTranslations || UseLocalSpaceRotations || UseLocalSpaceLinearVelocity; }
}

obsPerTransform += UseModelSpaceRotations ? 4 : 0;
obsPerTransform += UseLocalSpaceTranslations ? 3 : 0;
obsPerTransform += UseLocalSpaceRotations ? 4 : 0;
obsPerTransform += UseModelSpaceLinearVelocity ? 3 : 0;
obsPerTransform += UseLocalSpaceLinearVelocity ? 3 : 0;
return numTransforms * obsPerTransform;
}

var offset = baseOffset;
if (settings.UseModelSpace)
{
foreach (var pose in poseExtractor.ModelSpacePoses)
var poses = poseExtractor.ModelSpacePoses;
var vels = poseExtractor.ModelSpaceVelocities;
for(var i=0; i<poseExtractor.NumPoses; i++)
var pose = poses[i];
if(settings.UseModelSpaceTranslations)
{
writer.Add(pose.position, offset);

{
writer.Add(pose.rotation, offset);
offset += 4;
}
if (settings.UseModelSpaceLinearVelocity)
{
writer.Add(vels[i], offset);
offset += 3;
}
}
}

foreach (var pose in poseExtractor.LocalSpacePoses)
var poses = poseExtractor.LocalSpacePoses;
var vels = poseExtractor.LocalSpaceVelocities;
for(var i=0; i<poseExtractor.NumPoses; i++)
var pose = poses[i];
if(settings.UseLocalSpaceTranslations)
{
writer.Add(pose.position, offset);

{
writer.Add(pose.rotation, offset);
offset += 4;
}
if (settings.UseLocalSpaceLinearVelocity)
{
writer.Add(vels[i], offset);
offset += 3;
}
}
}

85
com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs


Pose[] m_ModelSpacePoses;
Pose[] m_LocalSpacePoses;
Vector3[] m_ModelSpaceLinearVelocities;
Vector3[] m_LocalSpaceLinearVelocities;
/// <summary>
/// Read access to the model space transforms.
/// </summary>

}
/// <summary>
/// Number of transforms in the hierarchy (read-only).
/// Read access to the model space linear velocities.
/// </summary>
public IList<Vector3> ModelSpaceVelocities
{
get { return m_ModelSpaceLinearVelocities; }
}
/// <summary>
/// Read access to the local space linear velocities.
/// </summary>
public IList<Vector3> LocalSpaceVelocities
{
get { return m_LocalSpaceLinearVelocities; }
}
/// <summary>
/// Number of poses in the hierarchy (read-only).
/// </summary>
public int NumPoses
{

/// <summary>
/// Get the parent index of the body at the specified index.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public int GetParentIndex(int index)
{
if (m_ParentIndices == null)
{
return -1;
}
return m_ParentIndices[index];
}
/// <summary>
/// Initialize with the mapping of parent indices.
/// The 0th element is assumed to be -1, indicating that it's the root.
/// </summary>

var numTransforms = parentIndices.Length;
m_ModelSpacePoses = new Pose[numTransforms];
m_LocalSpacePoses = new Pose[numTransforms];
m_ModelSpaceLinearVelocities = new Vector3[numTransforms];
m_LocalSpaceLinearVelocities = new Vector3[numTransforms];
}
/// <summary>

protected abstract Pose GetPoseAt(int index);
/// <summary>
/// Return the world space linear velocity of the i'th object.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
protected abstract Vector3 GetLinearVelocityAt(int index);
/// <summary>
/// Update the internal model space transform storage based on the underlying system.
/// </summary>
public void UpdateModelSpacePoses()

return;
}
var worldTransform = GetPoseAt(0);
var worldToModel = worldTransform.Inverse();
var rootWorldTransform = GetPoseAt(0);
var worldToModel = rootWorldTransform.Inverse();
var rootLinearVel = GetLinearVelocityAt(0);
var currentTransform = GetPoseAt(i);
m_ModelSpacePoses[i] = worldToModel.Multiply(currentTransform);
var currentWorldSpacePose = GetPoseAt(i);
var currentModelSpacePose = worldToModel.Multiply(currentWorldSpacePose);
m_ModelSpacePoses[i] = currentModelSpacePose;
var currentBodyLinearVel = GetLinearVelocityAt(i);
var relativeVelocity = currentBodyLinearVel - rootLinearVel;
m_ModelSpaceLinearVelocities[i] = worldToModel.rotation * relativeVelocity;
}
}

var invParent = parentTransform.Inverse();
var currentTransform = GetPoseAt(i);
m_LocalSpacePoses[i] = invParent.Multiply(currentTransform);
var parentLinearVel = GetLinearVelocityAt(m_ParentIndices[i]);
var currentLinearVel = GetLinearVelocityAt(i);
m_LocalSpaceLinearVelocities[i] = invParent.rotation * (currentLinearVel - parentLinearVel);
m_LocalSpaceLinearVelocities[i] = Vector3.zero;
public void DrawModelSpace(Vector3 offset)
internal void DrawModelSpace(Vector3 offset)
{
UpdateLocalSpacePoses();
UpdateModelSpacePoses();

}
}
/// <summary>
/// Extension methods for the Pose struct, in order to improve the readability of some math.
/// </summary>
public static class PoseExtensions
{
/// <summary>

public static Pose Multiply(this Pose pose, Pose rhs)
{
return rhs.GetTransformedBy(pose);
}
/// <summary>
/// Transform the vector by the pose. Conceptually this is equivalent to treating the Pose
/// as a 4x4 matrix and multiplying the augmented vector.
/// See https://en.wikipedia.org/wiki/Affine_transformation#Augmented_matrix for more details.
/// </summary>
/// <param name="pose"></param>
/// <param name="rhs"></param>
/// <returns></returns>
public static Vector3 Multiply(this Pose pose, Vector3 rhs)
{
return pose.rotation * rhs + pose.position;
}
// TODO optimize inv(A)*B?

28
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs


/// Initialize given a root RigidBody.
/// </summary>
/// <param name="rootBody"></param>
public RigidBodyPoseExtractor(Rigidbody rootBody)
public RigidBodyPoseExtractor(Rigidbody rootBody, GameObject rootGameObject = null)
var rbs = rootBody.GetComponentsInChildren <Rigidbody>();
Rigidbody[] rbs;
if (rootGameObject == null)
{
rbs = rootBody.GetComponentsInChildren<Rigidbody>();
}
else
{
rbs = rootGameObject.GetComponentsInChildren<Rigidbody>();
}
var bodyToIndex = new Dictionary<Rigidbody, int>(rbs.Length);
var parentIndices = new int[rbs.Length];

SetParentIndices(parentIndices);
}
/// <summary>
/// Get the pose of the i'th RigidBody.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
/// <inheritdoc/>
protected override Vector3 GetLinearVelocityAt(int index)
{
return m_Bodies[index].velocity;
}
/// <inheritdoc/>
}
}
}

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs


return Pose.identity;
}
protected override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
public void Init(int[] parentIndices)
{
SetParentIndices(parentIndices);

position = translation
};
}
protected override Vector3 GetLinearVelocityAt(int index)
{
return Vector3.zero;
}
}
[Test]

1
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs


using System.Collections.Generic;
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;

8
com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef


"name": "Unity.ML-Agents.Extensions.EditorTests",
"references": [
"Unity.ML-Agents.Extensions.Editor",
"Unity.ML-Agents.Extensions"
"Unity.ML-Agents.Extensions",
"Unity.ML-Agents"
],
"optionalUnityReferences": [
"TestAssemblies"

],
"excludePlatforms": []
"excludePlatforms": [],
"defineConstraints": [
"UNITY_INCLUDE_TESTS"
]
}

13
com.unity.ml-agents/CHANGELOG.md


and this project adheres to
[Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Major Changes
### Minor Changes
### Bug Fixes
#### com.unity.ml-agents (C#)
- Academy.EnvironmentStep() will now throw an exception if it is called
recursively (for example, by an Agent's CollectObservations method).
Previously, this would result in an infinite loop and cause the editor to hang.
(#4226)
## [1.2.0-preview] - 2020-07-15
### Major Changes

69
com.unity.ml-agents/Runtime/Academy.cs


// Flag used to keep track of the first time the Academy is reset.
bool m_HadFirstReset;
// Whether the Academy is in the middle of a step. This is used to detect and Academy
// step called by user code that is also called by the Academy.
bool m_IsStepping;
// Random seed used for inference.
int m_InferenceSeed;

/// </summary>
public void EnvironmentStep()
{
if (!m_HadFirstReset)
// Check whether we're already in the middle of a step.
// This shouldn't happen generally, but could happen if user code (e.g. CollectObservations)
// that is called by EnvironmentStep() also calls EnvironmentStep(). This would result
// in an infinite loop and/or stack overflow, so stop it before it happens.
if (m_IsStepping)
ForcedFullReset();
throw new UnityAgentsException(
"Academy.EnvironmentStep() called recursively. " +
"This might happen if you call EnvironmentStep() from custom code such as " +
"CollectObservations() or OnActionReceived()."
);
AgentPreStep?.Invoke(m_StepCount);
m_IsStepping = true;
try
{
if (!m_HadFirstReset)
{
ForcedFullReset();
}
m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();
AgentPreStep?.Invoke(m_StepCount);
m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();
using (TimerStack.Instance.Scoped("AgentSendState"))
{
AgentSendState?.Invoke();
}
using (TimerStack.Instance.Scoped("AgentSendState"))
{
AgentSendState?.Invoke();
}
using (TimerStack.Instance.Scoped("DecideAction"))
{
DecideAction?.Invoke();
}
using (TimerStack.Instance.Scoped("DecideAction"))
{
DecideAction?.Invoke();
}
// If the communicator is not on, we need to clear the SideChannel sending queue
if (!IsCommunicatorOn)
{
SideChannelManager.GetSideChannelMessage();
}
// If the communicator is not on, we need to clear the SideChannel sending queue
if (!IsCommunicatorOn)
{
SideChannelManager.GetSideChannelMessage();
using (TimerStack.Instance.Scoped("AgentAct"))
{
AgentAct?.Invoke();
}
using (TimerStack.Instance.Scoped("AgentAct"))
finally
AgentAct?.Invoke();
// Reset m_IsStepping when we're done (or if an exception occurred).
m_IsStepping = false;
}
}

32
com.unity.ml-agents/Tests/Editor/AcademyTests.cs


using NUnit.Framework;
using Unity.MLAgents.Sensors;
using UnityEngine;
#if UNITY_2019_3_OR_NEWER
using System.Reflection;

Assert.AreEqual("com.unity.ml-agents", packageInfo.name);
Assert.AreEqual(Academy.k_PackageVersion, packageInfo.version);
#endif
}
class RecursiveAgent : Agent
{
int m_collectObsCount;
public override void CollectObservations(VectorSensor sensor)
{
m_collectObsCount++;
if (m_collectObsCount == 1)
{
// NEVER DO THIS IN REAL CODE!
Academy.Instance.EnvironmentStep();
}
}
}
[Test]
public void TestRecursiveStepThrows()
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<RecursiveAgent>();
agent.LazyInitialize();
agent.RequestDecision();
Assert.Throws<UnityAgentsException>(() =>
{
Academy.Instance.EnvironmentStep();
});
// Make sure the Academy reset to a good state and is still steppable.
Academy.Instance.EnvironmentStep();
}

22
com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs


namespace Unity.MLAgents.Tests
{
public class SensorTestHelper
public static class SensorTestHelper
var numExpected = expected.Length;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
Assert.AreEqual(fill, output[0]);
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
Assert.AreEqual(fill, output[0]);
sensor.Write(writer);
Assert.AreEqual(expected, output);
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}

4
docs/Training-ML-Agents.md


blocks. See [Profiling in Python](Profiling-Python.md) for more information
on the timers generated.
These artifacts (except the `.nn` file) are updated throughout the training
process and finalized when training completes or is interrupted.
These artifacts are updated throughout the training
process and finalized when training is completed or is interrupted.
#### Stopping and Resuming Training

4
gym-unity/gym_unity/__init__.py


# Version of the library that will be used to upload to pypi
__version__ = "0.18.0"
__version__ = "0.19.0.dev0"
__release_tag__ = "release_4"
__release_tag__ = None

4
ml-agents-envs/mlagents_envs/__init__.py


# Version of the library that will be used to upload to pypi
__version__ = "0.18.0"
__version__ = "0.19.0.dev0"
__release_tag__ = "release_4"
__release_tag__ = None

20
ml-agents/mlagents/model_serialization.py


def export_policy_model(
settings: SerializationSettings, graph: tf.Graph, sess: tf.Session
output_filepath: str,
settings: SerializationSettings,
graph: tf.Graph,
sess: tf.Session,
Exports latest saved model to .nn format for Unity embedding.
Exports a TF graph for a Policy to .nn and/or .onnx format for Unity embedding.
:param output_filepath: file path to output the model (without file suffix)
:param settings: SerializationSettings describing how to export the model
:param graph: Tensorflow Graph for the policy
:param sess: Tensorflow session for the policy
if not os.path.exists(settings.model_path):
os.makedirs(settings.model_path)
# Save frozen graph
frozen_graph_def_path = settings.model_path + "/frozen_graph_def.pb"
with gfile.GFile(frozen_graph_def_path, "wb") as f:

if settings.convert_to_barracuda:
tf2bc.convert(frozen_graph_def_path, settings.model_path + ".nn")
logger.info(f"Exported {settings.model_path}.nn file")
tf2bc.convert(frozen_graph_def_path, f"{output_filepath}.nn")
logger.info(f"Exported {output_filepath}.nn")
# Save to onnx too (if we were able to import it)
if ONNX_EXPORT_ENABLED:

onnx_output_path = settings.model_path + ".onnx"
onnx_output_path = f"{output_filepath}.onnx"
with open(onnx_output_path, "wb") as f:
f.write(onnx_graph.SerializeToString())
logger.info(f"Converting to {onnx_output_path}")

4
ml-agents/mlagents/trainers/__init__.py


# Version of the library that will be used to upload to pypi
__version__ = "0.18.0"
__version__ = "0.19.0.dev0"
__release_tag__ = "release_4"
__release_tag__ = None

17
ml-agents/mlagents/trainers/ghost/trainer.py


self.current_policy_snapshot: Dict[str, List[float]] = {}
self.snapshot_counter: int = 0
self.policies: Dict[str, TFPolicy] = {}
# wrapped_training_team and learning team need to be separate
# in the situation where new agents are created destroyed

"""
self.trainer.end_episode()
def save_model(self, name_behavior_id: str) -> None:
"""
Forwarding call to wrapped trainers save_model
"""
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name
self.trainer.save_model(brain_name)
def export_model(self, name_behavior_id: str) -> None:
def save_model(self) -> None:
Forwarding call to wrapped trainers export_model.
Forwarding call to wrapped trainers save_model.
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name
self.trainer.export_model(brain_name)
self.trainer.save_model()
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec

1
ml-agents/mlagents/trainers/learn.py


GlobalTrainingStatus.load_state(
os.path.join(run_logs_dir, "training_status.json")
)
# Configure CSV, Tensorboard Writers and StatsReporter
# We assume reward and episode length are needed in the CSV.
csv_writer = CSVWriter(

31
ml-agents/mlagents/trainers/policy/tf_policy.py


from typing import Any, Dict, List, Optional, Tuple
import abc
import os
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.tf_utils import tf
from mlagents import tf_utils
from mlagents_envs.exception import UnityException

"""
return list(self.update_dict.keys())
def save_model(self, steps):
def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> None:
Saves the model
:param steps: The number of steps the model was trained for
:return:
Checkpoints the policy on disk.
:param checkpoint_path: filepath to write the checkpoint
:param settings: SerializationSettings for exporting the model.
# Save the TF checkpoint and graph definition
last_checkpoint = os.path.join(self.model_path, f"model-{steps}.ckpt")
self.saver.save(self.sess, last_checkpoint)
if self.saver:
self.saver.save(self.sess, f"{checkpoint_path}.ckpt")
# also save the policy so we have optimized model files for each checkpoint
self.save(checkpoint_path, settings)
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
"""
Saves the serialized model, given a path and SerializationSettings
This method will save the policy graph to the given filepath. The path
should be provided without an extension as multiple serialized model formats
may be generated as a result.
:param output_filepath: path (without suffix) for the model file(s)
:param settings: SerializationSettings for how to save the model.
"""
export_policy_model(output_filepath, settings, self.graph, self.sess)
def update_normalization(self, vector_obs: np.ndarray) -> None:
"""

1
ml-agents/mlagents/trainers/ppo/trainer.py


if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-NNPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)

22
ml-agents/mlagents/trainers/sac/trainer.py


import os
import numpy as np
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed

self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer
def save_model(self, name_behavior_id: str) -> None:
def _checkpoint(self) -> NNCheckpoint:
Saves the model. Overrides the default save_model since we want to save
the replay buffer as well.
Writes a checkpoint model to memory
Overrides the default to save the replay buffer.
self.policy.save_model(self.get_step)
ckpt = super()._checkpoint()
if self.checkpoint_replay_buffer:
self.save_replay_buffer()
return ckpt
def save_model(self) -> None:
"""
Saves the final training model to memory
Overrides the default to save the replay buffer.
"""
super().save_model()
if self.checkpoint_replay_buffer:
self.save_replay_buffer()

) -> None:
"""
Adds policy to trainer.
:param brain_parameters: specifications for policy construction
"""
if self.policy:
logger.warning(

if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)

12
ml-agents/mlagents/trainers/tests/test_barracuda_converter.py


from mlagents.trainers.tests.test_nn_policy import create_policy_mock
from mlagents.trainers.settings import TrainerSettings
from mlagents.tf_utils import tf
from mlagents.model_serialization import SerializationSettings, export_policy_model
from mlagents.model_serialization import SerializationSettings
def test_barracuda_converter():

use_discrete=discrete,
use_visual=visual,
)
policy.save_model(1000)
settings = SerializationSettings(policy.model_path, os.path.join(tmpdir, "test"))
export_policy_model(settings, policy.graph, policy.sess)
settings = SerializationSettings(policy.model_path, "MockBrain")
checkpoint_path = f"{tmpdir}/MockBrain-1"
policy.checkpoint(checkpoint_path, settings)
assert os.path.isfile(os.path.join(tmpdir, "test.nn"))
assert os.path.getsize(os.path.join(tmpdir, "test.nn")) > 100
assert os.path.isfile(checkpoint_path + ".nn")
assert os.path.getsize(checkpoint_path + ".nn") > 100

4
ml-agents/mlagents/trainers/tests/test_config_conversion.py


if trainer_type == TrainerType.PPO:
trainer_config = PPO_CONFIG
trainer_settings_type = PPOSettings
elif trainer_type == TrainerType.SAC:
else:
old_config = yaml.load(trainer_config)
old_config = yaml.safe_load(trainer_config)
old_config[BRAIN_NAME]["use_recurrent"] = use_recurrent
new_config = convert_behaviors(old_config)

7
ml-agents/mlagents/trainers/tests/test_nn_policy.py


import tempfile
import numpy as np
from mlagents.model_serialization import SerializationSettings
from mlagents.tf_utils import tf

policy = create_policy_mock(trainer_params, model_path=path1)
policy.initialize_or_load()
policy._set_step(2000)
policy.save_model(2000)
mock_brain_name = "MockBrain"
checkpoint_path = f"{policy.model_path}/{mock_brain_name}-2000"
serialization_settings = SerializationSettings(policy.model_path, mock_brain_name)
policy.checkpoint(checkpoint_path, serialization_settings)
assert len(os.listdir(tmp_path)) > 0

27
ml-agents/mlagents/trainers/tests/test_ppo.py


from mlagents.tf_utils import tf
import copy
import attr
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.ppo.trainer import PPOTrainer, discount_rewards
from mlagents.trainers.ppo.optimizer import PPOOptimizer

5 # 10 hacked because this function is no longer called through trainer
)
policy_mock.increment_step = mock.Mock(return_value=step_count)
trainer.add_policy("testbehavior", policy_mock)
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy_mock)
trainer._increment_step(5, "testbehavior")
trainer._increment_step(5, trainer.brain_name)
policy_mock.increment_step.assert_called_with(5)
assert trainer.step == step_count

dummy_config, curiosity_dummy_config, use_discrete # noqa: F811
):
mock_brain = mb.setup_test_behavior_specs(
mock_behavior_spec = mb.setup_test_behavior_specs(
use_discrete,
False,
vector_action_space=DISCRETE_ACTION_SPACE

# Test curiosity reward signal
trainer_params.reward_signals = curiosity_dummy_config
mock_brain_name = "MockBrain"
behavior_id = BehaviorIdentifiers.from_name_behavior_id(mock_brain_name)
policy = trainer.create_policy("test", mock_brain)
trainer.add_policy("test", policy)
policy = trainer.create_policy(behavior_id, mock_behavior_spec)
trainer.add_policy(behavior_id, policy)
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_brain)
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_behavior_spec)
# Mock out reward signal eval
buffer["extrinsic_rewards"] = buffer["environment_rewards"]
buffer["extrinsic_returns"] = buffer["environment_rewards"]

vector_action_space=DISCRETE_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
)
mock_brain_name = "MockBrain"
behavior_id = BehaviorIdentifiers.from_name_behavior_id(mock_brain_name)
policy = trainer.create_policy("test_brain", behavior_spec)
trainer.add_policy("test_brain", policy)
policy = trainer.create_policy(behavior_id, behavior_spec)
trainer.add_policy(behavior_id, policy)
trajectory_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
time_horizon = 15

policy = mock.Mock(spec=NNPolicy)
policy.get_current_step.return_value = 2000
trainer.add_policy("test_policy", policy)
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy)
assert trainer.get_policy("test_policy") == policy
# Make sure the summary steps were loaded properly

policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy("test_policy", policy)
trainer.add_policy(behavior_id, policy)
if __name__ == "__main__":

43
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


from unittest import mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue

def _update_policy(self):
return self.update_policy
def add_policy(self):
pass
def add_policy(self, mock_behavior_id, mock_policy):
self.policies[mock_behavior_id] = mock_policy
def create_policy(self):
return mock.Mock()

assert len(arr) == 0
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
def test_advance(mocked_clear_update_buffer):
def test_advance(mocked_clear_update_buffer, mocked_save_model):
mock_policy = mock.Mock()
mock_policy.model_path = "mock_model_path"
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)

# Check that the buffer has been cleared
assert not trainer.should_still_train
assert mocked_clear_update_buffer.call_count > 0
assert mocked_save_model.call_count == 0
@mock.patch("mlagents.trainers.trainer.trainer.Trainer.save_model")
def test_summary_checkpoint(mock_write_summary, mock_save_model):
@mock.patch("mlagents.trainers.trainer.rl_trainer.NNCheckpointManager.add_checkpoint")
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
mock_policy = mock.Mock()
mock_policy.model_path = "mock_model_path"
trainer.add_policy("TestBrain", mock_policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)

]
mock_write_summary.assert_has_calls(calls, any_order=True)
checkpoint_range = range(
checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
)
mock.call(trainer.brain_name)
for step in range(
checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval
mock.call(f"{mock_policy.model_path}/{trainer.brain_name}-{step}", mock.ANY)
for step in checkpoint_range
]
mock_policy.checkpoint.assert_has_calls(calls, any_order=True)
add_checkpoint_calls = [
mock.call(
trainer.brain_name,
NNCheckpoint(
step,
f"{mock_policy.model_path}/{trainer.brain_name}-{step}.nn",
None,
mock.ANY,
),
trainer.trainer_settings.keep_checkpoints,
for step in checkpoint_range
mock_save_model.assert_has_calls(calls, any_order=True)
mock_add_checkpoint.assert_has_calls(add_checkpoint_calls)

31
ml-agents/mlagents/trainers/tests/test_sac.py


import copy
from mlagents.tf_utils import tf
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.sac.optimizer import SACOptimizer

trainer_params = dummy_config
trainer_params.hyperparameters.save_replay_buffer = True
trainer = SACTrainer("test", 1, trainer_params, True, False, 0, "testdir")
policy = trainer.create_policy("test", mock_specs)
trainer.add_policy("test", policy)
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
policy = trainer.create_policy(behavior_id, mock_specs)
trainer.add_policy(behavior_id, policy)
trainer.save_model(trainer.brain_name)
trainer.save_model()
policy = trainer2.create_policy("test", mock_specs)
trainer2.add_policy("test", policy)
policy = trainer2.create_policy(behavior_id, mock_specs)
trainer2.add_policy(behavior_id, policy)
assert trainer2.update_buffer.num_experiences == buffer_len

trainer = SACTrainer("test", 0, dummy_config, True, False, 0, "0")
policy = mock.Mock(spec=NNPolicy)
policy.get_current_step.return_value = 2000
trainer.add_policy("test", policy)
assert trainer.get_policy("test") == policy
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
trainer.add_policy(behavior_id, policy)
assert trainer.get_policy(behavior_id.behavior_id) == policy
# Make sure the summary steps were loaded properly
assert trainer.get_step == 2000

with pytest.raises(RuntimeError):
trainer.add_policy("test", policy)
trainer.add_policy(behavior_id, policy)
def test_advance(dummy_config):

dummy_config.hyperparameters.reward_signal_steps_per_update = 20
dummy_config.hyperparameters.buffer_init_steps = 0
trainer = SACTrainer("test", 0, dummy_config, True, False, 0, "0")
policy = trainer.create_policy("test", specs)
trainer.add_policy("test", policy)
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
policy = trainer.create_policy(behavior_id, specs)
trainer.add_policy(behavior_id, policy)
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")

# Call add_policy and check that we update the correct number of times.
# This is to emulate a load from checkpoint.
policy = trainer.create_policy("test", specs)
behavior_id = BehaviorIdentifiers.from_name_behavior_id(trainer.brain_name)
policy = trainer.create_policy(behavior_id, specs)
trainer.add_policy("test", policy)
trainer.add_policy(behavior_id, policy)
trainer.optimizer.update = mock.Mock()
trainer.optimizer.update_reward_signals = mock.Mock()
trainer.optimizer.update_reward_signals.return_value = {}

8
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


tc.advance.side_effect = take_step_sideeffect
tc._export_graph = MagicMock()
tc._save_model = MagicMock()
tc._save_models = MagicMock()
return tc, trainer_mock

tf_reset_graph.assert_called_once()
env_mock.reset.assert_called_once()
assert tc.advance.call_count == 11
tc._export_graph.assert_not_called()
tc._save_model.assert_not_called()
tc._save_models.assert_not_called()
@patch.object(tf, "reset_default_graph")

tf_reset_graph.assert_called_once()
env_mock.reset.assert_called_once()
assert tc.advance.call_count == trainer_mock.get_max_steps + 1
tc._save_model.assert_called_once()
tc._save_models.assert_called_once()
@pytest.fixture

64
ml-agents/mlagents/trainers/tests/test_training_status.py


import unittest
import json
from enum import Enum
import time
)
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpointManager,
NNCheckpoint,
)

)
assert unknown_category is None
assert unknown_key is None
def test_model_management(tmpdir):
results_path = os.path.join(tmpdir, "results")
brain_name = "Mock_brain"
final_model_path = os.path.join(results_path, brain_name)
test_checkpoint_list = [
{
"steps": 1,
"file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
"reward": 1.312,
"creation_time": time.time(),
},
{
"steps": 2,
"file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
"reward": 1.912,
"creation_time": time.time(),
},
{
"steps": 3,
"file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
"reward": 2.312,
"creation_time": time.time(),
},
]
GlobalTrainingStatus.set_parameter_state(
brain_name, StatusType.CHECKPOINTS, test_checkpoint_list
)
new_checkpoint_4 = NNCheckpoint(
4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time()
)
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
new_checkpoint_5 = NNCheckpoint(
5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time()
)
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
final_model_path = f"{final_model_path}.nn"
final_model_time = time.time()
current_step = 6
final_model = NNCheckpoint(current_step, final_model_path, 3.294, final_model_time)
NNCheckpointManager.track_final_checkpoint(brain_name, final_model)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
StatusType.CHECKPOINTS.value
]
assert check_checkpoints is not None
final_model = GlobalTrainingStatus.saved_state[StatusType.FINAL_CHECKPOINT.value]
assert final_model is not None
class StatsMetaDataTest(unittest.TestCase):

66
ml-agents/mlagents/trainers/trainer/rl_trainer.py


# # Unity ML-Agents Toolkit
from typing import Dict, List
import os
from typing import Dict, List, Optional
import attr
from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,
)
from mlagents_envs.timers import timed
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer

"""
return False
def _policy_mean_reward(self) -> Optional[float]:
""" Returns the mean episode reward for the current policy. """
rewards = self.cumulative_returns_since_policy_update
if len(rewards) == 0:
return None
else:
return sum(rewards) / len(rewards)
@timed
def _checkpoint(self) -> NNCheckpoint:
"""
Checkpoints the policy associated with this trainer.
"""
n_policies = len(self.policies.keys())
if n_policies > 1:
logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
model_path = policy.model_path
settings = SerializationSettings(model_path, self.brain_name)
checkpoint_path = os.path.join(model_path, f"{self.brain_name}-{self.step}")
policy.checkpoint(checkpoint_path, settings)
new_checkpoint = NNCheckpoint(
int(self.step),
f"{checkpoint_path}.nn",
self._policy_mean_reward(),
time.time(),
)
NNCheckpointManager.add_checkpoint(
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
)
return new_checkpoint
def save_model(self) -> None:
"""
Saves the policy associated with this trainer.
"""
n_policies = len(self.policies.keys())
if n_policies > 1:
logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
policy = list(self.policies.values())[0]
settings = SerializationSettings(policy.model_path, self.brain_name)
model_checkpoint = self._checkpoint()
final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{policy.model_path}.nn"
)
policy.save(policy.model_path, settings)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod
def _update_policy(self) -> bool:
"""

self.trainer_settings.checkpoint_interval
)
if step_after_process >= self._next_save_step and self.get_step != 0:
logger.info(f"Checkpointing model for {self.brain_name}.")
self.save_model(self.brain_name)
self._checkpoint()
def advance(self) -> None:
"""

22
ml-agents/mlagents/trainers/trainer/trainer.py


# # Unity ML-Agents Toolkit
from typing import List, Deque
from typing import List, Deque, Dict
from mlagents_envs.timers import timed
from mlagents.model_serialization import export_policy_model, SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.trajectory import Trajectory

self.step: int = 0
self.artifact_path = artifact_path
self.summary_freq = self.trainer_settings.summary_freq
self.policies: Dict[str, TFPolicy] = {}
@property
def stats_reporter(self):

"""
return self._reward_buffer
@timed
def save_model(self, name_behavior_id: str) -> None:
"""
Saves the model
"""
self.get_policy(name_behavior_id).save_model(self.get_step)
def export_model(self, name_behavior_id: str) -> None:
@abc.abstractmethod
def save_model(self) -> None:
Exports the model
Saves model file(s) for the policy or policies associated with this trainer.
policy = self.get_policy(name_behavior_id)
settings = SerializationSettings(policy.model_path, self.brain_name)
export_policy_model(settings, policy.graph, policy.sess)
pass
@abc.abstractmethod
def end_episode(self):

15
ml-agents/mlagents/trainers/trainer_controller.py


tf.set_random_seed(training_seed)
@timed
def _save_model(self):
def _save_models(self):
for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].save_model(name_behavior_id)
self.trainers[brain_name].save_model()
self.logger.info("Saved Model")
def _save_model_when_interrupted(self):

self._save_model()
self._save_models()
Exports latest saved models to .nn format for Unity embedding.
Saves models for all trainers.
for name_behavior_id in self.brain_name_to_identifier[brain_name]:
self.trainers[brain_name].export_model(name_behavior_id)
self.trainers[brain_name].save_model()
@staticmethod
def _create_output_path(output_path):

raise ex
finally:
if self.train_model:
self._save_model()
self._export_graph()
self._save_models()
def end_trainer_episodes(self) -> None:
# Reward buffers reset takes place only for curriculum learning

5
ml-agents/mlagents/trainers/trainer_util.py


self.ghost_controller = GhostController()
def generate(self, brain_name: str) -> Trainer:
if brain_name not in self.trainer_config.keys():
logger.warning(
f"Behavior name {brain_name} does not match any behaviors specified in the trainer configuration file:"
f"{sorted(self.trainer_config.keys())}"
)
return initialize_trainer(
self.trainer_config[brain_name],
brain_name,

2
ml-agents/mlagents/trainers/training_status.py


class StatusType(Enum):
LESSON_NUM = "lesson_num"
STATS_METADATA = "metadata"
CHECKPOINTS = "checkpoints"
FINAL_CHECKPOINT = "final_checkpoint"
@attr.s(auto_attribs=True)

41
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs


#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
public class ArticulationBodySensorComponent : SensorComponent
{
public ArticulationBody RootBody;
[SerializeField]
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
public string sensorName;
/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, Settings, sensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
}
}
}
#endif // UNITY_2020_1_OR_NEWER

11
com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta


fileFormatVersion: 2
guid: e57a788acd5e049c6aa9642b450ca318
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

93
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs


using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
/// </summary>
public class PhysicsBodySensor : ISensor
{
int[] m_Shape;
string m_SensorName;
PoseExtractor m_PoseExtractor;
PhysicsSensorSettings m_Settings;
/// <summary>
/// Construct a new PhysicsBodySensor
/// </summary>
/// <param name="rootBody"></param>
/// <param name="settings"></param>
/// <param name="sensorName"></param>
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
}
#if UNITY_2020_1_OR_NEWER
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null)
{
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody);
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName;
m_Settings = settings;
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses);
m_Shape = new[] { numTransformObservations };
}
#endif
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor);
return numWritten;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public void Update()
{
if (m_Settings.UseModelSpace)
{
m_PoseExtractor.UpdateModelSpacePoses();
}
if (m_Settings.UseLocalSpace)
{
m_PoseExtractor.UpdateLocalSpacePoses();
}
}
/// <inheritdoc/>
public void Reset() {}
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public string GetName()
{
return m_SensorName;
}
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta


fileFormatVersion: 2
guid: 254640b3578a24bd2838c1fa39f1011a
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

52
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs


using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Extensions.Sensors
{
/// <summary>
/// Editor component that creates a PhysicsBodySensor for the Agent.
/// </summary>
public class RigidBodySensorComponent : SensorComponent
{
/// <summary>
/// The root Rigidbody of the system.
/// </summary>
public Rigidbody RootBody;
/// <summary>
/// Settings defining what types of observations will be generated.
/// </summary>
[SerializeField]
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default();
/// <summary>
/// Optional sensor name. This must be unique for each Agent.
/// </summary>
public string sensorName;
/// <summary>
/// Creates a PhysicsBodySensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
if (RootBody == null)
{
return new[] { 0 };
}
// TODO static method in PhysicsBodySensor?
// TODO only update PoseExtractor when body changes?
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject);
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses);
return new[] { numTransformObservations };
}
}
}

11
com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta


fileFormatVersion: 2
guid: df0f8be9a37d6486498061e2cbc4cd94
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

63
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs


#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class ArticulationBodyPoseExtractorTests
{
[TearDown]
public void RemoveGameObjects()
{
var objects = GameObject.FindObjectsOfType<GameObject>();
foreach (var o in objects)
{
UnityEngine.Object.DestroyImmediate(o);
}
}
[Test]
public void TestNullRoot()
{
var poseExtractor = new ArticulationBodyPoseExtractor(null);
// These should be no-ops
poseExtractor.UpdateLocalSpacePoses();
poseExtractor.UpdateModelSpacePoses();
Assert.AreEqual(0, poseExtractor.NumPoses);
}
[Test]
public void TestSingleBody()
{
var go = new GameObject();
var rootArticBody = go.AddComponent<ArticulationBody>();
var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody);
Assert.AreEqual(1, poseExtractor.NumPoses);
}
[Test]
public void TestTwoBodies()
{
// * rootObj
// - rootArticBody
// * leafGameObj
// - leafArticBody
var rootObj = new GameObject();
var rootArticBody = rootObj.AddComponent<ArticulationBody>();
var leafGameObj = new GameObject();
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>();
leafGameObj.transform.SetParent(rootObj.transform);
leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody);
Assert.AreEqual(2, poseExtractor.NumPoses);
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0));
Assert.AreEqual(0, poseExtractor.GetParentIndex(1));
}
}
}
#endif

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta


fileFormatVersion: 2
guid: 934ea08cde59d4356bc41e040d333c3d
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

113
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs


#if UNITY_2020_1_OR_NEWER
using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public class ArticulationBodySensorTests
{
[Test]
public void TestNullRootBody()
{
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>();
var sensor = sensorComponent.CreateSensor();
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
[Test]
public void TestSingleBody()
{
var gameObj = new GameObject();
var articulationBody = gameObj.AddComponent<ArticulationBody>();
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>();
sensorComponent.RootBody = articulationBody;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceLinearVelocity = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceRotations = true
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
0f, 0f, 0f, // ModelSpaceLinearVelocity
0f, 0f, 0f, // LocalSpaceTranslations
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
public void TestBodiesWithJoint()
{
var rootObj = new GameObject();
var rootArticBody = rootObj.AddComponent<ArticulationBody>();
var middleGamObj = new GameObject();
var middleArticBody = middleGamObj.AddComponent<ArticulationBody>();
middleArticBody.AddForce(new Vector3(0f, 1f, 0f));
middleGamObj.transform.SetParent(rootObj.transform);
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
middleArticBody.jointType = ArticulationJointType.RevoluteJoint;
var leafGameObj = new GameObject();
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>();
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
leafArticBody.jointType = ArticulationJointType.RevoluteJoint;
#if UNITY_2020_2_OR_NEWER
// ArticulationBody.velocity is read-only in 2020.1
rootArticBody.velocity = new Vector3(1f, 0f, 0f);
middleArticBody.velocity = new Vector3(0f, 1f, 0f);
leafArticBody.velocity = new Vector3(0f, 0f, 1f);
#endif
var sensorComponent = rootObj.AddComponent<ArticulationBodySensorComponent>();
sensorComponent.RootBody = rootArticBody;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseLocalSpaceTranslations = true,
#if UNITY_2020_2_OR_NEWER
UseLocalSpaceLinearVelocity = true
#endif
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
// Model space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Middle pos
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
// Local space
0f, 0f, 0f, // Root pos
#if UNITY_2020_2_OR_NEWER
0f, 0f, 0f, // Root vel
#endif
13.37f, 0f, 0f, // Attached pos
#if UNITY_2020_2_OR_NEWER
-1f, 1f, 0f, // Attached vel
#endif
4.2f, 0f, 0f, // Leaf pos
#if UNITY_2020_2_OR_NEWER
0f, -1f, 1f // Leaf vel
#endif
};
SensorTestHelper.CompareObservation(sensor, expected);
}
}
}
#endif // #if UNITY_2020_1_OR_NEWER

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta


fileFormatVersion: 2
guid: 0ef757469348342418a68826f51d0783
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

113
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs


using UnityEngine;
using NUnit.Framework;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Extensions.Sensors;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public static class SensorTestHelper
{
public static void CompareObservation(ISensor sensor, float[] expected)
{
string errorMessage;
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage);
Assert.IsTrue(isOK, errorMessage);
}
}
public class RigidBodySensorTests
{
[Test]
public void TestNullRootBody()
{
var gameObj = new GameObject();
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>();
var sensor = sensorComponent.CreateSensor();
SensorTestHelper.CompareObservation(sensor, new float[0]);
}
[Test]
public void TestSingleRigidbody()
{
var gameObj = new GameObject();
var rootRb = gameObj.AddComponent<Rigidbody>();
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceLinearVelocity = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceRotations = true
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
0f, 0f, 0f, // ModelSpaceLinearVelocity
0f, 0f, 0f, // LocalSpaceTranslations
0f, 0f, 0f, 1f // LocalSpaceRotations
};
SensorTestHelper.CompareObservation(sensor, expected);
}
[Test]
public void TestBodiesWithJoint()
{
var rootObj = new GameObject();
var rootRb = rootObj.AddComponent<Rigidbody>();
rootRb.velocity = new Vector3(1f, 0f, 0f);
var middleGamObj = new GameObject();
var middleRb = middleGamObj.AddComponent<Rigidbody>();
middleRb.velocity = new Vector3(0f, 1f, 0f);
middleGamObj.transform.SetParent(rootObj.transform);
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f);
var joint = middleGamObj.AddComponent<ConfigurableJoint>();
joint.connectedBody = rootRb;
var leafGameObj = new GameObject();
var leafRb = leafGameObj.AddComponent<Rigidbody>();
leafRb.velocity = new Vector3(0f, 0f, 1f);
leafGameObj.transform.SetParent(middleGamObj.transform);
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f);
var joint2 = leafGameObj.AddComponent<ConfigurableJoint>();
joint2.connectedBody = middleRb;
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>();
sensorComponent.RootBody = rootRb;
sensorComponent.Settings = new PhysicsSensorSettings
{
UseModelSpaceTranslations = true,
UseLocalSpaceTranslations = true,
UseLocalSpaceLinearVelocity = true
};
var sensor = sensorComponent.CreateSensor();
sensor.Update();
var expected = new[]
{
// Model space
0f, 0f, 0f, // Root pos
13.37f, 0f, 0f, // Middle pos
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
// Local space
0f, 0f, 0f, // Root pos
0f, 0f, 0f, // Root vel
13.37f, 0f, 0f, // Attached pos
-1f, 1f, 0f, // Attached vel
4.2f, 0f, 0f, // Leaf pos
0f, -1f, 1f // Leaf vel
};
SensorTestHelper.CompareObservation(sensor, expected);
}
}
}

11
com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta


fileFormatVersion: 2
guid: d8daf5517a7c94bfd9ac7f45f8d1bcd3
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

66
com.unity.ml-agents/Runtime/SensorHelper.cs


using UnityEngine;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// Utility methods related to <see cref="ISensor"/> implementations.
/// </summary>
public static class SensorHelper
{
/// <summary>
/// Generates the observations for the provided sensor, and returns true if they equal the
/// expected values. If they are unequal, errorMessage is also set.
/// This should not generally be used in production code. It is only intended for
/// simplifying unit tests.
/// </summary>
/// <param name="sensor"></param>
/// <param name="expected"></param>
/// <param name="errorMessage"></param>
/// <returns></returns>
public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage)
{
var numExpected = expected.Length;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "Error setting output buffer.";
return false;
}
}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
return false;
}
}
sensor.Write(writer);
for (var i = 0; i < output.Length; i++)
{
if (expected[i] != output[i])
{
errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} ";
return false;
}
}
errorMessage = null;
return true;
}
}
}

11
com.unity.ml-agents/Runtime/SensorHelper.cs.meta


fileFormatVersion: 2
guid: 7c1189c0af42c46f7b533350d49ad3e7
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

98
ml-agents/mlagents/trainers/policy/checkpoint_manager.py


# # Unity ML-Agents Toolkit
from typing import Dict, Any, Optional, List
import os
import attr
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
@attr.s(auto_attribs=True)
class NNCheckpoint:
steps: int
file_path: str
reward: Optional[float]
creation_time: float
class NNCheckpointManager:
@staticmethod
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
behavior_name, StatusType.CHECKPOINTS
)
if not checkpoint_list:
checkpoint_list = []
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoint_list
)
return checkpoint_list
@staticmethod
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None:
"""
Removes a checkpoint stored in checkpoint_list.
If checkpoint cannot be found, no action is done.
:param checkpoint: A checkpoint stored in checkpoint_list
"""
file_path: str = checkpoint["file_path"]
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Removed checkpoint model {file_path}.")
else:
logger.info(f"Checkpoint at {file_path} could not be found.")
return
@classmethod
def _cleanup_extra_checkpoints(
cls, checkpoints: List[Dict], keep_checkpoints: int
) -> List[Dict]:
"""
Ensures that the number of checkpoints stored are within the number
of checkpoints the user defines. If the limit is hit, checkpoints are
removed to create room for the next checkpoint to be inserted.
:param behavior_name: The behavior name whose checkpoints we will mange.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
while len(checkpoints) > keep_checkpoints:
if keep_checkpoints <= 0 or len(checkpoints) == 0:
break
NNCheckpointManager.remove_checkpoint(checkpoints.pop(0))
return checkpoints
@classmethod
def add_checkpoint(
cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int
) -> None:
"""
Make room for new checkpoint if needed and insert new checkpoint information.
:param behavior_name: Behavior name for the checkpoint.
:param new_checkpoint: The new checkpoint to be recorded.
:param keep_checkpoints: Number of checkpoints to record (user-defined).
"""
new_checkpoint_dict = attr.asdict(new_checkpoint)
checkpoints = cls.get_checkpoints(behavior_name)
checkpoints.append(new_checkpoint_dict)
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.CHECKPOINTS, checkpoints
)
@classmethod
def track_final_checkpoint(
cls, behavior_name: str, final_checkpoint: NNCheckpoint
) -> None:
"""
Ensures number of checkpoints stored is within the max number of checkpoints
defined by the user and finally stores the information about the final
model (or intermediate model if training is interrupted).
:param behavior_name: Behavior name of the model.
:param final_checkpoint: Checkpoint information for the final model.
"""
final_model_dict = attr.asdict(final_checkpoint)
GlobalTrainingStatus.set_parameter_state(
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict
)

92
ml-agents/mlagents/trainers/tests/test_tf_policy.py


from mlagents.model_serialization import SerializationSettings
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
from mlagents.trainers.action_info import ActionInfo
from unittest.mock import MagicMock
from unittest import mock
from mlagents.trainers.settings import TrainerSettings
import numpy as np
def basic_mock_brain():
mock_brain = MagicMock()
mock_brain.vector_action_space_type = "continuous"
mock_brain.vector_observation_space_size = 1
mock_brain.vector_action_space_size = [1]
mock_brain.brain_name = "MockBrain"
return mock_brain
class FakePolicy(TFPolicy):
def create_tf_graph(self):
pass
def get_trainable_variables(self):
return []
def test_take_action_returns_empty_with_no_agents():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
# Doesn't really matter what this is
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1)
no_agent_step = DecisionSteps.empty(dummy_groupspec)
result = policy.get_action(no_agent_step)
assert result == ActionInfo.empty()
def test_take_action_returns_nones_on_missing_values():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
policy.evaluate = MagicMock(return_value={})
policy.save_memories = MagicMock()
step_with_agents = DecisionSteps(
[], np.array([], dtype=np.float32), np.array([0]), None
)
result = policy.get_action(step_with_agents, worker_id=0)
assert result == ActionInfo(None, None, {}, [0])
def test_take_action_returns_action_info_when_available():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),
"value": np.array([1.1], dtype=np.float32),
}
policy.evaluate = MagicMock(return_value=policy_eval_out)
step_with_agents = DecisionSteps(
[], np.array([], dtype=np.float32), np.array([0]), None
)
result = policy.get_action(step_with_agents)
expected = ActionInfo(
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0]
)
assert result == expected
def test_convert_version_string():
result = TFPolicy._convert_version_string("200.300.100")
assert result == (200, 300, 100)
# Test dev versions
result = TFPolicy._convert_version_string("200.300.100.dev0")
assert result == (200, 300, 100)
@mock.patch("mlagents.trainers.policy.tf_policy.export_policy_model")
@mock.patch("time.time", mock.MagicMock(return_value=12345))
def test_checkpoint_writes_tf_and_nn_checkpoints(export_policy_model_mock):
mock_brain = basic_mock_brain()
test_seed = 4 # moving up in the world
policy = FakePolicy(test_seed, mock_brain, TrainerSettings(), "output")
n_steps = 5
policy.get_current_step = MagicMock(return_value=n_steps)
policy.saver = MagicMock()
serialization_settings = SerializationSettings("output", mock_brain.brain_name)
checkpoint_path = f"output/{mock_brain.brain_name}-{n_steps}"
policy.checkpoint(checkpoint_path, serialization_settings)
policy.saver.save.assert_called_once_with(policy.sess, f"{checkpoint_path}.ckpt")
export_policy_model_mock.assert_called_once_with(
checkpoint_path, serialization_settings, policy.graph, policy.sess
)

71
ml-agents/mlagents/trainers/tests/test_policy.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
from mlagents.trainers.action_info import ActionInfo
from unittest.mock import MagicMock
from mlagents.trainers.settings import TrainerSettings
import numpy as np
def basic_mock_brain():
mock_brain = MagicMock()
mock_brain.vector_action_space_type = "continuous"
mock_brain.vector_observation_space_size = 1
mock_brain.vector_action_space_size = [1]
return mock_brain
class FakePolicy(TFPolicy):
def create_tf_graph(self):
pass
def get_trainable_variables(self):
return []
def test_take_action_returns_empty_with_no_agents():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
# Doesn't really matter what this is
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1)
no_agent_step = DecisionSteps.empty(dummy_groupspec)
result = policy.get_action(no_agent_step)
assert result == ActionInfo.empty()
def test_take_action_returns_nones_on_missing_values():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
policy.evaluate = MagicMock(return_value={})
policy.save_memories = MagicMock()
step_with_agents = DecisionSteps(
[], np.array([], dtype=np.float32), np.array([0]), None
)
result = policy.get_action(step_with_agents, worker_id=0)
assert result == ActionInfo(None, None, {}, [0])
def test_take_action_returns_action_info_when_available():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),
"value": np.array([1.1], dtype=np.float32),
}
policy.evaluate = MagicMock(return_value=policy_eval_out)
step_with_agents = DecisionSteps(
[], np.array([], dtype=np.float32), np.array([0]), None
)
result = policy.get_action(step_with_agents)
expected = ActionInfo(
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0]
)
assert result == expected
def test_convert_version_string():
result = TFPolicy._convert_version_string("200.300.100")
assert result == (200, 300, 100)
# Test dev versions
result = TFPolicy._convert_version_string("200.300.100.dev0")
assert result == (200, 300, 100)
正在加载...
取消
保存