Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

170 行
4.9 KiB

using System;
using System.Collections.Generic;
using System.Reflection;
using System.Runtime.CompilerServices;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Sensors.Reflection;
[assembly: InternalsVisibleTo("Unity.ML-Agents.Editor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Runtime.Sensor.Tests")]
[assembly: InternalsVisibleTo("Unity.ML-Agents.Extensions.EditorTests")]
namespace Unity.MLAgents.Utils.Tests
{
internal class TestPolicy : IPolicy
{
public Action OnRequestDecision;
ObservationWriter m_ObsWriter = new ObservationWriter();
static ActionSpec s_ActionSpec = ActionSpec.MakeContinuous(1);
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers(new float[1], Array.Empty<int>());
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
foreach (var sensor in sensors)
{
sensor.GetObservationProto(m_ObsWriter);
}
OnRequestDecision?.Invoke();
}
public ref readonly ActionBuffers DecideAction() { return ref s_EmptyActionBuffers; }
public void Dispose() { }
}
public class TestAgent : Agent
{
internal AgentInfo _Info
{
get
{
return (AgentInfo)typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
set
{
typeof(Agent).GetField("m_Info", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, value);
}
}
internal void SetPolicy(IPolicy policy)
{
typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).SetValue(this, policy);
}
internal IPolicy GetPolicy()
{
return (IPolicy)typeof(Agent).GetField("m_Brain", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(this);
}
public int initializeAgentCalls;
public int collectObservationsCalls;
public int collectObservationsCallsForEpisode;
public int agentActionCalls;
public int agentActionCallsForEpisode;
public int agentOnEpisodeBeginCalls;
public int heuristicCalls;
public TestSensor sensor1;
public TestSensor sensor2;
[Observable("observableFloat")]
public float observableFloat;
public override void Initialize()
{
initializeAgentCalls += 1;
// Add in some custom Sensors so we can confirm they get sorted as expected.
sensor1 = new TestSensor("testsensor1");
sensor2 = new TestSensor("testsensor2");
sensor2.compressionType = SensorCompressionType.PNG;
sensors.Add(sensor2);
sensors.Add(sensor1);
}
public override void CollectObservations(VectorSensor sensor)
{
collectObservationsCalls += 1;
collectObservationsCallsForEpisode += 1;
sensor.AddObservation(collectObservationsCallsForEpisode);
}
public override void OnActionReceived(ActionBuffers buffers)
{
agentActionCalls += 1;
agentActionCallsForEpisode += 1;
AddReward(0.1f);
}
public override void OnEpisodeBegin()
{
agentOnEpisodeBeginCalls += 1;
collectObservationsCallsForEpisode = 0;
agentActionCallsForEpisode = 0;
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var obs = GetObservations();
var continuousActions = actionsOut.ContinuousActions;
continuousActions[0] = (int)obs[0];
heuristicCalls++;
}
}
public class TestSensor : ISensor
{
public string sensorName;
public int numWriteCalls;
public int numCompressedCalls;
public int numResetCalls;
public SensorCompressionType compressionType = SensorCompressionType.None;
public TestSensor(string n)
{
sensorName = n;
}
public ObservationSpec GetObservationSpec()
{
return ObservationSpec.Vector(0);
}
public int Write(ObservationWriter writer)
{
numWriteCalls++;
// No-op
return 0;
}
public byte[] GetCompressedObservation()
{
numCompressedCalls++;
return new byte[] { 0 };
}
public SensorCompressionType GetCompressionType()
{
return compressionType;
}
public string GetName()
{
return sensorName;
}
public void Update() { }
public void Reset()
{
numResetCalls++;
}
}
public class TestClasses
{
}
}