浏览代码

[MLA-1474] detect recursion on Agent methods and throw (#4573)

* recursion checker proof-of-concept

* checkers on agent

* cleanup and unit tests

* changelog

* extra test

* update comment
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
024bb104
共有 8 个文件被更改,包括 195 次插入28 次删除
  1. 3
      com.unity.ml-agents/CHANGELOG.md
  2. 27
      com.unity.ml-agents/Runtime/Academy.cs
  3. 24
      com.unity.ml-agents/Runtime/Agent.cs
  4. 56
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  5. 35
      com.unity.ml-agents/Runtime/RecursionChecker.cs
  6. 3
      com.unity.ml-agents/Runtime/RecursionChecker.cs.meta
  7. 72
      com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs
  8. 3
      com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta

3
com.unity.ml-agents/CHANGELOG.md


### Bug Fixes
#### com.unity.ml-agents (C#)
- `Agent.CollectObservations()` and `Agent.EndEpisode()` will now throw an exception
if they are called recursively (for example, if they call `Agent.EndEpisode()`).
Previously, this would result in an infinite loop and cause the editor to hang. (#4573)
#### ml-agents / ml-agents-envs / gym-unity (Python)

27
com.unity.ml-agents/Runtime/Academy.cs


// Flag used to keep track of the first time the Academy is reset.
bool m_HadFirstReset;
// Whether the Academy is in the middle of a step. This is used to detect and Academy
// step called by user code that is also called by the Academy.
bool m_IsStepping;
// Detect an Academy step called by user code that is also called by the Academy.
private RecursionChecker m_StepRecursionChecker = new RecursionChecker("EnvironmentStep");
// Random seed used for inference.
int m_InferenceSeed;

/// </summary>
public void EnvironmentStep()
{
// Check whether we're already in the middle of a step.
// This shouldn't happen generally, but could happen if user code (e.g. CollectObservations)
// that is called by EnvironmentStep() also calls EnvironmentStep(). This would result
// in an infinite loop and/or stack overflow, so stop it before it happens.
if (m_IsStepping)
{
throw new UnityAgentsException(
"Academy.EnvironmentStep() called recursively. " +
"This might happen if you call EnvironmentStep() from custom code such as " +
"CollectObservations() or OnActionReceived()."
);
}
m_IsStepping = true;
try
using (m_StepRecursionChecker.Start())
{
if (!m_HadFirstReset)
{

{
AgentAct?.Invoke();
}
}
finally
{
// Reset m_IsStepping when we're done (or if an exception occurred).
m_IsStepping = false;
}
}

24
com.unity.ml-agents/Runtime/Agent.cs


/// </summary>
internal VectorSensor collectObservationsSensor;
private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");
/// <summary>
/// List of IActuators that this Agent will delegate actions to if any exist.
/// </summary>

// episode when initializing until after the Academy had its first reset.
if (Academy.Instance.TotalStepCount != 0)
{
OnEpisodeBegin();
using (m_OnEpisodeBeginChecker.Start())
{
OnEpisodeBegin();
}
}
}

{
// Make sure the latest observations are being passed to training.
collectObservationsSensor.Reset();
CollectObservations(collectObservationsSensor);
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is done immediately

UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor);
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{

{
ResetData();
m_StepCount = 0;
OnEpisodeBegin();
using (m_OnEpisodeBeginChecker.Start())
{
OnEpisodeBegin();
}
}
/// <summary>

56
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


}
}
}
[TestFixture]
public class AgentRecursionTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
class CollectObsEndEpisodeAgent : Agent
{
public override void CollectObservations(VectorSensor sensor)
{
// NEVER DO THIS IN REAL CODE!
EndEpisode();
}
}
class OnEpisodeBeginEndEpisodeAgent : Agent
{
public override void OnEpisodeBegin()
{
// NEVER DO THIS IN REAL CODE!
EndEpisode();
}
}
void TestRecursiveThrows<T>() where T : Agent
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<T>();
agent.LazyInitialize();
agent.RequestDecision();
Assert.Throws<UnityAgentsException>(() =>
{
Academy.Instance.EnvironmentStep();
});
}
[Test]
public void TestRecursiveCollectObsEndEpisodeThrows()
{
TestRecursiveThrows<CollectObsEndEpisodeAgent>();
}
[Test]
public void TestRecursiveOnEpisodeBeginEndEpisodeThrows()
{
TestRecursiveThrows<OnEpisodeBeginEndEpisodeAgent>();
}
}
}

35
com.unity.ml-agents/Runtime/RecursionChecker.cs


using System;
namespace Unity.MLAgents
{
internal class RecursionChecker : IDisposable
{
private bool m_IsRunning;
private string m_MethodName;
public RecursionChecker(string methodName)
{
m_MethodName = methodName;
}
public IDisposable Start()
{
if (m_IsRunning)
{
throw new UnityAgentsException(
$"{m_MethodName} called recursively. " +
"This might happen if you call EnvironmentStep() or EndEpisode() from custom " +
"code such as CollectObservations() or OnActionReceived()."
);
}
m_IsRunning = true;
return this;
}
public void Dispose()
{
// Reset the flag when we're done (or if an exception occurred).
m_IsRunning = false;
}
}
}

3
com.unity.ml-agents/Runtime/RecursionChecker.cs.meta


fileFormatVersion: 2
guid: 49ebd06532b24078a6edda823aeff5d2
timeCreated: 1602731302

72
com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs


using System;
using NUnit.Framework;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class RecursionCheckerTests
{
class InfiniteRecurser
{
RecursionChecker m_checker = new RecursionChecker("InfiniteRecurser");
public int NumCalls = 0;
public void Implode()
{
NumCalls++;
using (m_checker.Start())
{
Implode();
}
}
}
[Test]
public void TestRecursionCheck()
{
var rc = new InfiniteRecurser();
Assert.Throws<UnityAgentsException>(() =>
{
rc.Implode();
});
// Should increment twice before bailing out.
Assert.AreEqual(2, rc.NumCalls);
}
class OneTimeThrower
{
RecursionChecker m_checker = new RecursionChecker("OneTimeThrower");
public int NumCalls;
public void DoStuff()
{
// This method throws from inside the checker the first time.
// Later calls do nothing.
NumCalls++;
using (m_checker.Start())
{
if (NumCalls == 1)
{
throw new ArgumentException("oops");
}
}
}
}
[Test]
public void TestThrowResetsFlag()
{
var ott = new OneTimeThrower();
Assert.Throws<ArgumentException>(() =>
{
ott.DoStuff();
});
// Make sure the flag is cleared if we throw in the "using". Should be able to step subsequently.
ott.DoStuff();
ott.DoStuff();
Assert.AreEqual(3, ott.NumCalls);
}
}
}

3
com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta


fileFormatVersion: 2
guid: 5a7183e11dd5434684a4225c80169173
timeCreated: 1602781778
正在加载...
取消
保存