Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

106 行
4.1 KiB

using NUnit.Framework;
using System.IO;
using Unity.MLAgents.SideChannels;
namespace Unity.MLAgents.Tests
{
public class SamplerTests
{
const int k_Seed = 1337;
const double k_Epsilon = 0.0001;
EnvironmentParametersChannel m_Channel;
public SamplerTests()
{
m_Channel = SideChannelManager.GetSideChannel<EnvironmentParametersChannel>();
// if running test on its own
if (m_Channel == null)
{
m_Channel = new EnvironmentParametersChannel();
SideChannelManager.RegisterSideChannel(m_Channel);
}
}
[Test]
public void UniformSamplerTest()
{
float min_value = 1.0f;
float max_value = 2.0f;
string parameter = "parameter1";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Uniform);
outgoingMsg.WriteFloat32(min_value);
outgoingMsg.WriteFloat32(max_value);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelManager.ProcessSideChannelData(message);
}
Assert.AreEqual(1.208888f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.118017f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}
[Test]
public void GaussianSamplerTest()
{
float mean = 3.0f;
float stddev = 0.2f;
string parameter = "parameter2";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.Gaussian);
outgoingMsg.WriteFloat32(mean);
outgoingMsg.WriteFloat32(stddev);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelManager.ProcessSideChannelData(message);
}
Assert.AreEqual(2.936162f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(2.951348f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}
[Test]
public void MultiRangeUniformSamplerTest()
{
float[] intervals = new float[4];
intervals[0] = 1.2f;
intervals[1] = 2f;
intervals[2] = 3.2f;
intervals[3] = 4.1f;
string parameter = "parameter3";
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteString(parameter);
// 1 indicates this meessage is a Sampler
outgoingMsg.WriteInt32(1);
outgoingMsg.WriteInt32(k_Seed);
outgoingMsg.WriteInt32((int)SamplerType.MultiRangeUniform);
outgoingMsg.WriteFloatList(intervals);
byte[] message = GetByteMessage(m_Channel, outgoingMsg);
SideChannelManager.ProcessSideChannelData(message);
}
Assert.AreEqual(3.387999f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
Assert.AreEqual(1.294413f, m_Channel.GetWithDefault(parameter, 1.0f), k_Epsilon);
}
internal static byte[] GetByteMessage(SideChannel sideChannel, OutgoingMessage msg)
{
byte[] message = msg.ToByteArray();
using (var memStream = new MemoryStream())
{
using (var binaryWriter = new BinaryWriter(memStream))
{
binaryWriter.Write(sideChannel.ChannelId.ToByteArray());
binaryWriter.Write(message.Length);
binaryWriter.Write(message);
}
return memStream.ToArray();
}
}
}
}