Arthur Juliani
5 年前
当前提交
6bee0fd1
共有 53 个文件被更改,包括 1432 次插入 和 260 次删除
-
25Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
-
33Project/Assets/ML-Agents/Examples/Walker/Scripts/WalkerAgent.cs
-
28com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodyPoseExtractor.cs
-
44com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsSensorSettings.cs
-
85com.unity.ml-agents.extensions/Runtime/Sensors/PoseExtractor.cs
-
28com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodyPoseExtractor.cs
-
11com.unity.ml-agents.extensions/Tests/Editor/Sensors/PoseExtractorTests.cs
-
1com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodyPoseExtractorTests.cs
-
8com.unity.ml-agents.extensions/Tests/Editor/Unity.ML-Agents.Extensions.EditorTests.asmdef
-
13com.unity.ml-agents/CHANGELOG.md
-
69com.unity.ml-agents/Runtime/Academy.cs
-
32com.unity.ml-agents/Tests/Editor/AcademyTests.cs
-
22com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
-
4docs/Training-ML-Agents.md
-
4gym-unity/gym_unity/__init__.py
-
4ml-agents-envs/mlagents_envs/__init__.py
-
20ml-agents/mlagents/model_serialization.py
-
4ml-agents/mlagents/trainers/__init__.py
-
17ml-agents/mlagents/trainers/ghost/trainer.py
-
1ml-agents/mlagents/trainers/learn.py
-
31ml-agents/mlagents/trainers/policy/tf_policy.py
-
1ml-agents/mlagents/trainers/ppo/trainer.py
-
22ml-agents/mlagents/trainers/sac/trainer.py
-
12ml-agents/mlagents/trainers/tests/test_barracuda_converter.py
-
4ml-agents/mlagents/trainers/tests/test_config_conversion.py
-
7ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
27ml-agents/mlagents/trainers/tests/test_ppo.py
-
43ml-agents/mlagents/trainers/tests/test_rl_trainer.py
-
31ml-agents/mlagents/trainers/tests/test_sac.py
-
8ml-agents/mlagents/trainers/tests/test_trainer_controller.py
-
64ml-agents/mlagents/trainers/tests/test_training_status.py
-
66ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
22ml-agents/mlagents/trainers/trainer/trainer.py
-
15ml-agents/mlagents/trainers/trainer_controller.py
-
5ml-agents/mlagents/trainers/trainer_util.py
-
2ml-agents/mlagents/trainers/training_status.py
-
41com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/ArticulationBodySensorComponent.cs.meta
-
93com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/PhysicsBodySensor.cs.meta
-
52com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs
-
11com.unity.ml-agents.extensions/Runtime/Sensors/RigidBodySensorComponent.cs.meta
-
63com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs
-
11com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodyPoseExtractorTests.cs.meta
-
113com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs
-
11com.unity.ml-agents.extensions/Tests/Editor/Sensors/ArticulationBodySensorTests.cs.meta
-
113com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs
-
11com.unity.ml-agents.extensions/Tests/Editor/Sensors/RigidBodySensorTests.cs.meta
-
66com.unity.ml-agents/Runtime/SensorHelper.cs
-
11com.unity.ml-agents/Runtime/SensorHelper.cs.meta
-
98ml-agents/mlagents/trainers/policy/checkpoint_manager.py
-
92ml-agents/mlagents/trainers/tests/test_tf_policy.py
-
71ml-agents/mlagents/trainers/tests/test_policy.py
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.18.0" |
|||
__version__ = "0.19.0.dev0" |
|||
__release_tag__ = "release_4" |
|||
__release_tag__ = None |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.18.0" |
|||
__version__ = "0.19.0.dev0" |
|||
__release_tag__ = "release_4" |
|||
__release_tag__ = None |
|
|||
# Version of the library that will be used to upload to pypi |
|||
__version__ = "0.18.0" |
|||
__version__ = "0.19.0.dev0" |
|||
__release_tag__ = "release_4" |
|||
__release_tag__ = None |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
public class ArticulationBodySensorComponent : SensorComponent |
|||
{ |
|||
public ArticulationBody RootBody; |
|||
|
|||
[SerializeField] |
|||
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); |
|||
public string sensorName; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsBodySensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new PhysicsBodySensor(RootBody, Settings, sensorName); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
if (RootBody == null) |
|||
{ |
|||
return new[] { 0 }; |
|||
} |
|||
|
|||
// TODO static method in PhysicsBodySensor?
|
|||
// TODO only update PoseExtractor when body changes?
|
|||
var poseExtractor = new ArticulationBodyPoseExtractor(RootBody); |
|||
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); |
|||
return new[] { numTransformObservations }; |
|||
} |
|||
} |
|||
|
|||
} |
|||
#endif // UNITY_2020_1_OR_NEWER
|
|
|||
fileFormatVersion: 2 |
|||
guid: e57a788acd5e049c6aa9642b450ca318 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// ISensor implementation that generates observations for a group of Rigidbodies or ArticulationBodies.
|
|||
/// </summary>
|
|||
public class PhysicsBodySensor : ISensor |
|||
{ |
|||
int[] m_Shape; |
|||
string m_SensorName; |
|||
|
|||
PoseExtractor m_PoseExtractor; |
|||
PhysicsSensorSettings m_Settings; |
|||
|
|||
/// <summary>
|
|||
/// Construct a new PhysicsBodySensor
|
|||
/// </summary>
|
|||
/// <param name="rootBody"></param>
|
|||
/// <param name="settings"></param>
|
|||
/// <param name="sensorName"></param>
|
|||
public PhysicsBodySensor(Rigidbody rootBody, GameObject rootGameObject, PhysicsSensorSettings settings, string sensorName=null) |
|||
{ |
|||
m_PoseExtractor = new RigidBodyPoseExtractor(rootBody, rootGameObject); |
|||
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"PhysicsBodySensor:{rootBody?.name}" : sensorName; |
|||
m_Settings = settings; |
|||
|
|||
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); |
|||
m_Shape = new[] { numTransformObservations }; |
|||
} |
|||
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
public PhysicsBodySensor(ArticulationBody rootBody, PhysicsSensorSettings settings, string sensorName=null) |
|||
{ |
|||
m_PoseExtractor = new ArticulationBodyPoseExtractor(rootBody); |
|||
m_SensorName = string.IsNullOrEmpty(sensorName) ? $"ArticulationBodySensor:{rootBody?.name}" : sensorName; |
|||
m_Settings = settings; |
|||
|
|||
var numTransformObservations = settings.TransformSize(m_PoseExtractor.NumPoses); |
|||
m_Shape = new[] { numTransformObservations }; |
|||
} |
|||
#endif
|
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
var numWritten = writer.WritePoses(m_Settings, m_PoseExtractor); |
|||
return numWritten; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() |
|||
{ |
|||
if (m_Settings.UseModelSpace) |
|||
{ |
|||
m_PoseExtractor.UpdateModelSpacePoses(); |
|||
} |
|||
|
|||
if (m_Settings.UseLocalSpace) |
|||
{ |
|||
m_PoseExtractor.UpdateLocalSpacePoses(); |
|||
} |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public string GetName() |
|||
{ |
|||
return m_SensorName; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 254640b3578a24bd2838c1fa39f1011a |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Editor component that creates a PhysicsBodySensor for the Agent.
|
|||
/// </summary>
|
|||
public class RigidBodySensorComponent : SensorComponent |
|||
{ |
|||
/// <summary>
|
|||
/// The root Rigidbody of the system.
|
|||
/// </summary>
|
|||
public Rigidbody RootBody; |
|||
|
|||
/// <summary>
|
|||
/// Settings defining what types of observations will be generated.
|
|||
/// </summary>
|
|||
[SerializeField] |
|||
public PhysicsSensorSettings Settings = PhysicsSensorSettings.Default(); |
|||
|
|||
/// <summary>
|
|||
/// Optional sensor name. This must be unique for each Agent.
|
|||
/// </summary>
|
|||
public string sensorName; |
|||
|
|||
/// <summary>
|
|||
/// Creates a PhysicsBodySensor.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
public override ISensor CreateSensor() |
|||
{ |
|||
return new PhysicsBodySensor(RootBody, gameObject, Settings, sensorName); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public override int[] GetObservationShape() |
|||
{ |
|||
if (RootBody == null) |
|||
{ |
|||
return new[] { 0 }; |
|||
} |
|||
|
|||
// TODO static method in PhysicsBodySensor?
|
|||
// TODO only update PoseExtractor when body changes?
|
|||
var poseExtractor = new RigidBodyPoseExtractor(RootBody, gameObject); |
|||
var numTransformObservations = Settings.TransformSize(poseExtractor.NumPoses); |
|||
return new[] { numTransformObservations }; |
|||
} |
|||
} |
|||
|
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: df0f8be9a37d6486498061e2cbc4cd94 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
public class ArticulationBodyPoseExtractorTests |
|||
{ |
|||
[TearDown] |
|||
public void RemoveGameObjects() |
|||
{ |
|||
var objects = GameObject.FindObjectsOfType<GameObject>(); |
|||
foreach (var o in objects) |
|||
{ |
|||
UnityEngine.Object.DestroyImmediate(o); |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestNullRoot() |
|||
{ |
|||
var poseExtractor = new ArticulationBodyPoseExtractor(null); |
|||
// These should be no-ops
|
|||
poseExtractor.UpdateLocalSpacePoses(); |
|||
poseExtractor.UpdateModelSpacePoses(); |
|||
|
|||
Assert.AreEqual(0, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleBody() |
|||
{ |
|||
var go = new GameObject(); |
|||
var rootArticBody = go.AddComponent<ArticulationBody>(); |
|||
var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody); |
|||
Assert.AreEqual(1, poseExtractor.NumPoses); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestTwoBodies() |
|||
{ |
|||
// * rootObj
|
|||
// - rootArticBody
|
|||
// * leafGameObj
|
|||
// - leafArticBody
|
|||
var rootObj = new GameObject(); |
|||
var rootArticBody = rootObj.AddComponent<ArticulationBody>(); |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>(); |
|||
leafGameObj.transform.SetParent(rootObj.transform); |
|||
|
|||
leafArticBody.jointType = ArticulationJointType.RevoluteJoint; |
|||
|
|||
var poseExtractor = new ArticulationBodyPoseExtractor(rootArticBody); |
|||
Assert.AreEqual(2, poseExtractor.NumPoses); |
|||
Assert.AreEqual(-1, poseExtractor.GetParentIndex(0)); |
|||
Assert.AreEqual(0, poseExtractor.GetParentIndex(1)); |
|||
} |
|||
} |
|||
} |
|||
#endif
|
|
|||
fileFormatVersion: 2 |
|||
guid: 934ea08cde59d4356bc41e040d333c3d |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
#if UNITY_2020_1_OR_NEWER
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
|
|||
public class ArticulationBodySensorTests |
|||
{ |
|||
[Test] |
|||
public void TestNullRootBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
|
|||
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
var sensor = sensorComponent.CreateSensor(); |
|||
SensorTestHelper.CompareObservation(sensor, new float[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
var articulationBody = gameObj.AddComponent<ArticulationBody>(); |
|||
var sensorComponent = gameObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
sensorComponent.RootBody = articulationBody; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceLinearVelocity = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceRotations = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
0f, 0f, 0f, // ModelSpaceLinearVelocity
|
|||
0f, 0f, 0f, // LocalSpaceTranslations
|
|||
0f, 0f, 0f, 1f // LocalSpaceRotations
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodiesWithJoint() |
|||
{ |
|||
var rootObj = new GameObject(); |
|||
var rootArticBody = rootObj.AddComponent<ArticulationBody>(); |
|||
|
|||
var middleGamObj = new GameObject(); |
|||
var middleArticBody = middleGamObj.AddComponent<ArticulationBody>(); |
|||
middleArticBody.AddForce(new Vector3(0f, 1f, 0f)); |
|||
middleGamObj.transform.SetParent(rootObj.transform); |
|||
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); |
|||
middleArticBody.jointType = ArticulationJointType.RevoluteJoint; |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafArticBody = leafGameObj.AddComponent<ArticulationBody>(); |
|||
leafGameObj.transform.SetParent(middleGamObj.transform); |
|||
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); |
|||
leafArticBody.jointType = ArticulationJointType.RevoluteJoint; |
|||
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
// ArticulationBody.velocity is read-only in 2020.1
|
|||
rootArticBody.velocity = new Vector3(1f, 0f, 0f); |
|||
middleArticBody.velocity = new Vector3(0f, 1f, 0f); |
|||
leafArticBody.velocity = new Vector3(0f, 0f, 1f); |
|||
#endif
|
|||
|
|||
var sensorComponent = rootObj.AddComponent<ArticulationBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootArticBody; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseLocalSpaceTranslations = true, |
|||
#if UNITY_2020_2_OR_NEWER
|
|||
UseLocalSpaceLinearVelocity = true |
|||
#endif
|
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
// Model space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Middle pos
|
|||
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
|
|||
|
|||
// Local space
|
|||
0f, 0f, 0f, // Root pos
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
0f, 0f, 0f, // Root vel
|
|||
#endif
|
|||
|
|||
13.37f, 0f, 0f, // Attached pos
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
-1f, 1f, 0f, // Attached vel
|
|||
#endif
|
|||
|
|||
4.2f, 0f, 0f, // Leaf pos
|
|||
#if UNITY_2020_2_OR_NEWER
|
|||
0f, -1f, 1f // Leaf vel
|
|||
#endif
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
} |
|||
} |
|||
} |
|||
#endif // #if UNITY_2020_1_OR_NEWER
|
|
|||
fileFormatVersion: 2 |
|||
guid: 0ef757469348342418a68826f51d0783 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEngine; |
|||
using NUnit.Framework; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Extensions.Sensors; |
|||
|
|||
|
|||
namespace Unity.MLAgents.Extensions.Tests.Sensors |
|||
{ |
|||
|
|||
public static class SensorTestHelper |
|||
{ |
|||
public static void CompareObservation(ISensor sensor, float[] expected) |
|||
{ |
|||
string errorMessage; |
|||
bool isOK = SensorHelper.CompareObservation(sensor, expected, out errorMessage); |
|||
Assert.IsTrue(isOK, errorMessage); |
|||
} |
|||
} |
|||
|
|||
public class RigidBodySensorTests |
|||
{ |
|||
[Test] |
|||
public void TestNullRootBody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
|
|||
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); |
|||
var sensor = sensorComponent.CreateSensor(); |
|||
SensorTestHelper.CompareObservation(sensor, new float[0]); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestSingleRigidbody() |
|||
{ |
|||
var gameObj = new GameObject(); |
|||
var rootRb = gameObj.AddComponent<Rigidbody>(); |
|||
var sensorComponent = gameObj.AddComponent<RigidBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootRb; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceLinearVelocity = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceRotations = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
0f, 0f, 0f, // ModelSpaceLinearVelocity
|
|||
0f, 0f, 0f, // LocalSpaceTranslations
|
|||
0f, 0f, 0f, 1f // LocalSpaceRotations
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestBodiesWithJoint() |
|||
{ |
|||
var rootObj = new GameObject(); |
|||
var rootRb = rootObj.AddComponent<Rigidbody>(); |
|||
rootRb.velocity = new Vector3(1f, 0f, 0f); |
|||
|
|||
var middleGamObj = new GameObject(); |
|||
var middleRb = middleGamObj.AddComponent<Rigidbody>(); |
|||
middleRb.velocity = new Vector3(0f, 1f, 0f); |
|||
middleGamObj.transform.SetParent(rootObj.transform); |
|||
middleGamObj.transform.localPosition = new Vector3(13.37f, 0f, 0f); |
|||
var joint = middleGamObj.AddComponent<ConfigurableJoint>(); |
|||
joint.connectedBody = rootRb; |
|||
|
|||
var leafGameObj = new GameObject(); |
|||
var leafRb = leafGameObj.AddComponent<Rigidbody>(); |
|||
leafRb.velocity = new Vector3(0f, 0f, 1f); |
|||
leafGameObj.transform.SetParent(middleGamObj.transform); |
|||
leafGameObj.transform.localPosition = new Vector3(4.2f, 0f, 0f); |
|||
var joint2 = leafGameObj.AddComponent<ConfigurableJoint>(); |
|||
joint2.connectedBody = middleRb; |
|||
|
|||
|
|||
var sensorComponent = rootObj.AddComponent<RigidBodySensorComponent>(); |
|||
sensorComponent.RootBody = rootRb; |
|||
sensorComponent.Settings = new PhysicsSensorSettings |
|||
{ |
|||
UseModelSpaceTranslations = true, |
|||
UseLocalSpaceTranslations = true, |
|||
UseLocalSpaceLinearVelocity = true |
|||
}; |
|||
|
|||
var sensor = sensorComponent.CreateSensor(); |
|||
sensor.Update(); |
|||
var expected = new[] |
|||
{ |
|||
// Model space
|
|||
0f, 0f, 0f, // Root pos
|
|||
13.37f, 0f, 0f, // Middle pos
|
|||
leafGameObj.transform.position.x, 0f, 0f, // Leaf pos
|
|||
|
|||
// Local space
|
|||
0f, 0f, 0f, // Root pos
|
|||
0f, 0f, 0f, // Root vel
|
|||
|
|||
13.37f, 0f, 0f, // Attached pos
|
|||
-1f, 1f, 0f, // Attached vel
|
|||
|
|||
4.2f, 0f, 0f, // Leaf pos
|
|||
0f, -1f, 1f // Leaf vel
|
|||
}; |
|||
SensorTestHelper.CompareObservation(sensor, expected); |
|||
|
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: d8daf5517a7c94bfd9ac7f45f8d1bcd3 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors |
|||
{ |
|||
/// <summary>
|
|||
/// Utility methods related to <see cref="ISensor"/> implementations.
|
|||
/// </summary>
|
|||
public static class SensorHelper |
|||
{ |
|||
/// <summary>
|
|||
/// Generates the observations for the provided sensor, and returns true if they equal the
|
|||
/// expected values. If they are unequal, errorMessage is also set.
|
|||
/// This should not generally be used in production code. It is only intended for
|
|||
/// simplifying unit tests.
|
|||
/// </summary>
|
|||
/// <param name="sensor"></param>
|
|||
/// <param name="expected"></param>
|
|||
/// <param name="errorMessage"></param>
|
|||
/// <returns></returns>
|
|||
public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage) |
|||
{ |
|||
var numExpected = expected.Length; |
|||
const float fill = -1337f; |
|||
var output = new float[numExpected]; |
|||
for (var i = 0; i < numExpected; i++) |
|||
{ |
|||
output[i] = fill; |
|||
} |
|||
|
|||
if (numExpected > 0) |
|||
{ |
|||
if (fill != output[0]) |
|||
{ |
|||
errorMessage = "Error setting output buffer."; |
|||
return false; |
|||
} |
|||
} |
|||
|
|||
ObservationWriter writer = new ObservationWriter(); |
|||
writer.SetTarget(output, sensor.GetObservationShape(), 0); |
|||
|
|||
// Make sure ObservationWriter didn't touch anything
|
|||
if (numExpected > 0) |
|||
{ |
|||
if (fill != output[0]) |
|||
{ |
|||
errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have."; |
|||
return false; |
|||
} |
|||
} |
|||
|
|||
sensor.Write(writer); |
|||
for (var i = 0; i < output.Length; i++) |
|||
{ |
|||
if (expected[i] != output[i]) |
|||
{ |
|||
errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} "; |
|||
return false; |
|||
} |
|||
} |
|||
|
|||
errorMessage = null; |
|||
return true; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 7c1189c0af42c46f7b533350d49ad3e7 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
# # Unity ML-Agents Toolkit |
|||
from typing import Dict, Any, Optional, List |
|||
import os |
|||
import attr |
|||
from mlagents.trainers.training_status import GlobalTrainingStatus, StatusType |
|||
from mlagents_envs.logging_util import get_logger |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
@attr.s(auto_attribs=True) |
|||
class NNCheckpoint: |
|||
steps: int |
|||
file_path: str |
|||
reward: Optional[float] |
|||
creation_time: float |
|||
|
|||
|
|||
class NNCheckpointManager: |
|||
@staticmethod |
|||
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]: |
|||
checkpoint_list = GlobalTrainingStatus.get_parameter_state( |
|||
behavior_name, StatusType.CHECKPOINTS |
|||
) |
|||
if not checkpoint_list: |
|||
checkpoint_list = [] |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
behavior_name, StatusType.CHECKPOINTS, checkpoint_list |
|||
) |
|||
return checkpoint_list |
|||
|
|||
@staticmethod |
|||
def remove_checkpoint(checkpoint: Dict[str, Any]) -> None: |
|||
""" |
|||
Removes a checkpoint stored in checkpoint_list. |
|||
If checkpoint cannot be found, no action is done. |
|||
|
|||
:param checkpoint: A checkpoint stored in checkpoint_list |
|||
""" |
|||
file_path: str = checkpoint["file_path"] |
|||
if os.path.exists(file_path): |
|||
os.remove(file_path) |
|||
logger.info(f"Removed checkpoint model {file_path}.") |
|||
else: |
|||
logger.info(f"Checkpoint at {file_path} could not be found.") |
|||
return |
|||
|
|||
@classmethod |
|||
def _cleanup_extra_checkpoints( |
|||
cls, checkpoints: List[Dict], keep_checkpoints: int |
|||
) -> List[Dict]: |
|||
""" |
|||
Ensures that the number of checkpoints stored are within the number |
|||
of checkpoints the user defines. If the limit is hit, checkpoints are |
|||
removed to create room for the next checkpoint to be inserted. |
|||
|
|||
:param behavior_name: The behavior name whose checkpoints we will mange. |
|||
:param keep_checkpoints: Number of checkpoints to record (user-defined). |
|||
""" |
|||
while len(checkpoints) > keep_checkpoints: |
|||
if keep_checkpoints <= 0 or len(checkpoints) == 0: |
|||
break |
|||
NNCheckpointManager.remove_checkpoint(checkpoints.pop(0)) |
|||
return checkpoints |
|||
|
|||
@classmethod |
|||
def add_checkpoint( |
|||
cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int |
|||
) -> None: |
|||
""" |
|||
Make room for new checkpoint if needed and insert new checkpoint information. |
|||
:param behavior_name: Behavior name for the checkpoint. |
|||
:param new_checkpoint: The new checkpoint to be recorded. |
|||
:param keep_checkpoints: Number of checkpoints to record (user-defined). |
|||
""" |
|||
new_checkpoint_dict = attr.asdict(new_checkpoint) |
|||
checkpoints = cls.get_checkpoints(behavior_name) |
|||
checkpoints.append(new_checkpoint_dict) |
|||
cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints) |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
behavior_name, StatusType.CHECKPOINTS, checkpoints |
|||
) |
|||
|
|||
@classmethod |
|||
def track_final_checkpoint( |
|||
cls, behavior_name: str, final_checkpoint: NNCheckpoint |
|||
) -> None: |
|||
""" |
|||
Ensures number of checkpoints stored is within the max number of checkpoints |
|||
defined by the user and finally stores the information about the final |
|||
model (or intermediate model if training is interrupted). |
|||
:param behavior_name: Behavior name of the model. |
|||
:param final_checkpoint: Checkpoint information for the final model. |
|||
""" |
|||
final_model_dict = attr.asdict(final_checkpoint) |
|||
GlobalTrainingStatus.set_parameter_state( |
|||
behavior_name, StatusType.FINAL_CHECKPOINT, final_model_dict |
|||
) |
|
|||
from mlagents.model_serialization import SerializationSettings |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from unittest.mock import MagicMock |
|||
from unittest import mock |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
import numpy as np |
|||
|
|||
|
|||
def basic_mock_brain(): |
|||
mock_brain = MagicMock() |
|||
mock_brain.vector_action_space_type = "continuous" |
|||
mock_brain.vector_observation_space_size = 1 |
|||
mock_brain.vector_action_space_size = [1] |
|||
mock_brain.brain_name = "MockBrain" |
|||
return mock_brain |
|||
|
|||
|
|||
class FakePolicy(TFPolicy): |
|||
def create_tf_graph(self): |
|||
pass |
|||
|
|||
def get_trainable_variables(self): |
|||
return [] |
|||
|
|||
|
|||
def test_take_action_returns_empty_with_no_agents(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
# Doesn't really matter what this is |
|||
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1) |
|||
no_agent_step = DecisionSteps.empty(dummy_groupspec) |
|||
result = policy.get_action(no_agent_step) |
|||
assert result == ActionInfo.empty() |
|||
|
|||
|
|||
def test_take_action_returns_nones_on_missing_values(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy.evaluate = MagicMock(return_value={}) |
|||
policy.save_memories = MagicMock() |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents, worker_id=0) |
|||
assert result == ActionInfo(None, None, {}, [0]) |
|||
|
|||
|
|||
def test_take_action_returns_action_info_when_available(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy_eval_out = { |
|||
"action": np.array([1.0], dtype=np.float32), |
|||
"memory_out": np.array([[2.5]], dtype=np.float32), |
|||
"value": np.array([1.1], dtype=np.float32), |
|||
} |
|||
policy.evaluate = MagicMock(return_value=policy_eval_out) |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents) |
|||
expected = ActionInfo( |
|||
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0] |
|||
) |
|||
assert result == expected |
|||
|
|||
|
|||
def test_convert_version_string(): |
|||
result = TFPolicy._convert_version_string("200.300.100") |
|||
assert result == (200, 300, 100) |
|||
# Test dev versions |
|||
result = TFPolicy._convert_version_string("200.300.100.dev0") |
|||
assert result == (200, 300, 100) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.policy.tf_policy.export_policy_model") |
|||
@mock.patch("time.time", mock.MagicMock(return_value=12345)) |
|||
def test_checkpoint_writes_tf_and_nn_checkpoints(export_policy_model_mock): |
|||
mock_brain = basic_mock_brain() |
|||
test_seed = 4 # moving up in the world |
|||
policy = FakePolicy(test_seed, mock_brain, TrainerSettings(), "output") |
|||
n_steps = 5 |
|||
policy.get_current_step = MagicMock(return_value=n_steps) |
|||
policy.saver = MagicMock() |
|||
serialization_settings = SerializationSettings("output", mock_brain.brain_name) |
|||
checkpoint_path = f"output/{mock_brain.brain_name}-{n_steps}" |
|||
policy.checkpoint(checkpoint_path, serialization_settings) |
|||
policy.saver.save.assert_called_once_with(policy.sess, f"{checkpoint_path}.ckpt") |
|||
export_policy_model_mock.assert_called_once_with( |
|||
checkpoint_path, serialization_settings, policy.graph, policy.sess |
|||
) |
|
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec |
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from unittest.mock import MagicMock |
|||
from mlagents.trainers.settings import TrainerSettings |
|||
import numpy as np |
|||
|
|||
|
|||
def basic_mock_brain(): |
|||
mock_brain = MagicMock() |
|||
mock_brain.vector_action_space_type = "continuous" |
|||
mock_brain.vector_observation_space_size = 1 |
|||
mock_brain.vector_action_space_size = [1] |
|||
return mock_brain |
|||
|
|||
|
|||
class FakePolicy(TFPolicy): |
|||
def create_tf_graph(self): |
|||
pass |
|||
|
|||
def get_trainable_variables(self): |
|||
return [] |
|||
|
|||
|
|||
def test_take_action_returns_empty_with_no_agents(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
# Doesn't really matter what this is |
|||
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1) |
|||
no_agent_step = DecisionSteps.empty(dummy_groupspec) |
|||
result = policy.get_action(no_agent_step) |
|||
assert result == ActionInfo.empty() |
|||
|
|||
|
|||
def test_take_action_returns_nones_on_missing_values(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy.evaluate = MagicMock(return_value={}) |
|||
policy.save_memories = MagicMock() |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents, worker_id=0) |
|||
assert result == ActionInfo(None, None, {}, [0]) |
|||
|
|||
|
|||
def test_take_action_returns_action_info_when_available(): |
|||
test_seed = 3 |
|||
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output") |
|||
policy_eval_out = { |
|||
"action": np.array([1.0], dtype=np.float32), |
|||
"memory_out": np.array([[2.5]], dtype=np.float32), |
|||
"value": np.array([1.1], dtype=np.float32), |
|||
} |
|||
policy.evaluate = MagicMock(return_value=policy_eval_out) |
|||
step_with_agents = DecisionSteps( |
|||
[], np.array([], dtype=np.float32), np.array([0]), None |
|||
) |
|||
result = policy.get_action(step_with_agents) |
|||
expected = ActionInfo( |
|||
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0] |
|||
) |
|||
assert result == expected |
|||
|
|||
|
|||
def test_convert_version_string(): |
|||
result = TFPolicy._convert_version_string("200.300.100") |
|||
assert result == (200, 300, 100) |
|||
# Test dev versions |
|||
result = TFPolicy._convert_version_string("200.300.100.dev0") |
|||
assert result == (200, 300, 100) |
撰写
预览
正在加载...
取消
保存
Reference in new issue