Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

198 行
5.8 KiB

using System.Collections.Generic;
using NUnit.Framework;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Integrations.Match3;
using UnityEngine;
namespace Unity.MLAgents.Tests.Integrations.Match3
{
internal class SimpleBoard : AbstractBoard
{
public int Rows;
public int Columns;
public int NumCellTypes;
public int NumSpecialTypes;
public int LastMoveIndex;
public bool MovesAreValid = true;
public bool CallbackCalled;
public override BoardSize GetMaxBoardSize()
{
return new BoardSize
{
Rows = Rows,
Columns = Columns,
NumCellTypes = NumCellTypes,
NumSpecialTypes = NumSpecialTypes
};
}
public override int GetCellType(int row, int col)
{
return 0;
}
public override int GetSpecialType(int row, int col)
{
return 0;
}
public override bool IsMoveValid(Move m)
{
return MovesAreValid;
}
public override bool MakeMove(Move m)
{
LastMoveIndex = m.MoveIndex;
return MovesAreValid;
}
public void Callback()
{
CallbackCalled = true;
}
}
public class Match3ActuatorTests
{
[SetUp]
public void SetUp()
{
if (Academy.IsInitialized)
{
Academy.Instance.Dispose();
}
}
[TestCase(true)]
[TestCase(false)]
public void TestValidMoves(bool movesAreValid)
{
// Check that a board with no valid moves doesn't raise an exception.
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var agent = gameObj.AddComponent<Agent>();
gameObj.AddComponent<Match3ActuatorComponent>();
board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;
board.MovesAreValid = movesAreValid;
board.OnNoValidMovesAction = board.Callback;
board.LastMoveIndex = -1;
agent.LazyInitialize();
agent.RequestDecision();
Academy.Instance.EnvironmentStep();
if (movesAreValid)
{
Assert.IsFalse(board.CallbackCalled);
}
else
{
Assert.IsTrue(board.CallbackCalled);
}
Assert.AreNotEqual(-1, board.LastMoveIndex);
}
[Test]
public void TestActionSpec()
{
var gameObj = new GameObject();
var board = gameObj.AddComponent<SimpleBoard>();
var actuator = gameObj.AddComponent<Match3ActuatorComponent>();
board.Rows = 5;
board.Columns = 5;
board.NumCellTypes = 5;
board.NumSpecialTypes = 0;
var actionSpec = actuator.ActionSpec;
Assert.AreEqual(1, actionSpec.NumDiscreteActions);
Assert.AreEqual(board.NumMoves(), actionSpec.BranchSizes[0]);
}
[Test]
public void TestActionSpecNullBoard()
{
var gameObj = new GameObject();
var actuator = gameObj.AddComponent<Match3ActuatorComponent>();
var actionSpec = actuator.ActionSpec;
Assert.AreEqual(0, actionSpec.NumDiscreteActions);
Assert.AreEqual(0, actionSpec.NumContinuousActions);
}
public class HashSetActionMask : IDiscreteActionMask
{
public HashSet<int>[] HashSets;
public HashSetActionMask(ActionSpec spec)
{
HashSets = new HashSet<int>[spec.NumDiscreteActions];
for (var i = 0; i < spec.NumDiscreteActions; i++)
{
HashSets[i] = new HashSet<int>();
}
}
public void SetActionEnabled(int branch, int actionIndex, bool isEnabled)
{
var hashSet = HashSets[branch];
if (isEnabled)
{
hashSet.Remove(actionIndex);
}
else
{
hashSet.Add(actionIndex);
}
}
}
[TestCase(true, TestName = "Full Board")]
[TestCase(false, TestName = "Small Board")]
public void TestMasking(bool fullBoard)
{
var gameObj = new GameObject("board");
var board = gameObj.AddComponent<StringBoard>();
var boardString =
@"0105
1024
0203
2022";
board.SetBoard(boardString);
var boardSize = board.GetMaxBoardSize();
if (!fullBoard)
{
board.CurrentRows -= 1;
}
var validMoves = AbstractBoardTests.GetValidMoves4x4(fullBoard, boardSize);
var actuatorComponent = gameObj.AddComponent<Match3ActuatorComponent>();
var actuator = actuatorComponent.CreateActuators()[0];
var masks = new HashSetActionMask(actuator.ActionSpec);
actuator.WriteDiscreteActionMask(masks);
// Run through all moves and make sure those are the only valid ones
HashSet<int> validIndices = new HashSet<int>();
foreach (var m in validMoves)
{
validIndices.Add(m.MoveIndex);
}
// Valid moves and masked moves should be disjoint
Assert.IsFalse(validIndices.Overlaps(masks.HashSets[0]));
// And they should add up to all the potential moves
Assert.AreEqual(validIndices.Count + masks.HashSets[0].Count, board.NumMoves());
}
}
}