浏览代码

Make the agent begin episode at initialization (#3605)

* Make the agent begin episode at initialization

* Renaming and adding a comment

* [skip ci] Update com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* [skip ci] Update com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* renamed test variables and modified some test statements

* Use TotalStepCount rather than HadFirstReset

* [skip ci] Renamed HadFirstReset to m_HadFirstReset

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/bug-failed-api-check
GitHub 5 年前
当前提交
119141fb
共有 3 个文件被更改,包括 70 次插入51 次删除
  1. 6
      com.unity.ml-agents/Runtime/Academy.cs
  2. 9
      com.unity.ml-agents/Runtime/Agent.cs
  3. 106
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs

6
com.unity.ml-agents/Runtime/Academy.cs


List<ModelRunner> m_ModelRunners = new List<ModelRunner>();
// Flag used to keep track of the first time the Academy is reset.
bool m_FirstAcademyReset;
bool m_HadFirstReset;
// The Academy uses a series of events to communicate with agents
// to facilitate synchronization. More specifically, it ensure

{
EnvironmentReset();
AgentForceReset?.Invoke();
m_FirstAcademyReset = true;
m_HadFirstReset = true;
}
/// <summary>

public void EnvironmentStep()
{
if (!m_FirstAcademyReset)
if (!m_HadFirstReset)
{
ForcedFullReset();
}

9
com.unity.ml-agents/Runtime/Agent.cs


ResetData();
Initialize();
InitializeSensors();
// The first time the Academy resets, all Agents in the scene will be
// forced to reset through the <see cref="AgentForceReset"/> event.
// To avoid the Agent resetting twice, the Agents will not begin their
// episode when initializing until after the Academy had its first reset.
if (Academy.Instance.TotalStepCount != 0)
{
OnEpisodeBegin();
}
}
/// <summary>

106
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


public int initializeAgentCalls;
public int collectObservationsCalls;
public int collectObservationsCallsSinceLastReset;
public int collectObservationsCallsForEpisode;
public int agentActionCallsSinceLastReset;
public int agentResetCalls;
public int agentActionCallsForEpisode;
public int agentOnEpisodeBeginCalls;
public int heuristicCalls;
public TestSensor sensor1;
public TestSensor sensor2;

public override void CollectObservations(VectorSensor sensor)
{
collectObservationsCalls += 1;
collectObservationsCallsSinceLastReset += 1;
collectObservationsCallsForEpisode += 1;
sensor.AddObservation(0f);
}

agentActionCallsSinceLastReset += 1;
agentActionCallsForEpisode += 1;
agentResetCalls += 1;
collectObservationsCallsSinceLastReset = 0;
agentActionCallsSinceLastReset = 0;
agentOnEpisodeBeginCalls += 1;
collectObservationsCallsForEpisode = 0;
agentActionCallsForEpisode = 0;
}
public override float[] Heuristic()

agentGo2.AddComponent<TestAgent>();
var agent2 = agentGo2.GetComponent<TestAgent>();
Assert.AreEqual(0, agent1.agentResetCalls);
Assert.AreEqual(0, agent2.agentResetCalls);
Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls);
Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls);
Assert.AreEqual(0, agent1.initializeAgentCalls);
Assert.AreEqual(0, agent2.initializeAgentCalls);
Assert.AreEqual(0, agent1.agentActionCalls);

// agent1 was not enabled when the academy started
// The agents have been initialized
Assert.AreEqual(0, agent1.agentResetCalls);
Assert.AreEqual(0, agent2.agentResetCalls);
Assert.AreEqual(0, agent1.agentOnEpisodeBeginCalls);
Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls);
Assert.AreEqual(1, agent1.initializeAgentCalls);
Assert.AreEqual(1, agent2.initializeAgentCalls);
Assert.AreEqual(0, agent1.agentActionCalls);

agent1.LazyInitialize();
var numberAgent1Reset = 0;
var numberAgent1Episodes = 0;
var numberAgent2Episodes = 0;
Assert.AreEqual(numberAgent1Reset, agent1.agentResetCalls);
// Agent2 is never reset since initialized after academy
Assert.AreEqual(0, agent2.agentResetCalls);
Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls);
Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls);
Assert.AreEqual(1, agent1.initializeAgentCalls);
Assert.AreEqual(numberAgent2Initialization, agent2.initializeAgentCalls);
Assert.AreEqual(i, agent1.agentActionCalls);

// Agent 1 resets at the first step
// Agent 1 starts a new episode at the first step
numberAgent1Reset += 1;
numberAgent1Episodes += 1;
// Since Agent2 is initialized after the Academy has stepped, its OnEpisodeBegin should be called now.
Assert.AreEqual(0, agent2.agentOnEpisodeBeginCalls);
Assert.AreEqual(1, agent2.agentOnEpisodeBeginCalls);
numberAgent2Episodes += 1;
}
// We are testing request decision and request actions when called

agent2.LazyInitialize();
var numberAgent1Reset = 0;
var numberAgent2Reset = 0;
var numberAgent1Episodes = 0;
var numberAgent2Episodes = 0;
var agent2StepSinceReset = 0;
var agent2StepForEpisode = 0;
for (var i = 0; i < 5000; i++)
{
Assert.AreEqual(acaStepsSinceReset, aca.StepCount);

Assert.AreEqual(agent2StepSinceReset, agent2.StepCount);
Assert.AreEqual(numberAgent1Reset, agent1.agentResetCalls);
Assert.AreEqual(numberAgent2Reset, agent2.agentResetCalls);
Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls);
Assert.AreEqual(agent2StepForEpisode, agent2.StepCount);
// Agent 2 and academy reset at the first step
// Agent 2 and academy reset at the first step
Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls);
numberAgent2Reset += 1;
numberAgent2Episodes += 1;
Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls);
numberAgent1Episodes += 1;
Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls);
Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls);
numberAgent1Reset += 1;
numberAgent1Episodes += 1;
Assert.AreEqual(numberAgent1Episodes, agent1.agentOnEpisodeBeginCalls);
// Resetting agent 2 regularly
// Ending the episode for agent 2 regularly
Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls);
numberAgent2Reset += 1;
agent2StepSinceReset = 0;
numberAgent2Episodes += 1;
agent2StepForEpisode = 0;
Assert.AreEqual(numberAgent2Episodes, agent2.agentOnEpisodeBeginCalls);
}
// Request a decision for agent 2 regularly
if (i % 3 == 2)

}
acaStepsSinceReset += 1;
agent2StepSinceReset += 1;
agent2StepForEpisode += 1;
aca.EnvironmentStep();
}
}

agent1.LazyInitialize();
agent2.SetPolicy(new TestPolicy());
var expectedAgent1ActionSinceReset = 0;
var expectedAgent1ActionForEpisode = 0;
expectedAgent1ActionSinceReset += 1;
if (expectedAgent1ActionSinceReset == agent1.maxStep || i == 0)
expectedAgent1ActionForEpisode += 1;
if (expectedAgent1ActionForEpisode == agent1.maxStep || i == 0)
expectedAgent1ActionSinceReset = 0;
expectedAgent1ActionForEpisode = 0;
Assert.LessOrEqual(Mathf.Abs(expectedAgent1ActionSinceReset * 10.1f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(expectedAgent1ActionForEpisode * 10.1f - agent1.GetCumulativeReward()), 0.05f);
Assert.LessOrEqual(Mathf.Abs(i * 0.1f - agent2.GetCumulativeReward()), 0.05f);
agent1.AddReward(10f);

agent1.LazyInitialize();
var expectedAgentStepCount = 0;
var expectedResets = 0;
var expectedEpisodes = 0;
var expectedAgentActionSinceReset = 0;
var expectedAgentActionForEpisode = 0;
var expectedCollectObsCallsSinceReset = 0;
var expectedCollectObsCallsForEpisode = 0;
expectedAgentActionSinceReset += 1;
expectedAgentActionForEpisode += 1;
expectedCollectObsCallsSinceReset += 1;
expectedCollectObsCallsForEpisode += 1;
expectedResets += 1;
expectedEpisodes += 1;
expectedAgentActionSinceReset = 0;
expectedCollectObsCallsSinceReset = 0;
expectedAgentActionForEpisode = 0;
expectedCollectObsCallsForEpisode = 0;
Assert.AreEqual(expectedResets, agent1.agentResetCalls);
Assert.AreEqual(expectedEpisodes, agent1.agentOnEpisodeBeginCalls);
Assert.AreEqual(expectedAgentActionSinceReset, agent1.agentActionCallsSinceLastReset);
Assert.AreEqual(expectedAgentActionForEpisode, agent1.agentActionCallsForEpisode);
Assert.AreEqual(expectedCollectObsCallsSinceReset, agent1.collectObservationsCallsSinceLastReset);
Assert.AreEqual(expectedCollectObsCallsForEpisode, agent1.collectObservationsCallsForEpisode);
}
}

正在加载...
取消
保存