Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

208 行
5.7 KiB

using System;
using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;
public class GridAgent : Agent
{
[FormerlySerializedAs("m_Area")]
[Header("Specific to GridWorld")]
public GridArea area;
public float timeBetweenDecisionsAtInference;
float m_TimeSinceDecision;
[Tooltip("Because we want an observation right before making a decision, we can force " +
"a camera to render before making a decision. Place the agentCam here if using " +
"RenderTexture as observations.")]
public Camera renderCamera;
[Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
"masking turned on may not behave optimally when action masking is turned off.")]
public bool maskActions = true;
GoalSensorComponent goalSensor;
public GridGoal gridGoal;
const int k_NoAction = 0; // do nothing!
const int k_Up = 1;
const int k_Down = 2;
const int k_Left = 3;
const int k_Right = 4;
public enum GridGoal
{
Plus,
Cross,
}
EnvironmentParameters m_ResetParams;
public override void Initialize()
{
m_ResetParams = Academy.Instance.EnvironmentParameters;
}
public override void CollectObservations(VectorSensor sensor)
{
Array values = Enum.GetValues(typeof(GridGoal));
int goalNum = (int)gridGoal;
goalSensor = this.GetComponent<GoalSensorComponent>();
goalSensor.AddOneHotGoal(goalNum, values.Length);
}
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
// Mask the necessary actions if selected by the user.
if (maskActions)
{
// Prevents the agent from picking an action that would make it collide with a wall
var positionX = (int)transform.localPosition.x;
var positionZ = (int)transform.localPosition.z;
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
if (positionX == 0)
{
actionMask.WriteMask(0, new[] { k_Left });
}
if (positionX == maxPosition)
{
actionMask.WriteMask(0, new[] { k_Right });
}
if (positionZ == 0)
{
actionMask.WriteMask(0, new[] { k_Down });
}
if (positionZ == maxPosition)
{
actionMask.WriteMask(0, new[] { k_Up });
}
}
}
private void ProvideReward(GridGoal hitObject)
{
if (gridGoal == hitObject)
{
SetReward(1f);
}
else
{
SetReward(-1f);
}
}
// to be implemented by the developer
public override void OnActionReceived(ActionBuffers actionBuffers)
{
AddReward(-0.01f);
var action = actionBuffers.DiscreteActions[0];
var targetPos = transform.position;
switch (action)
{
case k_NoAction:
// do nothing
break;
case k_Right:
targetPos = transform.position + new Vector3(1f, 0, 0f);
break;
case k_Left:
targetPos = transform.position + new Vector3(-1f, 0, 0f);
break;
case k_Up:
targetPos = transform.position + new Vector3(0f, 0, 1f);
break;
case k_Down:
targetPos = transform.position + new Vector3(0f, 0, -1f);
break;
default:
throw new ArgumentException("Invalid action value");
}
var hit = Physics.OverlapBox(
targetPos, new Vector3(0.3f, 0.3f, 0.3f));
if (hit.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
{
transform.position = targetPos;
if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
{
ProvideReward(GridGoal.Plus);
EndEpisode();
}
else if (hit.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
{
ProvideReward(GridGoal.Cross);
EndEpisode();
}
}
}
public override void Heuristic(in ActionBuffers actionsOut)
{
var discreteActionsOut = actionsOut.DiscreteActions;
discreteActionsOut[0] = k_NoAction;
if (Input.GetKey(KeyCode.D))
{
discreteActionsOut[0] = k_Right;
}
if (Input.GetKey(KeyCode.W))
{
discreteActionsOut[0] = k_Up;
}
if (Input.GetKey(KeyCode.A))
{
discreteActionsOut[0] = k_Left;
}
if (Input.GetKey(KeyCode.S))
{
discreteActionsOut[0] = k_Down;
}
}
// to be implemented by the developer
public override void OnEpisodeBegin()
{
area.AreaReset();
Array values = Enum.GetValues(typeof(GridGoal));
gridGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length));
}
public void FixedUpdate()
{
WaitTimeInference();
}
void WaitTimeInference()
{
if (renderCamera != null)
{
renderCamera.Render();
}
if (Academy.Instance.IsCommunicatorOn)
{
RequestDecision();
}
else
{
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference)
{
m_TimeSinceDecision = 0f;
RequestDecision();
}
else
{
m_TimeSinceDecision += Time.fixedDeltaTime;
}
}
}
}