浏览代码

backport fix for recursion in user code (#4638)

* backport fix for recursion in user code
/r2v-yamato-linux
GitHub 4 年前
当前提交
5066c28e
共有 8 个文件被更改,包括 223 次插入27 次删除
  1. 3
      com.unity.ml-agents/CHANGELOG.md
  2. 54
      com.unity.ml-agents/Runtime/Academy.cs
  3. 24
      com.unity.ml-agents/Runtime/Agent.cs
  4. 56
      com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
  5. 35
      com.unity.ml-agents/Runtime/RecursionChecker.cs
  6. 3
      com.unity.ml-agents/Runtime/RecursionChecker.cs.meta
  7. 72
      com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs
  8. 3
      com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta

3
com.unity.ml-agents/CHANGELOG.md


### Bug Fixes
#### com.unity.ml-agents (C#)
- Fixed a bug with visual observations using .onnx model files and newer versions of Barracuda (1.1.0 or later). (#4533)
- `Agent.CollectObservations()`, `Agent.EndEpisode()`, and `Academy.EnvironmentStep()` will now throw an exception
if they are called recursively (for example, if they call `Agent.EndEpisode()`).
Previously, this would result in an infinite loop and cause the editor to hang. (#4638)
- Fixed a bug where accessing the Academy outside of play mode would cause the Academy to get stepped multiple times when in play mode. (#4637)
## [1.0.5] - 2020-09-23

54
com.unity.ml-agents/Runtime/Academy.cs


// Flag used to keep track of the first time the Academy is reset.
bool m_HadFirstReset;
// Detect an Academy step called by user code that is also called by the Academy.
private RecursionChecker m_StepRecursionChecker = new RecursionChecker("EnvironmentStep");
// Random seed used for inference.
int m_InferenceSeed;

/// </summary>
public void EnvironmentStep()
{
if (!m_HadFirstReset)
using (m_StepRecursionChecker.Start())
ForcedFullReset();
}
AgentPreStep?.Invoke(m_StepCount);
m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();
if (!m_HadFirstReset)
{
ForcedFullReset();
}
using (TimerStack.Instance.Scoped("AgentSendState"))
{
AgentSendState?.Invoke();
}
AgentPreStep?.Invoke(m_StepCount);
m_StepCount += 1;
m_TotalStepCount += 1;
AgentIncrementStep?.Invoke();
using (TimerStack.Instance.Scoped("AgentSendState"))
{
AgentSendState?.Invoke();
}
using (TimerStack.Instance.Scoped("DecideAction"))
{
DecideAction?.Invoke();
}
using (TimerStack.Instance.Scoped("DecideAction"))
{
DecideAction?.Invoke();
}
// If the communicator is not on, we need to clear the SideChannel sending queue
if (!IsCommunicatorOn)
{
SideChannelsManager.GetSideChannelMessage();
}
// If the communicator is not on, we need to clear the SideChannel sending queue
if (!IsCommunicatorOn)
{
SideChannelsManager.GetSideChannelMessage();
}
using (TimerStack.Instance.Scoped("AgentAct"))
{
AgentAct?.Invoke();
using (TimerStack.Instance.Scoped("AgentAct"))
{
AgentAct?.Invoke();
}
}
}

24
com.unity.ml-agents/Runtime/Agent.cs


/// </summary>
internal VectorSensor collectObservationsSensor;
private RecursionChecker m_CollectObservationsChecker = new RecursionChecker("CollectObservations");
private RecursionChecker m_OnEpisodeBeginChecker = new RecursionChecker("OnEpisodeBegin");
/// <summary>
/// Called when the attached [GameObject] becomes enabled and active.
/// [GameObject]: https://docs.unity3d.com/Manual/GameObjects.html

// episode when initializing until after the Academy had its first reset.
if (Academy.Instance.TotalStepCount != 0)
{
OnEpisodeBegin();
using (m_OnEpisodeBeginChecker.Start())
{
OnEpisodeBegin();
}
}
}

{
// Make sure the latest observations are being passed to training.
collectObservationsSensor.Reset();
CollectObservations(collectObservationsSensor);
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}
// Request the last decision with no callbacks
// We request a decision so Python knows the Agent is done immediately

UpdateSensors();
using (TimerStack.Instance.Scoped("CollectObservations"))
{
CollectObservations(collectObservationsSensor);
using (m_CollectObservationsChecker.Start())
{
CollectObservations(collectObservationsSensor);
}
}
using (TimerStack.Instance.Scoped("CollectDiscreteActionMasks"))
{

{
ResetData();
m_StepCount = 0;
OnEpisodeBegin();
using (m_OnEpisodeBeginChecker.Start())
{
OnEpisodeBegin();
}
}
/// <summary>

56
com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs


_InnerAgentTestOnEnableOverride();
}
}
[TestFixture]
public class AgentRecursionTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
class CollectObsEndEpisodeAgent : Agent
{
public override void CollectObservations(VectorSensor sensor)
{
// NEVER DO THIS IN REAL CODE!
EndEpisode();
}
}
class OnEpisodeBeginEndEpisodeAgent : Agent
{
public override void OnEpisodeBegin()
{
// NEVER DO THIS IN REAL CODE!
EndEpisode();
}
}
void TestRecursiveThrows<T>() where T : Agent
{
var gameObj = new GameObject();
var agent = gameObj.AddComponent<T>();
agent.LazyInitialize();
agent.RequestDecision();
Assert.Throws<UnityAgentsException>(() =>
{
Academy.Instance.EnvironmentStep();
});
}
[Test]
public void TestRecursiveCollectObsEndEpisodeThrows()
{
TestRecursiveThrows<CollectObsEndEpisodeAgent>();
}
[Test]
public void TestRecursiveOnEpisodeBeginEndEpisodeThrows()
{
TestRecursiveThrows<OnEpisodeBeginEndEpisodeAgent>();
}
}
}

35
com.unity.ml-agents/Runtime/RecursionChecker.cs


using System;
namespace Unity.MLAgents
{
internal class RecursionChecker : IDisposable
{
private bool m_IsRunning;
private string m_MethodName;
public RecursionChecker(string methodName)
{
m_MethodName = methodName;
}
public IDisposable Start()
{
if (m_IsRunning)
{
throw new UnityAgentsException(
$"{m_MethodName} called recursively. " +
"This might happen if you call EnvironmentStep() or EndEpisode() from custom " +
"code such as CollectObservations() or OnActionReceived()."
);
}
m_IsRunning = true;
return this;
}
public void Dispose()
{
// Reset the flag when we're done (or if an exception occurred).
m_IsRunning = false;
}
}
}

3
com.unity.ml-agents/Runtime/RecursionChecker.cs.meta


fileFormatVersion: 2
guid: 49ebd06532b24078a6edda823aeff5d2
timeCreated: 1602731302

72
com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs


using System;
using NUnit.Framework;
namespace Unity.MLAgents.Tests
{
[TestFixture]
public class RecursionCheckerTests
{
class InfiniteRecurser
{
RecursionChecker m_checker = new RecursionChecker("InfiniteRecurser");
public int NumCalls = 0;
public void Implode()
{
NumCalls++;
using (m_checker.Start())
{
Implode();
}
}
}
[Test]
public void TestRecursionCheck()
{
var rc = new InfiniteRecurser();
Assert.Throws<UnityAgentsException>(() =>
{
rc.Implode();
});
// Should increment twice before bailing out.
Assert.AreEqual(2, rc.NumCalls);
}
class OneTimeThrower
{
RecursionChecker m_checker = new RecursionChecker("OneTimeThrower");
public int NumCalls;
public void DoStuff()
{
// This method throws from inside the checker the first time.
// Later calls do nothing.
NumCalls++;
using (m_checker.Start())
{
if (NumCalls == 1)
{
throw new ArgumentException("oops");
}
}
}
}
[Test]
public void TestThrowResetsFlag()
{
var ott = new OneTimeThrower();
Assert.Throws<ArgumentException>(() =>
{
ott.DoStuff();
});
// Make sure the flag is cleared if we throw in the "using". Should be able to step subsequently.
ott.DoStuff();
ott.DoStuff();
Assert.AreEqual(3, ott.NumCalls);
}
}
}

3
com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta


fileFormatVersion: 2
guid: 5a7183e11dd5434684a4225c80169173
timeCreated: 1602781778
正在加载...
取消
保存