|
|
|
|
|
|
using NUnit.Framework; |
|
|
|
using UnityEngine; |
|
|
|
using System.IO.Abstractions.TestingHelpers; |
|
|
|
using System.Reflection; |
|
|
|
using MLAgents.CommunicatorObjects; |
|
|
|
using Google.Protobuf; |
|
|
|
|
|
|
|
namespace MLAgents.Tests |
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
demoStore.Record(agentInfo); |
|
|
|
demoStore.Close(); |
|
|
|
} |
|
|
|
|
|
|
|
public class ObservationAgent : TestAgent |
|
|
|
{ |
|
|
|
public override void CollectObservations() |
|
|
|
{ |
|
|
|
collectObservationsCalls += 1; |
|
|
|
AddVectorObs(1f); |
|
|
|
AddVectorObs(2f); |
|
|
|
AddVectorObs(3f); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
[Test] |
|
|
|
public void TestAgentWrite() |
|
|
|
{ |
|
|
|
var agentGo1 = new GameObject("TestAgent"); |
|
|
|
var bpA = agentGo1.AddComponent<BehaviorParameters>(); |
|
|
|
bpA.brainParameters.vectorObservationSize = 3; |
|
|
|
bpA.brainParameters.numStackedVectorObservations = 1; |
|
|
|
bpA.brainParameters.vectorActionDescriptions = new[] { "TestActionA", "TestActionB" }; |
|
|
|
bpA.brainParameters.vectorActionSize = new[] { 2, 2 }; |
|
|
|
bpA.brainParameters.vectorActionSpaceType = SpaceType.Discrete; |
|
|
|
|
|
|
|
agentGo1.AddComponent<ObservationAgent>(); |
|
|
|
var agent1 = agentGo1.GetComponent<ObservationAgent>(); |
|
|
|
|
|
|
|
agentGo1.AddComponent<DemonstrationRecorder>(); |
|
|
|
var demoRecorder = agentGo1.GetComponent<DemonstrationRecorder>(); |
|
|
|
var fileSystem = new MockFileSystem(); |
|
|
|
demoRecorder.demonstrationName = "TestBrain"; |
|
|
|
demoRecorder.record = true; |
|
|
|
demoRecorder.InitializeDemoStore(fileSystem); |
|
|
|
|
|
|
|
var acaGo = new GameObject("TestAcademy"); |
|
|
|
acaGo.AddComponent<TestAcademy>(); |
|
|
|
var aca = acaGo.GetComponent<TestAcademy>(); |
|
|
|
aca.resetParameters = new ResetParameters(); |
|
|
|
|
|
|
|
var academyInitializeMethod = typeof(Academy).GetMethod("InitializeEnvironment", |
|
|
|
BindingFlags.Instance | BindingFlags.NonPublic); |
|
|
|
var agentEnableMethod = typeof(Agent).GetMethod("OnEnable", |
|
|
|
BindingFlags.Instance | BindingFlags.NonPublic); |
|
|
|
var agentSendInfo = typeof(Agent).GetMethod("SendInfo", |
|
|
|
BindingFlags.Instance | BindingFlags.NonPublic); |
|
|
|
|
|
|
|
agentEnableMethod?.Invoke(agent1, new object[] { }); |
|
|
|
academyInitializeMethod?.Invoke(aca, new object[] { }); |
|
|
|
|
|
|
|
// Step the agent
|
|
|
|
agent1.RequestDecision(); |
|
|
|
agentSendInfo?.Invoke(agent1, new object[] { }); |
|
|
|
|
|
|
|
demoRecorder.Close(); |
|
|
|
|
|
|
|
// Read back the demo file and make sure observations were written
|
|
|
|
var reader = fileSystem.File.OpenRead("Assets/Demonstrations/TestBrain.demo"); |
|
|
|
reader.Seek(DemonstrationStore.MetaDataBytes + 1, 0); |
|
|
|
BrainParametersProto.Parser.ParseDelimitedFrom(reader); |
|
|
|
|
|
|
|
var agentInfoProto = AgentInfoProto.Parser.ParseDelimitedFrom(reader); |
|
|
|
var obs = agentInfoProto.StackedVectorObservation; |
|
|
|
Assert.AreEqual(obs.Count, bpA.brainParameters.vectorObservationSize); |
|
|
|
for (var i = 0; i < obs.Count; i++) |
|
|
|
{ |
|
|
|
Assert.AreEqual((float) i+1, obs[i]); |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
} |