Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

399 行
17 KiB

using System.Linq;
using NUnit.Framework;
using UnityEngine;
using UnityEditor;
using Unity.Barracuda;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;
namespace Unity.MLAgents.Tests
{
public class Test3DSensorComponent : SensorComponent
{
public ISensor Sensor;
public override ISensor CreateSensor()
{
return Sensor;
}
public override int[] GetObservationShape()
{
var shape = Sensor.GetObservationSpec().Shape;
return new int[] { shape[0], shape[1], shape[2] };
}
}
public class Test3DSensor : ISensor, IBuiltInSensor, IDimensionPropertiesSensor
{
int m_Width;
int m_Height;
int m_Channels;
string m_Name;
// Dummy value for the IBuiltInSensor interface
public const int k_BuiltInSensorType = -42;
public Test3DSensor(string name, int width, int height, int channels)
{
m_Width = width;
m_Height = height;
m_Channels = channels;
m_Name = name;
}
public ObservationSpec GetObservationSpec()
{
return ObservationSpec.Visual(m_Height, m_Width, m_Channels);
}
public int Write(ObservationWriter writer)
{
for (int i = 0; i < m_Width * m_Height * m_Channels; i++)
{
writer[i] = 0.0f;
}
return m_Width * m_Height * m_Channels;
}
public byte[] GetCompressedObservation()
{
return new byte[0];
}
public void Update() { }
public void Reset() { }
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
public string GetName()
{
return m_Name;
}
public BuiltInSensorType GetBuiltInSensorType()
{
return (BuiltInSensorType)k_BuiltInSensorType;
}
public DimensionProperty[] GetDimensionProperties()
{
return new[]
{
DimensionProperty.TranslationalEquivariance,
DimensionProperty.TranslationalEquivariance,
DimensionProperty.None
};
}
}
[TestFixture]
public class ParameterLoaderTest
{
// ONNX model with continuous/discrete action output (support hybrid action)
const string k_continuousONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action.onnx";
const string k_discreteONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr.onnx";
const string k_hybridONNXPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/hybrid0vis53vec_3c_2daction.onnx";
// NN model with single action output (deprecated, does not support hybrid action).
// Same BrainParameters settings as the corresponding ONNX model.
const string k_continuousNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/continuous2vis8vec2action_deprecated.nn";
const string k_discreteNNPath = "Packages/com.unity.ml-agents/Tests/Editor/TestModels/discrete1vis0vec_2_3action_recurr_deprecated.nn";
NNModel continuousONNXModel;
NNModel discreteONNXModel;
NNModel hybridONNXModel;
NNModel continuousNNModel;
NNModel discreteNNModel;
Test3DSensorComponent sensor_21_20_3;
Test3DSensorComponent sensor_20_22_3;
BrainParameters GetContinuous2vis8vec2actionBrainParameters()
{
var validBrainParameters = new BrainParameters();
validBrainParameters.VectorObservationSize = 8;
validBrainParameters.NumStackedVectorObservations = 1;
validBrainParameters.ActionSpec = ActionSpec.MakeContinuous(2);
return validBrainParameters;
}
BrainParameters GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters()
{
var validBrainParameters = new BrainParameters();
validBrainParameters.VectorObservationSize = 0;
validBrainParameters.NumStackedVectorObservations = 1;
validBrainParameters.ActionSpec = ActionSpec.MakeDiscrete(2, 3);
return validBrainParameters;
}
BrainParameters GetHybridBrainParameters()
{
var validBrainParameters = new BrainParameters();
validBrainParameters.VectorObservationSize = 53;
validBrainParameters.NumStackedVectorObservations = 1;
validBrainParameters.ActionSpec = new ActionSpec(3, new[] { 2 });
return validBrainParameters;
}
[SetUp]
public void SetUp()
{
continuousONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousONNXPath, typeof(NNModel));
discreteONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteONNXPath, typeof(NNModel));
hybridONNXModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_hybridONNXPath, typeof(NNModel));
continuousNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_continuousNNPath, typeof(NNModel));
discreteNNModel = (NNModel)AssetDatabase.LoadAssetAtPath(k_discreteNNPath, typeof(NNModel));
var go = new GameObject("SensorA");
sensor_21_20_3 = go.AddComponent<Test3DSensorComponent>();
sensor_21_20_3.Sensor = new Test3DSensor("SensorA", 21, 20, 3);
sensor_20_22_3 = go.AddComponent<Test3DSensorComponent>();
sensor_20_22_3.Sensor = new Test3DSensor("SensorA", 20, 22, 3);
}
[Test]
public void TestModelExist()
{
Assert.IsNotNull(continuousONNXModel);
Assert.IsNotNull(discreteONNXModel);
Assert.IsNotNull(hybridONNXModel);
Assert.IsNotNull(continuousNNModel);
Assert.IsNotNull(discreteNNModel);
}
[TestCase(true)]
[TestCase(false)]
public void TestGetInputTensorsContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var inputNames = model.GetInputNames();
// Model should contain 3 inputs : vector, visual 1 and visual 2
Assert.AreEqual(3, inputNames.Count());
Assert.Contains(TensorNames.VectorObservationPlaceholder, inputNames);
Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "0", inputNames);
Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "1", inputNames);
Assert.AreEqual(2, model.GetNumVisualInputs());
// Test if the model is null
model = null;
Assert.AreEqual(0, model.GetInputTensors().Count);
Assert.AreEqual(0, model.GetNumVisualInputs());
}
[TestCase(true)]
[TestCase(false)]
public void TestGetInputTensorsDiscrete(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var inputNames = model.GetInputNames();
// Model should contain 2 inputs : recurrent and visual 1
Assert.Contains(TensorNames.VisualObservationPlaceholderPrefix + "0", inputNames);
// TODO :There are some memory tensors as well
}
[Test]
public void TestGetInputTensorsHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var inputNames = model.GetInputNames();
Assert.Contains(TensorNames.VectorObservationPlaceholder, inputNames);
}
[TestCase(true)]
[TestCase(false)]
public void TestGetOutputTensorsContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var outputNames = model.GetOutputNames();
var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.ContinuousActionOutput;
Assert.Contains(actionOutputName, outputNames);
Assert.AreEqual(1, outputNames.Count());
model = null;
Assert.AreEqual(0, model.GetOutputNames().Count());
}
[TestCase(true)]
[TestCase(false)]
public void TestGetOutputTensorsDiscrete(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var outputNames = model.GetOutputNames();
var actionOutputName = useDeprecatedNNModel ? TensorNames.ActionOutputDeprecated : TensorNames.DiscreteActionOutput;
Assert.Contains(actionOutputName, outputNames);
// TODO : There are some memory tensors as well
}
[Test]
public void TestGetOutputTensorsHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var outputNames = model.GetOutputNames();
Assert.AreEqual(2, outputNames.Count());
Assert.Contains(TensorNames.ContinuousActionOutput, outputNames);
Assert.Contains(TensorNames.DiscreteActionOutput, outputNames);
model = null;
Assert.AreEqual(0, model.GetOutputNames().Count());
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelValidContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var validBrainParameters = GetContinuous2vis8vec2actionBrainParameters();
var errors = BarracudaModelParamLoader.CheckModel(
model, validBrainParameters,
new ISensor[] { new VectorSensor(8), sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]
);
Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelValidDiscrete(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var validBrainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
var errors = BarracudaModelParamLoader.CheckModel(
model, validBrainParameters,
new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]
);
Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
[Test]
public void TestCheckModelValidHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var validBrainParameters = GetHybridBrainParameters();
var errors = BarracudaModelParamLoader.CheckModel(
model, validBrainParameters,
new ISensor[] { new VectorSensor(validBrainParameters.VectorObservationSize) }, new ActuatorComponent[0]
);
Assert.AreEqual(0, errors.Count()); // There should not be any errors
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsVectorObservationContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.VectorObservationSize = 9; // Invalid observation
var errors = BarracudaModelParamLoader.CheckModel(
model, brainParameters,
new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]
);
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.NumStackedVectorObservations = 2;// Invalid stacking
errors = BarracudaModelParamLoader.CheckModel(
model, brainParameters,
new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]
);
Assert.Greater(errors.Count(), 0);
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsVectorObservationDiscrete(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
brainParameters.VectorObservationSize = 1; // Invalid observation
var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
}
[Test]
public void TestCheckModelThrowsVectorObservationHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var brainParameters = GetHybridBrainParameters();
brainParameters.VectorObservationSize = 9; // Invalid observation
var errors = BarracudaModelParamLoader.CheckModel(
model, brainParameters,
new ISensor[] { }, new ActuatorComponent[0]
);
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.NumStackedVectorObservations = 2;// Invalid stacking
errors = BarracudaModelParamLoader.CheckModel(
model, brainParameters,
new ISensor[] { }, new ActuatorComponent[0]
);
Assert.Greater(errors.Count(), 0);
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsActionContinuous(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(continuousNNModel) : ModelLoader.Load(continuousONNXModel);
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeContinuous(3); // Invalid action
var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3); // Invalid SpaceType
errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
}
[TestCase(true)]
[TestCase(false)]
public void TestCheckModelThrowsActionDiscrete(bool useDeprecatedNNModel)
{
var model = useDeprecatedNNModel ? ModelLoader.Load(discreteNNModel) : ModelLoader.Load(discreteONNXModel);
var brainParameters = GetDiscrete1vis0vec_2_3action_recurrModelBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeDiscrete(3, 3); // Invalid action
var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeContinuous(2); // Invalid SpaceType
errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
}
[Test]
public void TestCheckModelThrowsActionHybrid()
{
var model = ModelLoader.Load(hybridONNXModel);
var brainParameters = GetHybridBrainParameters();
brainParameters.ActionSpec = new ActionSpec(3, new[] { 3 }); // Invalid discrete action size
var errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
brainParameters = GetContinuous2vis8vec2actionBrainParameters();
brainParameters.ActionSpec = ActionSpec.MakeDiscrete(2); // Missing continuous action
errors = BarracudaModelParamLoader.CheckModel(model, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
}
[Test]
public void TestCheckModelThrowsNoModel()
{
var brainParameters = GetContinuous2vis8vec2actionBrainParameters();
var errors = BarracudaModelParamLoader.CheckModel(null, brainParameters, new ISensor[] { sensor_21_20_3.CreateSensor(), sensor_20_22_3.CreateSensor() }, new ActuatorComponent[0]);
Assert.Greater(errors.Count(), 0);
}
}
}