浏览代码
[MLA-1474] detect recursion on Agent methods and throw (#4573)
[MLA-1474] detect recursion on Agent methods and throw (#4573)
* recursion checker proof-of-concept * checkers on agent * cleanup and unit tests * changelog * extra test * update comment/MLA-1734-demo-provider
GitHub
4 年前
当前提交
024bb104
共有 8 个文件被更改,包括 195 次插入 和 28 次删除
-
3com.unity.ml-agents/CHANGELOG.md
-
27com.unity.ml-agents/Runtime/Academy.cs
-
24com.unity.ml-agents/Runtime/Agent.cs
-
56com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
-
35com.unity.ml-agents/Runtime/RecursionChecker.cs
-
3com.unity.ml-agents/Runtime/RecursionChecker.cs.meta
-
72com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs
-
3com.unity.ml-agents/Tests/Editor/RecursionCheckerTests.cs.meta
|
|||
using System; |
|||
|
|||
namespace Unity.MLAgents |
|||
{ |
|||
internal class RecursionChecker : IDisposable |
|||
{ |
|||
private bool m_IsRunning; |
|||
private string m_MethodName; |
|||
|
|||
public RecursionChecker(string methodName) |
|||
{ |
|||
m_MethodName = methodName; |
|||
} |
|||
|
|||
public IDisposable Start() |
|||
{ |
|||
if (m_IsRunning) |
|||
{ |
|||
throw new UnityAgentsException( |
|||
$"{m_MethodName} called recursively. " + |
|||
"This might happen if you call EnvironmentStep() or EndEpisode() from custom " + |
|||
"code such as CollectObservations() or OnActionReceived()." |
|||
); |
|||
} |
|||
m_IsRunning = true; |
|||
return this; |
|||
} |
|||
|
|||
public void Dispose() |
|||
{ |
|||
// Reset the flag when we're done (or if an exception occurred).
|
|||
m_IsRunning = false; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 49ebd06532b24078a6edda823aeff5d2 |
|||
timeCreated: 1602731302 |
|
|||
using System; |
|||
using NUnit.Framework; |
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
[TestFixture] |
|||
public class RecursionCheckerTests |
|||
{ |
|||
class InfiniteRecurser |
|||
{ |
|||
RecursionChecker m_checker = new RecursionChecker("InfiniteRecurser"); |
|||
public int NumCalls = 0; |
|||
|
|||
public void Implode() |
|||
{ |
|||
NumCalls++; |
|||
using (m_checker.Start()) |
|||
{ |
|||
Implode(); |
|||
} |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestRecursionCheck() |
|||
{ |
|||
var rc = new InfiniteRecurser(); |
|||
Assert.Throws<UnityAgentsException>(() => |
|||
{ |
|||
rc.Implode(); |
|||
}); |
|||
|
|||
// Should increment twice before bailing out.
|
|||
Assert.AreEqual(2, rc.NumCalls); |
|||
} |
|||
|
|||
class OneTimeThrower |
|||
{ |
|||
RecursionChecker m_checker = new RecursionChecker("OneTimeThrower"); |
|||
public int NumCalls; |
|||
|
|||
public void DoStuff() |
|||
{ |
|||
// This method throws from inside the checker the first time.
|
|||
// Later calls do nothing.
|
|||
NumCalls++; |
|||
using (m_checker.Start()) |
|||
{ |
|||
if (NumCalls == 1) |
|||
{ |
|||
throw new ArgumentException("oops"); |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestThrowResetsFlag() |
|||
{ |
|||
var ott = new OneTimeThrower(); |
|||
Assert.Throws<ArgumentException>(() => |
|||
{ |
|||
ott.DoStuff(); |
|||
}); |
|||
|
|||
// Make sure the flag is cleared if we throw in the "using". Should be able to step subsequently.
|
|||
ott.DoStuff(); |
|||
ott.DoStuff(); |
|||
Assert.AreEqual(3, ott.NumCalls); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 5a7183e11dd5434684a4225c80169173 |
|||
timeCreated: 1602781778 |
撰写
预览
正在加载...
取消
保存
Reference in new issue