您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
125 行
3.9 KiB
125 行
3.9 KiB
using NUnit.Framework;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.MLAgents.Policies;
|
|
using UnityEngine;
|
|
|
|
namespace Unity.MLAgents.Tests.Policies
|
|
{
|
|
[TestFixture]
|
|
public class HeuristicPolicyTest
|
|
{
|
|
[SetUp]
|
|
public void SetUp()
|
|
{
|
|
if (Academy.IsInitialized)
|
|
{
|
|
Academy.Instance.Dispose();
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Assert that the action buffers are initialized to zero, and then set them to non-zero values.
|
|
/// </summary>
|
|
/// <param name="actionsOut"></param>
|
|
static void CheckAndSetBuffer(in ActionBuffers actionsOut)
|
|
{
|
|
var continuousActions = actionsOut.ContinuousActions;
|
|
for (var continuousIndex = 0; continuousIndex < continuousActions.Length; continuousIndex++)
|
|
{
|
|
Assert.AreEqual(continuousActions[continuousIndex], 0.0f);
|
|
continuousActions[continuousIndex] = 1.0f;
|
|
}
|
|
|
|
var discreteActions = actionsOut.DiscreteActions;
|
|
for (var discreteIndex = 0; discreteIndex < discreteActions.Length; discreteIndex++)
|
|
{
|
|
Assert.AreEqual(discreteActions[discreteIndex], 0);
|
|
discreteActions[discreteIndex] = 1;
|
|
}
|
|
}
|
|
|
|
|
|
class ActionClearedAgent : Agent
|
|
{
|
|
public int HeuristicCalls = 0;
|
|
public override void Heuristic(in ActionBuffers actionsOut)
|
|
{
|
|
CheckAndSetBuffer(actionsOut);
|
|
HeuristicCalls++;
|
|
}
|
|
}
|
|
|
|
class ActionClearedActuator : IActuator
|
|
{
|
|
public int HeuristicCalls = 0;
|
|
public ActionClearedActuator(ActionSpec actionSpec)
|
|
{
|
|
ActionSpec = actionSpec;
|
|
Name = GetType().Name;
|
|
}
|
|
|
|
public void OnActionReceived(ActionBuffers actionBuffers)
|
|
{
|
|
}
|
|
|
|
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
|
|
{
|
|
}
|
|
|
|
public void Heuristic(in ActionBuffers actionBuffersOut)
|
|
{
|
|
CheckAndSetBuffer(actionBuffersOut);
|
|
HeuristicCalls++;
|
|
}
|
|
|
|
public ActionSpec ActionSpec { get; }
|
|
public string Name { get; }
|
|
|
|
public void ResetData()
|
|
{
|
|
|
|
}
|
|
}
|
|
|
|
class ActionClearedActuatorComponent : ActuatorComponent
|
|
{
|
|
public ActionClearedActuator ActionClearedActuator;
|
|
public ActionClearedActuatorComponent()
|
|
{
|
|
ActionSpec = new ActionSpec(2, new[] { 3, 3 });
|
|
}
|
|
|
|
public override IActuator[] CreateActuators()
|
|
{
|
|
ActionClearedActuator = new ActionClearedActuator(ActionSpec);
|
|
return new IActuator[] { ActionClearedActuator };
|
|
}
|
|
|
|
public override ActionSpec ActionSpec { get; }
|
|
}
|
|
|
|
[Test]
|
|
public void TestActionsCleared()
|
|
{
|
|
var gameObj = new GameObject();
|
|
var agent = gameObj.AddComponent<ActionClearedAgent>();
|
|
var behaviorParameters = agent.GetComponent<BehaviorParameters>();
|
|
behaviorParameters.BrainParameters.ActionSpec = new ActionSpec(1, new[] { 4 });
|
|
behaviorParameters.BrainParameters.VectorObservationSize = 0;
|
|
behaviorParameters.BehaviorType = BehaviorType.HeuristicOnly;
|
|
|
|
var actuatorComponent = gameObj.AddComponent<ActionClearedActuatorComponent>();
|
|
agent.LazyInitialize();
|
|
|
|
const int k_NumSteps = 5;
|
|
for (var i = 0; i < k_NumSteps; i++)
|
|
{
|
|
agent.RequestDecision();
|
|
Academy.Instance.EnvironmentStep();
|
|
}
|
|
|
|
Assert.AreEqual(agent.HeuristicCalls, k_NumSteps);
|
|
Assert.AreEqual(actuatorComponent.ActionClearedActuator.HeuristicCalls, k_NumSteps);
|
|
}
|
|
}
|
|
}
|