您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
170 行
4.9 KiB
170 行
4.9 KiB
using System;
|
|
using System.Collections.Generic;
|
|
using System.Reflection;
|
|
using System.Runtime.CompilerServices;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.MLAgents.Policies;
|
|
using Unity.MLAgents.Sensors;
|
|
using Unity.MLAgents.Sensors.Reflection;
|
|
|
|
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
|
|
[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Sensor.Tests")]
|
|
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
|
|
|
|
namespace Unity.MLAgents.Utils.Tests
|
|
{
|
|
internal class TestPolicy : IPolicy
|
|
{
|
|
public Action OnRequestDecision;
|
|
ObservationWriter m_ObsWriter = new ObservationWriter();
|
|
static ActionSpec s_ActionSpec = ActionSpec.MakeContinuous(1);
|
|
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(new float[1], Array.Empty<int>());
|
|
|
|
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
|
|
{
|
|
foreach (var sensor in sensors)
|
|
{
|
|
sensor.GetObservationProto(m_ObsWriter);
|
|
}
|
|
OnRequestDecision?.Invoke();
|
|
}
|
|
|
|
public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; }
|
|
|
|
public void Dispose() { }
|
|
}
|
|
|
|
public class TestAgent : Agent
|
|
{
|
|
internal AgentInfo _Info
|
|
{
|
|
get
|
|
{
|
|
return (AgentInfo)typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
|
|
}
|
|
set
|
|
{
|
|
typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, value);
|
|
}
|
|
}
|
|
|
|
internal void SetPolicy(IPolicy policy)
|
|
{
|
|
typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, policy);
|
|
}
|
|
|
|
internal IPolicy GetPolicy()
|
|
{
|
|
return (IPolicy)typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
|
|
}
|
|
|
|
public int initializeAgentCalls;
|
|
public int collectObservationsCalls;
|
|
public int collectObservationsCallsForEpisode;
|
|
public int agentActionCalls;
|
|
public int agentActionCallsForEpisode;
|
|
public int agentOnEpisodeBeginCalls;
|
|
public int heuristicCalls;
|
|
public TestSensor sensor1;
|
|
public TestSensor sensor2;
|
|
|
|
[Observable("observableFloat")]
|
|
public float observableFloat;
|
|
|
|
public override void Initialize()
|
|
{
|
|
initializeAgentCalls += 1;
|
|
|
|
// Add in some custom Sensors so we can confirm they get sorted as expected.
|
|
sensor1 = new TestSensor("testsensor1");
|
|
sensor2 = new TestSensor("testsensor2");
|
|
sensor2.compressionType = SensorCompressionType.PNG;
|
|
|
|
sensors.Add(sensor2);
|
|
sensors.Add(sensor1);
|
|
}
|
|
|
|
public override void CollectObservations(VectorSensor sensor)
|
|
{
|
|
collectObservationsCalls += 1;
|
|
collectObservationsCallsForEpisode += 1;
|
|
sensor.AddObservation(collectObservationsCallsForEpisode);
|
|
}
|
|
|
|
public override void OnActionReceived(ActionBuffers buffers)
|
|
{
|
|
agentActionCalls += 1;
|
|
agentActionCallsForEpisode += 1;
|
|
AddReward(0.1f);
|
|
}
|
|
|
|
public override void OnEpisodeBegin()
|
|
{
|
|
agentOnEpisodeBeginCalls += 1;
|
|
collectObservationsCallsForEpisode = 0;
|
|
agentActionCallsForEpisode = 0;
|
|
}
|
|
|
|
public override void Heuristic(in ActionBuffers actionsOut)
|
|
{
|
|
var obs = GetObservations();
|
|
var continuousActions = actionsOut.ContinuousActions;
|
|
continuousActions[0] = (int)obs[0];
|
|
heuristicCalls++;
|
|
}
|
|
}
|
|
|
|
public class TestSensor : ISensor
|
|
{
|
|
public string sensorName;
|
|
public int numWriteCalls;
|
|
public int numCompressedCalls;
|
|
public int numResetCalls;
|
|
public SensorCompressionType compressionType = SensorCompressionType.None;
|
|
|
|
public TestSensor(string n)
|
|
{
|
|
sensorName = n;
|
|
}
|
|
|
|
public ObservationSpec GetObservationSpec()
|
|
{
|
|
return ObservationSpec.Vector(0);
|
|
}
|
|
|
|
public int Write(ObservationWriter writer)
|
|
{
|
|
numWriteCalls++;
|
|
// No-op
|
|
return 0;
|
|
}
|
|
|
|
public byte[] GetCompressedObservation()
|
|
{
|
|
numCompressedCalls++;
|
|
return new byte[] { 0 };
|
|
}
|
|
|
|
public SensorCompressionType GetCompressionType()
|
|
{
|
|
return compressionType;
|
|
}
|
|
|
|
public string GetName()
|
|
{
|
|
return sensorName;
|
|
}
|
|
|
|
public void Update() { }
|
|
|
|
public void Reset()
|
|
{
|
|
numResetCalls++;
|
|
}
|
|
}
|
|
|
|
public class TestClasses
|
|
{
|
|
|
|
}
|
|
}
|