您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
208 行
5.7 KiB
208 行
5.7 KiB
using System;
|
|
using UnityEngine;
|
|
using System.Linq;
|
|
using Unity.MLAgents;
|
|
using Unity.MLAgents.Sensors;
|
|
using Unity.MLAgents.Actuators;
|
|
using UnityEngine.Serialization;
|
|
|
|
public class GridAgent : Agent
|
|
{
|
|
[FormerlySerializedAs("m_Area")]
|
|
[Header("Specific to GridWorld")]
|
|
public GridArea area;
|
|
public float timeBetweenDecisionsAtInference;
|
|
float m_TimeSinceDecision;
|
|
|
|
[Tooltip("Because we want an observation right before making a decision, we can force " +
|
|
"a camera to render before making a decision. Place the agentCam here if using " +
|
|
"RenderTexture as observations.")]
|
|
public Camera renderCamera;
|
|
|
|
[Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
|
|
"masking turned on may not behave optimally when action masking is turned off.")]
|
|
public bool maskActions = true;
|
|
|
|
GoalSensorComponent goalSensor;
|
|
|
|
public GridGoal gridGoal;
|
|
|
|
const int k_NoAction = 0; // do nothing!
|
|
const int k_Up = 1;
|
|
const int k_Down = 2;
|
|
const int k_Left = 3;
|
|
const int k_Right = 4;
|
|
|
|
public enum GridGoal
|
|
{
|
|
Plus,
|
|
Cross,
|
|
}
|
|
|
|
EnvironmentParameters m_ResetParams;
|
|
|
|
public override void Initialize()
|
|
{
|
|
m_ResetParams = Academy.Instance.EnvironmentParameters;
|
|
}
|
|
|
|
public override void CollectObservations(VectorSensor sensor)
|
|
{
|
|
Array values = Enum.GetValues(typeof(GridGoal));
|
|
int goalNum = (int)gridGoal;
|
|
goalSensor = this.GetComponent<GoalSensorComponent>();
|
|
goalSensor.AddOneHotGoal(goalNum, values.Length);
|
|
}
|
|
|
|
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
|
|
{
|
|
// Mask the necessary actions if selected by the user.
|
|
if (maskActions)
|
|
{
|
|
// Prevents the agent from picking an action that would make it collide with a wall
|
|
var positionX = (int)transform.localPosition.x;
|
|
var positionZ = (int)transform.localPosition.z;
|
|
var maxPosition = (int)m_ResetParams.GetWithDefault("gridSize", 5f) - 1;
|
|
|
|
if (positionX == 0)
|
|
{
|
|
actionMask.WriteMask(0, new[] { k_Left });
|
|
}
|
|
|
|
if (positionX == maxPosition)
|
|
{
|
|
actionMask.WriteMask(0, new[] { k_Right });
|
|
}
|
|
|
|
if (positionZ == 0)
|
|
{
|
|
actionMask.WriteMask(0, new[] { k_Down });
|
|
}
|
|
|
|
if (positionZ == maxPosition)
|
|
{
|
|
actionMask.WriteMask(0, new[] { k_Up });
|
|
}
|
|
}
|
|
}
|
|
|
|
private void ProvideReward(GridGoal hitObject)
|
|
{
|
|
if (gridGoal == hitObject)
|
|
{
|
|
SetReward(1f);
|
|
}
|
|
else
|
|
{
|
|
SetReward(-1f);
|
|
}
|
|
}
|
|
|
|
// to be implemented by the developer
|
|
public override void OnActionReceived(ActionBuffers actionBuffers)
|
|
|
|
{
|
|
AddReward(-0.01f);
|
|
var action = actionBuffers.DiscreteActions[0];
|
|
|
|
var targetPos = transform.position;
|
|
switch (action)
|
|
{
|
|
case k_NoAction:
|
|
// do nothing
|
|
break;
|
|
case k_Right:
|
|
targetPos = transform.position + new Vector3(1f, 0, 0f);
|
|
break;
|
|
case k_Left:
|
|
targetPos = transform.position + new Vector3(-1f, 0, 0f);
|
|
break;
|
|
case k_Up:
|
|
targetPos = transform.position + new Vector3(0f, 0, 1f);
|
|
break;
|
|
case k_Down:
|
|
targetPos = transform.position + new Vector3(0f, 0, -1f);
|
|
break;
|
|
default:
|
|
throw new ArgumentException("Invalid action value");
|
|
}
|
|
|
|
var hit = Physics.OverlapBox(
|
|
targetPos, new Vector3(0.3f, 0.3f, 0.3f));
|
|
if (hit.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
|
|
{
|
|
transform.position = targetPos;
|
|
|
|
if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
|
|
{
|
|
ProvideReward(GridGoal.Plus);
|
|
EndEpisode();
|
|
}
|
|
else if (hit.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
|
|
{
|
|
ProvideReward(GridGoal.Cross);
|
|
EndEpisode();
|
|
}
|
|
}
|
|
}
|
|
|
|
public override void Heuristic(in ActionBuffers actionsOut)
|
|
{
|
|
var discreteActionsOut = actionsOut.DiscreteActions;
|
|
discreteActionsOut[0] = k_NoAction;
|
|
if (Input.GetKey(KeyCode.D))
|
|
{
|
|
discreteActionsOut[0] = k_Right;
|
|
}
|
|
if (Input.GetKey(KeyCode.W))
|
|
{
|
|
discreteActionsOut[0] = k_Up;
|
|
}
|
|
if (Input.GetKey(KeyCode.A))
|
|
{
|
|
discreteActionsOut[0] = k_Left;
|
|
}
|
|
if (Input.GetKey(KeyCode.S))
|
|
{
|
|
discreteActionsOut[0] = k_Down;
|
|
}
|
|
}
|
|
|
|
// to be implemented by the developer
|
|
public override void OnEpisodeBegin()
|
|
{
|
|
area.AreaReset();
|
|
Array values = Enum.GetValues(typeof(GridGoal));
|
|
gridGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length));
|
|
}
|
|
|
|
public void FixedUpdate()
|
|
{
|
|
WaitTimeInference();
|
|
}
|
|
|
|
void WaitTimeInference()
|
|
{
|
|
if (renderCamera != null)
|
|
{
|
|
renderCamera.Render();
|
|
}
|
|
|
|
if (Academy.Instance.IsCommunicatorOn)
|
|
{
|
|
RequestDecision();
|
|
}
|
|
else
|
|
{
|
|
if (m_TimeSinceDecision >= timeBetweenDecisionsAtInference)
|
|
{
|
|
m_TimeSinceDecision = 0f;
|
|
RequestDecision();
|
|
}
|
|
else
|
|
{
|
|
m_TimeSinceDecision += Time.fixedDeltaTime;
|
|
}
|
|
}
|
|
}
|
|
}
|