浏览代码

fix commit

/gridworld-custom-obs
Chris Elion 4 年前
当前提交
ed8a249c
共有 3 个文件被更改,包括 159 次插入20 次删除
  1. 16
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  2. 67
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs
  3. 96
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSensorComponent.cs

16
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


m_ResetParams = Academy.Instance.EnvironmentParameters;
}
(int, int) LocalCoordinates
{
get { return ((int)transform.localPosition.x, (int)transform.localPosition.z); }
}
public override void CollectDiscreteActionMasks(DiscreteActionMasker actionMasker)
{
// Mask the necessary actions if selected by the user.

var positionX = (int)transform.position.x;
var positionZ = (int)transform.position.z;
var (positionX, positionZ) = LocalCoordinates;
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
if (positionX == 0)

// to be implemented by the developer
public override void OnActionReceived(float[] vectorAction)
{
var (positionX, positionZ) = LocalCoordinates;
var (oldPositionX, oldPositionZ) = (positionX, positionZ);
AddReward(-0.01f);
var action = Mathf.FloorToInt(vectorAction[0]);

break;
case k_Right:
targetPos = transform.position + new Vector3(1f, 0, 0f);
positionX++;
positionX--;
positionZ++;
positionZ--;
break;
default:
throw new ArgumentException("Invalid action value");

if (hit.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
{
transform.position = targetPos;
area.board[oldPositionX, oldPositionZ] = GridArea.CellType.Empty;
area.board[positionX, positionZ] = GridArea.CellType.Agent;
if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
{

67
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridArea.cs


using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using UnityScript.Lang;
public enum CellType
{
Empty = 0,
Goal = 1,
Pit = 2,
Agent = 3,
}
public int[] players;
public CellType[] players;
[HideInInspector]
public CellType[,] board;
public GameObject trueAgent;

public GameObject pitPref;
GameObject[] m_Objects;
Dictionary<CellType, GameObject> m_Objects;
GameObject m_Plane;
GameObject m_Sn;

{
m_ResetParams = Academy.Instance.EnvironmentParameters;
m_Objects = new[] { goalPref, pitPref };
m_Objects = new Dictionary<CellType, GameObject>
{
{ CellType.Goal, goalPref },
{ CellType.Pit, pitPref }
};
m_AgentCam = transform.Find("agentCam").GetComponent<Camera>();
m_AgentCam = transform.Find("agentCam")?.GetComponent<Camera>();
actorObjs = new List<GameObject>();

void SetEnvironment()
{
transform.position = m_InitialPosition * (m_ResetParams.GetWithDefault("gridSize", 5f) + 1);
var playersList = new List<int>();
var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f);
transform.position = m_InitialPosition * (gridSize + 1);
var playersList = new List<CellType>();
playersList.Add(1);
playersList.Add(CellType.Pit);
playersList.Add(0);
playersList.Add(CellType.Goal);
var gridSize = (int)m_ResetParams.GetWithDefault("gridSize", 5f);
m_Plane.transform.localScale = new Vector3(gridSize / 10.0f, 1f, gridSize / 10.0f);
m_Plane.transform.localPosition = new Vector3((gridSize - 1) / 2f, -0.5f, (gridSize - 1) / 2f);
m_Sn.transform.localScale = new Vector3(1, 1, gridSize + 2);

m_Se.transform.localPosition = new Vector3(gridSize, 0.0f, (gridSize - 1) / 2f);
m_Sw.transform.localPosition = new Vector3(-1, 0.0f, (gridSize - 1) / 2f);
m_AgentCam.orthographicSize = (gridSize) / 2f;
m_AgentCam.transform.localPosition = new Vector3((gridSize - 1) / 2f, gridSize + 1f, (gridSize - 1) / 2f);
if(m_AgentCam != null)
{
m_AgentCam.orthographicSize = (gridSize) / 2f;
m_AgentCam.transform.localPosition = new Vector3((gridSize - 1) / 2f, gridSize + 1f, (gridSize - 1) / 2f);
}
if (board == null)
{
board = new CellType[gridSize, gridSize];
}
else
{
for (var i = 0; i < gridSize; i++)
{
for (var j = 0; j < gridSize; j++)
{
board[i, j] = CellType.Empty;
}
}
}
}
public void AreaReset()

for (var i = 0; i < players.Length; i++)
{
var x = (numbersA[i]) / gridSize;
var y = (numbersA[i]) % gridSize;
var z = (numbersA[i]) % gridSize;
actorObj.transform.localPosition = new Vector3(x, -0.25f, y);
actorObj.transform.localPosition = new Vector3(x, -0.25f, z);
board[x, z] = players[i];
var yA = (numbersA[players.Length]) % gridSize;
trueAgent.transform.localPosition = new Vector3(xA, -0.25f, yA);
var zA = (numbersA[players.Length]) % gridSize;
trueAgent.transform.localPosition = new Vector3(xA, -0.25f, zA);
board[xA, zA] = CellType.Agent;
}
}

96
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridSensorComponent.cs


using Unity.MLAgents.Sensors;
public class GridSensorComponent
public class GridSensorComponent : SensorComponent
public GridArea gridArea;
int pixelsPerCell = 8;
// TODO use grid size from env parameters
int gridSize = 5;
/// <summary>
/// Creates a BasicSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
return new GridSensor(gridArea, gridSize, pixelsPerCell);
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { gridSize * pixelsPerCell, gridSize * pixelsPerCell, 4 };
}
}
public class GridSensor : ISensor
{
GridArea m_GridArea;
int m_PixlesPerCell;
int m_GridSize;
int[] m_Shape;
const int k_NumChannels = 4;
public GridSensor(GridArea gridArea, int gridSize, int pixelsPerCell)
{
m_GridArea = gridArea;
m_GridSize = gridSize;
m_PixlesPerCell = pixelsPerCell;
m_Shape = new []{ gridSize * pixelsPerCell, gridSize * pixelsPerCell, k_NumChannels };
}
/// <inheritdoc/>
public int[] GetObservationShape()
{
return m_Shape;
}
/// <inheritdoc/>
public int Write(ObservationWriter writer)
{
var board = m_GridArea.board;
var height = m_GridSize * m_PixlesPerCell;
var width = m_GridSize * m_PixlesPerCell;
for (var h = 0; h < height; h++)
{
var i = h / m_PixlesPerCell;
for (var w = 0; w < width; w++)
{
var j = w / m_PixlesPerCell;
var cellVal = board[i, j];
for (var c = 0; c < k_NumChannels; c++)
{
writer[h, w, c] = (c == (int)cellVal) ? 1.0f : 0.0f;
}
}
}
var numWritten = height * width * k_NumChannels;
return numWritten;
}
/// <inheritdoc/>
public byte[] GetCompressedObservation()
{
return null;
}
/// <inheritdoc/>
public void Update() { }
/// <inheritdoc/>
public void Reset() { }
/// <inheritdoc/>
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
/// <inheritdoc/>
public string GetName()
{
return "GridSensor";
}
}
}
正在加载...
取消
保存