using System;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Extensions.Match3;
namespace Unity.MLAgentsExamples
{
///
/// State of the "game" when showing all steps of the simulation. This is only used outside of training.
/// The state diagram is
///
/// | <--------------------------------------- ^
/// | |
/// v |
/// +--------+ +-------+ +-----+ +------+
/// |Find | ---> |Clear | ---> |Drop | ---> |Fill |
/// |Matches | |Matched| | | |Empty |
/// +--------+ +-------+ +-----+ +------+
///
/// | ^
/// | |
/// v |
///
/// +--------+
/// |Wait for|
/// |Move |
/// +--------+
///
/// The stats advances each "MoveTime" seconds.
///
enum State
{
///
/// Guard value, should never happen.
///
Invalid = -1,
///
/// Look for matches. If there are matches, the next state is ClearMatched, otherwise WaitForMove.
///
FindMatches = 0,
///
/// Remove matched cells and replace them with a placeholder value.
///
ClearMatched = 1,
///
/// Move cells "down" to fill empty space.
///
Drop = 2,
///
/// Replace empty cells with new random values.
///
FillEmpty = 3,
///
/// Request a move from the Agent.
///
WaitForMove = 4,
}
public enum HeuristicQuality
{
///
/// The heuristic will pick any valid move at random.
///
RandomValidMove,
///
/// The heuristic will pick the move that scores the most points.
/// This only looks at the immediate move, and doesn't consider where cells will fall.
///
Greedy
}
public class Match3Agent : Agent
{
[HideInInspector]
public Match3Board Board;
public float MoveTime = 1.0f;
public int MaxMoves = 500;
public HeuristicQuality HeuristicQuality = HeuristicQuality.RandomValidMove;
State m_CurrentState = State.WaitForMove;
float m_TimeUntilMove;
private int m_MovesMade;
private System.Random m_Random;
private const float k_RewardMultiplier = 0.01f;
void Awake()
{
Board = GetComponent();
var seed = Board.RandomSeed == -1 ? gameObject.GetInstanceID() : Board.RandomSeed + 1;
m_Random = new System.Random(seed);
}
public override void OnEpisodeBegin()
{
base.OnEpisodeBegin();
Board.InitSettled();
m_CurrentState = State.FindMatches;
m_TimeUntilMove = MoveTime;
m_MovesMade = 0;
}
private void FixedUpdate()
{
if (Academy.Instance.IsCommunicatorOn)
{
FastUpdate();
}
else
{
AnimatedUpdate();
}
// We can't use the normal MaxSteps system to decide when to end an episode,
// since different agents will make moves at different frequencies (depending on the number of
// chained moves). So track a number of moves per Agent and manually interrupt the episode.
if (m_MovesMade >= MaxMoves)
{
EpisodeInterrupted();
}
}
void FastUpdate()
{
while (true)
{
var hasMatched = Board.MarkMatchedCells();
if (!hasMatched)
{
break;
}
var pointsEarned = Board.ClearMatchedCells();
AddReward(k_RewardMultiplier * pointsEarned);
Board.DropCells();
Board.FillFromAbove();
}
while (!HasValidMoves())
{
// Shuffle the board until we have a valid move.
Board.InitSettled();
}
RequestDecision();
m_MovesMade++;
}
void AnimatedUpdate()
{
m_TimeUntilMove -= Time.deltaTime;
if (m_TimeUntilMove > 0.0f)
{
return;
}
m_TimeUntilMove = MoveTime;
State nextState;
switch (m_CurrentState)
{
case State.FindMatches:
var hasMatched = Board.MarkMatchedCells();
nextState = hasMatched ? State.ClearMatched : State.WaitForMove;
if (nextState == State.WaitForMove)
{
m_MovesMade++;
}
break;
case State.ClearMatched:
var pointsEarned = Board.ClearMatchedCells();
AddReward(k_RewardMultiplier * pointsEarned);
nextState = State.Drop;
break;
case State.Drop:
Board.DropCells();
nextState = State.FillEmpty;
break;
case State.FillEmpty:
Board.FillFromAbove();
nextState = State.FindMatches;
break;
case State.WaitForMove:
while (true)
{
// Shuffle the board until we have a valid move.
bool hasMoves = HasValidMoves();
if (hasMoves)
{
break;
}
Board.InitSettled();
}
RequestDecision();
nextState = State.FindMatches;
break;
default:
throw new ArgumentOutOfRangeException();
}
m_CurrentState = nextState;
}
bool HasValidMoves()
{
foreach (var unused in Board.ValidMoves())
{
return true;
}
return false;
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActions = actionsOut.DiscreteActions;
discreteActions[0] = GreedyMove();
}
int GreedyMove()
{
var pointsByType = new[] { Board.BasicCellPoints, Board.SpecialCell1Points, Board.SpecialCell2Points };
var bestMoveIndex = 0;
var bestMovePoints = -1;
var numMovesAtCurrentScore = 0;
foreach (var move in Board.ValidMoves())
{
var movePoints = HeuristicQuality == HeuristicQuality.Greedy ? EvalMovePoints(move, pointsByType) : 1;
if (movePoints < bestMovePoints)
{
// Worse, skip
continue;
}
if (movePoints > bestMovePoints)
{
// Better, keep
bestMovePoints = movePoints;
bestMoveIndex = move.MoveIndex;
numMovesAtCurrentScore = 1;
}
else
{
// Tied for best - use reservoir sampling to make sure we select from equal moves uniformly.
// See https://en.wikipedia.org/wiki/Reservoir_sampling#Simple_algorithm
numMovesAtCurrentScore++;
var randVal = m_Random.Next(0, numMovesAtCurrentScore);
if (randVal == 0)
{
// Keep the new one
bestMoveIndex = move.MoveIndex;
}
}
}
return bestMoveIndex;
}
int EvalMovePoints(Move move, int[] pointsByType)
{
// Counts the expected points for making the move.
var moveVal = Board.GetCellType(move.Row, move.Column);
var moveSpecial = Board.GetSpecialType(move.Row, move.Column);
var (otherRow, otherCol) = move.OtherCell();
var oppositeVal = Board.GetCellType(otherRow, otherCol);
var oppositeSpecial = Board.GetSpecialType(otherRow, otherCol);
int movePoints = EvalHalfMove(
otherRow, otherCol, moveVal, moveSpecial, move.Direction, pointsByType
);
int otherPoints = EvalHalfMove(
move.Row, move.Column, oppositeVal, oppositeSpecial, move.OtherDirection(), pointsByType
);
return movePoints + otherPoints;
}
int EvalHalfMove(int newRow, int newCol, int newValue, int newSpecial, Direction incomingDirection, int[] pointsByType)
{
// This is a essentially a duplicate of AbstractBoard.CheckHalfMove but also counts the points for the move.
int matchedLeft = 0, matchedRight = 0, matchedUp = 0, matchedDown = 0;
int scoreLeft = 0, scoreRight = 0, scoreUp = 0, scoreDown = 0;
if (incomingDirection != Direction.Right)
{
for (var c = newCol - 1; c >= 0; c--)
{
if (Board.GetCellType(newRow, c) == newValue)
{
matchedLeft++;
scoreLeft += pointsByType[Board.GetSpecialType(newRow, c)];
}
else
break;
}
}
if (incomingDirection != Direction.Left)
{
for (var c = newCol + 1; c < Board.Columns; c++)
{
if (Board.GetCellType(newRow, c) == newValue)
{
matchedRight++;
scoreRight += pointsByType[Board.GetSpecialType(newRow, c)];
}
else
break;
}
}
if (incomingDirection != Direction.Down)
{
for (var r = newRow + 1; r < Board.Rows; r++)
{
if (Board.GetCellType(r, newCol) == newValue)
{
matchedUp++;
scoreUp += pointsByType[Board.GetSpecialType(r, newCol)];
}
else
break;
}
}
if (incomingDirection != Direction.Up)
{
for (var r = newRow - 1; r >= 0; r--)
{
if (Board.GetCellType(r, newCol) == newValue)
{
matchedDown++;
scoreDown += pointsByType[Board.GetSpecialType(r, newCol)];
}
else
break;
}
}
if ((matchedUp + matchedDown >= 2) || (matchedLeft + matchedRight >= 2))
{
// It's a match. Start from counting the piece being moved
var totalScore = pointsByType[newSpecial];
if (matchedUp + matchedDown >= 2)
{
totalScore += scoreUp + scoreDown;
}
if (matchedLeft + matchedRight >= 2)
{
totalScore += scoreLeft + scoreRight;
}
return totalScore;
}
return 0;
}
}
}