using UnityEngine ;
using NUnit.Framework ;
using System.Reflection ;
using System.Collections.Generic ;
using Unity.MLAgents.Utils.Tests ;
internal class TestPolicy : IPolicy
{
public Action OnRequestDecision ;
ObservationWriter m_ObsWriter = new ObservationWriter ( ) ;
static ActionSpec s_ActionSpec = ActionSpec . MakeContinuous ( 1 ) ;
static ActionBuffers s_EmptyActionBuffers = new ActionBuffers ( new float [ 1 ] , Array . Empty < int > ( ) ) ;
public void RequestDecision ( AgentInfo info , List < ISensor > sensors )
{
foreach ( var sensor in sensors )
{
sensor . GetObservationProto ( m_ObsWriter ) ;
}
OnRequestDecision ? . Invoke ( ) ;
}
public ref readonly ActionBuffers DecideAction ( ) { return ref s_EmptyActionBuffers ; }
public void Dispose ( ) { }
}
public class TestAgent : Agent
{
internal AgentInfo _Info
{
get
{
return ( AgentInfo ) typeof ( Agent ) . GetField ( "m_Info" , BindingFlags . Instance | BindingFlags . NonPublic ) . GetValue ( this ) ;
}
set
{
typeof ( Agent ) . GetField ( "m_Info" , BindingFlags . Instance | BindingFlags . NonPublic ) . SetValue ( this , value ) ;
}
}
internal void SetPolicy ( IPolicy policy )
{
typeof ( Agent ) . GetField ( "m_Brain" , BindingFlags . Instance | BindingFlags . NonPublic ) . SetValue ( this , policy ) ;
}
internal IPolicy GetPolicy ( )
{
return ( IPolicy ) typeof ( Agent ) . GetField ( "m_Brain" , BindingFlags . Instance | BindingFlags . NonPublic ) . GetValue ( this ) ;
}
public int initializeAgentCalls ;
public int collectObservationsCalls ;
public int collectObservationsCallsForEpisode ;
public int agentActionCalls ;
public int agentActionCallsForEpisode ;
public int agentOnEpisodeBeginCalls ;
public int heuristicCalls ;
public TestSensor sensor1 ;
public TestSensor sensor2 ;
[Observable("observableFloat")]
public float observableFloat ;
public override void Initialize ( )
{
initializeAgentCalls + = 1 ;
// Add in some custom Sensors so we can confirm they get sorted as expected.
sensor1 = new TestSensor ( "testsensor1" ) ;
sensor2 = new TestSensor ( "testsensor2" ) ;
sensor2 . compressionType = SensorCompressionType . PNG ;
sensors . Add ( sensor2 ) ;
sensors . Add ( sensor1 ) ;
}
public override void CollectObservations ( VectorSensor sensor )
{
collectObservationsCalls + = 1 ;
collectObservationsCallsForEpisode + = 1 ;
sensor . AddObservation ( collectObservationsCallsForEpisode ) ;
}
public override void OnActionReceived ( ActionBuffers buffers )
{
agentActionCalls + = 1 ;
agentActionCallsForEpisode + = 1 ;
AddReward ( 0.1f ) ;
}
public override void OnEpisodeBegin ( )
{
agentOnEpisodeBeginCalls + = 1 ;
collectObservationsCallsForEpisode = 0 ;
agentActionCallsForEpisode = 0 ;
}
public override void Heuristic ( in ActionBuffers actionsOut )
{
var obs = GetObservations ( ) ;
var continuousActions = actionsOut . ContinuousActions ;
continuousActions [ 0 ] = ( int ) obs [ 0 ] ;
heuristicCalls + + ;
}
}
public class TestSensor : ISensor
{
public string sensorName ;
public int numWriteCalls ;
public int numCompressedCalls ;
public int numResetCalls ;
public SensorCompressionType compressionType = SensorCompressionType . None ;
public TestSensor ( string n )
{
sensorName = n ;
}
public ObservationSpec GetObservationSpec ( )
{
return ObservationSpec . Vector ( 0 ) ;
}
public int Write ( ObservationWriter writer )
{
numWriteCalls + + ;
// No-op
return 0 ;
}
public byte [ ] GetCompressedObservation ( )
{
numCompressedCalls + + ;
return new byte [ ] { 0 } ;
}
public SensorCompressionType GetCompressionType ( )
{
return compressionType ;
}
public string GetName ( )
{
return sensorName ;
}
public void Update ( ) { }
public void Reset ( )
{
numResetCalls + + ;
}
}
[TestFixture]
public class EditModeTestGeneration
{
Assert . AreEqual ( stepsSinceReset , aca . StepCount ) ;
Assert . AreEqual ( numberReset , aca . EpisodeCount ) ;
Assert . AreEqual ( i , aca . TotalStepCount ) ;
// Academy resets at the first step
if ( i = = 0 )
{
numberAcaReset + = 1 ;
numberAgent2Episodes + = 1 ;
}
//Agent 1 is only initialized at step 2
if ( i = = 2 )
{
Assert . AreEqual ( numberAgent1Episodes , agent1 . agentOnEpisodeBeginCalls ) ;
}
// Set agent 1 to done every 11 steps to test behavior
if ( i % 1 1 = = 5 )
{
Assert . AreEqual ( numberAgent1Episodes , agent1 . agentOnEpisodeBeginCalls ) ;
}
// Ending the episode for agent 2 regularly
if ( i % 1 3 = = 3 )
{
agent2StepForEpisode = 0 ;
Assert . AreEqual ( numberAgent2Episodes , agent2 . agentOnEpisodeBeginCalls ) ;
}
// Request a decision for agent 2 regularly
if ( i % 3 = = 2 )
{
agent2StepForEpisode + = 1 ;
aca . EnvironmentStep ( ) ;
}
}
[Test]
public void AssertStackingReset ( )
{
var agentGo1 = new GameObject ( "TestAgent" ) ;
var bp1 = agentGo1 . AddComponent < BehaviorParameters > ( ) ;
bp1 . BrainParameters . ActionSpec = ActionSpec . MakeContinuous ( 1 ) ;
var agent1 = agentGo1 . AddComponent < TestAgent > ( ) ;
var behaviorParameters = agentGo1 . GetComponent < BehaviorParameters > ( ) ;
behaviorParameters . BrainParameters . NumStackedVectorObservations = 3 ;
var aca = Academy . Instance ;
agent1 . LazyInitialize ( ) ;
var policy = new TestPolicy ( ) ;
agent1 . SetPolicy ( policy ) ;
StackingSensor sensor = null ;
foreach ( ISensor s in agent1 . sensors )
{
if ( s is StackingSensor )
{
sensor = s as StackingSensor ;
}
}
Assert . NotNull ( sensor ) ;
for ( int i = 0 ; i < 2 0 ; i + + )
{
agent1 . RequestDecision ( ) ;
aca . EnvironmentStep ( ) ;
}
policy . OnRequestDecision = ( ) = > SensorTestHelper . CompareObservation ( sensor , new [ ] { 1 8f , 1 9f , 2 1f } ) ;
agent1 . EndEpisode ( ) ;
SensorTestHelper . CompareObservation ( sensor , new [ ] { 0f , 0f , 0f } ) ;
}
}