Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

167 行
6.1 KiB

using System;
using NUnit.Framework;
using System.Collections.Generic;
using System.Text;
using MLAgents.SideChannels;
namespace MLAgents.Tests
{
public class SideChannelTests
{
// This test side channel only deals in integers
public class TestSideChannel : SideChannel
{
public List<int> messagesReceived = new List<int>();
public TestSideChannel()
{
ChannelId = new Guid("6afa2c06-4f82-11ea-b238-784f4387d1f7");
}
public override void OnMessageReceived(IncomingMessage msg)
{
messagesReceived.Add(msg.ReadInt32());
}
public void SendInt(int value)
{
using (var msg = new OutgoingMessage())
{
msg.WriteInt32(value);
QueueMessageToSend(msg);
}
}
}
[Test]
public void TestIntegerSideChannel()
{
var intSender = new TestSideChannel();
var intReceiver = new TestSideChannel();
var dictSender = new Dictionary<Guid, SideChannel> { { intSender.ChannelId, intSender } };
var dictReceiver = new Dictionary<Guid, SideChannel> { { intReceiver.ChannelId, intReceiver } };
intSender.SendInt(4);
intSender.SendInt(5);
intSender.SendInt(6);
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
Assert.AreEqual(intReceiver.messagesReceived[0], 4);
Assert.AreEqual(intReceiver.messagesReceived[1], 5);
Assert.AreEqual(intReceiver.messagesReceived[2], 6);
}
[Test]
public void TestRawBytesSideChannel()
{
var str1 = "Test string";
var str2 = "Test string, second";
var strSender = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7"));
var strReceiver = new RawBytesChannel(new Guid("9a5b8954-4f82-11ea-b238-784f4387d1f7"));
var dictSender = new Dictionary<Guid, SideChannel> { { strSender.ChannelId, strSender } };
var dictReceiver = new Dictionary<Guid, SideChannel> { { strReceiver.ChannelId, strReceiver } };
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str1));
strSender.SendRawBytes(Encoding.ASCII.GetBytes(str2));
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
var messages = strReceiver.GetAndClearReceivedMessages();
Assert.AreEqual(messages.Count, 2);
Assert.AreEqual(Encoding.ASCII.GetString(messages[0]), str1);
Assert.AreEqual(Encoding.ASCII.GetString(messages[1]), str2);
}
[Test]
public void TestFloatPropertiesSideChannel()
{
var k1 = "gravity";
var k2 = "length";
int wasCalled = 0;
var propA = new FloatPropertiesChannel();
var propB = new FloatPropertiesChannel();
var dictReceiver = new Dictionary<Guid, SideChannel> { { propA.ChannelId, propA } };
var dictSender = new Dictionary<Guid, SideChannel> { { propB.ChannelId, propB } };
propA.RegisterCallback(k1, f => { wasCalled++; });
var tmp = propB.GetPropertyWithDefault(k2, 3.0f);
Assert.AreEqual(tmp, 3.0f);
propB.SetProperty(k2, 1.0f);
tmp = propB.GetPropertyWithDefault(k2, 3.0f);
Assert.AreEqual(tmp, 1.0f);
byte[] fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
tmp = propA.GetPropertyWithDefault(k2, 3.0f);
Assert.AreEqual(tmp, 1.0f);
Assert.AreEqual(wasCalled, 0);
propB.SetProperty(k1, 1.0f);
Assert.AreEqual(wasCalled, 0);
fakeData = RpcCommunicator.GetSideChannelMessage(dictSender);
RpcCommunicator.ProcessSideChannelData(dictReceiver, fakeData);
Assert.AreEqual(wasCalled, 1);
var keysA = propA.ListProperties();
Assert.AreEqual(2, keysA.Count);
Assert.IsTrue(keysA.Contains(k1));
Assert.IsTrue(keysA.Contains(k2));
var keysB = propA.ListProperties();
Assert.AreEqual(2, keysB.Count);
Assert.IsTrue(keysB.Contains(k1));
Assert.IsTrue(keysB.Contains(k2));
}
[Test]
public void TestOutgoingMessageRawBytes()
{
// Make sure that SetRawBytes resets the buffer correctly.
// Write 8 bytes (an int and float) then call SetRawBytes with 4 bytes
var msg = new OutgoingMessage();
msg.WriteInt32(42);
msg.WriteFloat32(1.0f);
var data = new byte[] { 1, 2, 3, 4 };
msg.SetRawBytes(data);
var result = msg.ToByteArray();
Assert.AreEqual(data, result);
}
[Test]
public void TestMessageReadWrites()
{
var boolVal = true;
var intVal = 1337;
var floatVal = 4.2f;
var floatListVal = new float[] { 1001, 1002 };
var stringVal = "mlagents!";
IncomingMessage incomingMsg;
using (var outgoingMsg = new OutgoingMessage())
{
outgoingMsg.WriteBoolean(boolVal);
outgoingMsg.WriteInt32(intVal);
outgoingMsg.WriteFloat32(floatVal);
outgoingMsg.WriteString(stringVal);
outgoingMsg.WriteFloatList(floatListVal);
incomingMsg = new IncomingMessage(outgoingMsg.ToByteArray());
}
Assert.AreEqual(boolVal, incomingMsg.ReadBoolean());
Assert.AreEqual(intVal, incomingMsg.ReadInt32());
Assert.AreEqual(floatVal, incomingMsg.ReadFloat32());
Assert.AreEqual(stringVal, incomingMsg.ReadString());
Assert.AreEqual(floatListVal, incomingMsg.ReadFloatList());
}
}
}