Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

130 行
4.7 KiB

using UnityEngine;
using Unity.Barracuda;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// Utility methods related to <see cref="ISensor"/> implementations.
/// </summary>
public static class SensorHelper
{
/// <summary>
/// Generates the observations for the provided sensor, and returns true if they equal the
/// expected values. If they are unequal, errorMessage is also set.
/// This should not generally be used in production code. It is only intended for
/// simplifying unit tests.
/// </summary>
/// <param name="sensor"></param>
/// <param name="expected"></param>
/// <param name="errorMessage"></param>
/// <returns></returns>
public static bool CompareObservation(ISensor sensor, float[] expected, out string errorMessage)
{
var numExpected = expected.Length;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "Error setting output buffer.";
return false;
}
}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
return false;
}
}
sensor.Write(writer);
for (var i = 0; i < output.Length; i++)
{
if (expected[i] != output[i])
{
errorMessage = $"Expected and actual differed in position {i}. Expected: {expected[i]} Actual: {output[i]} ";
return false;
}
}
errorMessage = null;
return true;
}
/// <summary>
/// Generates the observations for the provided sensor, and returns true if they equal the
/// expected values. If they are unequal, errorMessage is also set.
/// This should not generally be used in production code. It is only intended for
/// simplifying unit tests.
/// </summary>
/// <param name="sensor"></param>
/// <param name="expected"></param>
/// <param name="errorMessage"></param>
/// <returns></returns>
public static bool CompareObservation(ISensor sensor, float[,,] expected, out string errorMessage)
{
var tensorShape = new TensorShape(0, expected.GetLength(0), expected.GetLength(1), expected.GetLength(2));
var numExpected = tensorShape.height * tensorShape.width * tensorShape.channels;
const float fill = -1337f;
var output = new float[numExpected];
for (var i = 0; i < numExpected; i++)
{
output[i] = fill;
}
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "Error setting output buffer.";
return false;
}
}
ObservationWriter writer = new ObservationWriter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure ObservationWriter didn't touch anything
if (numExpected > 0)
{
if (fill != output[0])
{
errorMessage = "ObservationWriter.SetTarget modified a buffer it shouldn't have.";
return false;
}
}
sensor.Write(writer);
for (var h = 0; h < tensorShape.height; h++)
{
for (var w = 0; w < tensorShape.width; w++)
{
for (var c = 0; c < tensorShape.channels; c++)
{
if (expected[h, w, c] != output[tensorShape.Index(0, h, w, c)])
{
errorMessage = $"Expected and actual differed in position [{h}, {w}, {c}]. " +
"Expected: {expected[h, w, c]} Actual: {output[tensorShape.Index(0, h, w, c)]} ";
return false;
}
}
}
}
errorMessage = null;
return true;
}
}
}