|
|
|
|
|
|
using UnityEngine; |
|
|
|
using System.Linq; |
|
|
|
using Unity.MLAgents; |
|
|
|
using Unity.MLAgents.Sensors; |
|
|
|
using Unity.MLAgents.Actuators; |
|
|
|
using UnityEngine.Serialization; |
|
|
|
|
|
|
|
|
|
|
[Tooltip("Selecting will turn on action masking. Note that a model trained with action " + |
|
|
|
"masking turned on may not behave optimally when action masking is turned off.")] |
|
|
|
public bool maskActions = true; |
|
|
|
|
|
|
|
|
|
|
|
public GridGoal gridGoal; |
|
|
|
|
|
|
|
const int k_NoAction = 0; // do nothing!
|
|
|
|
const int k_Up = 1; |
|
|
|
|
|
|
|
|
|
|
public enum GridGoal |
|
|
|
{ |
|
|
|
Plus, |
|
|
|
Cross, |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
public override void CollectObservations(VectorSensor sensor) |
|
|
|
{ |
|
|
|
Array values = Enum.GetValues(typeof(GridGoal)); |
|
|
|
sensor.AddOneHotObservation((int)gridGoal, values.Length); |
|
|
|
} |
|
|
|
|
|
|
|
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask) |
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
private void ProvideReward(GridGoal hitObject) |
|
|
|
{ |
|
|
|
if (gridGoal == hitObject) |
|
|
|
{ |
|
|
|
SetReward(1f); |
|
|
|
Debug.Log(1); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
SetReward(-1f); |
|
|
|
Debug.Log(-1); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// to be implemented by the developer
|
|
|
|
public override void OnActionReceived(ActionBuffers actionBuffers) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1) |
|
|
|
{ |
|
|
|
SetReward(1f); |
|
|
|
ProvideReward(GridGoal.Plus); |
|
|
|
SetReward(-1f); |
|
|
|
ProvideReward(GridGoal.Cross); |
|
|
|
EndEpisode(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
public override void OnEpisodeBegin() |
|
|
|
{ |
|
|
|
area.AreaReset(); |
|
|
|
Array values = Enum.GetValues(typeof(GridGoal)); |
|
|
|
gridGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length)); |
|
|
|
} |
|
|
|
|
|
|
|
public void FixedUpdate() |
|
|
|