Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

139 行
4.6 KiB

using NUnit.Framework;
namespace MLAgents.Tests
{
public class EditModeTestActionMasker
{
[Test]
public void Contruction()
{
var bp = new BrainParameters();
var masker = new ActionMasker(bp);
Assert.IsNotNull(masker);
}
[Test]
public void FailsWithContinuous()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.continuous;
bp.vectorActionSize = new int[1] {4};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new int[1] {0});
Assert.Catch<UnityAgentsException>(() => masker.GetMask());
}
[Test]
public void NullMask()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.discrete;
var masker = new ActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
}
[Test]
public void FirstBranchMask()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.discrete;
bp.vectorActionSize = new int[3] {4, 5, 6};
var masker = new ActionMasker(bp);
var mask = masker.GetMask();
Assert.IsNull(mask);
masker.SetActionMask(0, new int[]{1,2,3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsTrue(mask[1]);
Assert.IsTrue(mask[2]);
Assert.IsTrue(mask[3]);
Assert.IsFalse(mask[4]);
Assert.AreEqual(mask.Length, 15);
}
[Test]
public void SecondBranchMask()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.discrete;
bp.vectorActionSize = new int[3] {4, 5, 6};
var masker = new ActionMasker(bp);
bool[] mask = masker.GetMask();
masker.SetActionMask(1, new int[]{1,2,3});
mask = masker.GetMask();
Assert.IsFalse(mask[0]);
Assert.IsFalse(mask[4]);
Assert.IsTrue(mask[5]);
Assert.IsTrue(mask[6]);
Assert.IsTrue(mask[7]);
Assert.IsFalse(mask[8]);
Assert.IsFalse(mask[9]);
}
[Test]
public void MaskReset()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.discrete;
bp.vectorActionSize = new int[3] {4, 5, 6};
var masker = new ActionMasker(bp);
var mask = masker.GetMask();
masker.SetActionMask(1, new int[3]{1,2,3});
mask = masker.GetMask();
masker.ResetMask();
mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
Assert.IsFalse(mask[i]);
}
}
[Test]
public void ThrowsError()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.discrete;
bp.vectorActionSize = new int[3] {4, 5, 6};
var masker = new ActionMasker(bp);
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(0, new int[1]{5}));
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(1, new int[1]{5}));
masker.SetActionMask(2, new int[1] {5});
Assert.Catch<UnityAgentsException>(
() => masker.SetActionMask(3, new int[1]{1}));
masker.GetMask();
masker.ResetMask();
masker.SetActionMask(0, new int[4] {0, 1, 2, 3});
Assert.Catch<UnityAgentsException>(
() => masker.GetMask());
}
[Test]
public void MultipleMaskEdit()
{
var bp = new BrainParameters();
bp.vectorActionSpaceType = SpaceType.discrete;
bp.vectorActionSize = new int[3] {4, 5, 6};
var masker = new ActionMasker(bp);
masker.SetActionMask(0, new int[2] {0, 1});
masker.SetActionMask(0, new int[1] {3});
masker.SetActionMask(2, new int[1] {1});
var mask = masker.GetMask();
for (var i = 0; i < 15; i++)
{
if ((i == 0) || (i == 1) || (i == 3)|| (i == 10))
{
Assert.IsTrue(mask[i]);
}
else
{
Assert.IsFalse(mask[i]);
}
}
}
}
}