浏览代码

Allow gridworld agent to have different goals

/MLA-1734-demo-provider
Arthur Juliani 4 年前
当前提交
6b32ff21
共有 1 个文件被更改,包括 34 次插入2 次删除
  1. 36
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

36
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;

[Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
"masking turned on may not behave optimally when action masking is turned off.")]
public bool maskActions = true;
public GridGoal gridGoal;
const int k_NoAction = 0; // do nothing!
const int k_Up = 1;

public enum GridGoal
{
Plus,
Cross,
}
}
public override void CollectObservations(VectorSensor sensor)
{
Array values = Enum.GetValues(typeof(GridGoal));
sensor.AddOneHotObservation((int)gridGoal, values.Length);
}
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

}
}
private void ProvideReward(GridGoal hitObject)
{
if (gridGoal == hitObject)
{
SetReward(1f);
Debug.Log(1);
}
else
{
SetReward(-1f);
Debug.Log(-1);
}
}
// to be implemented by the developer
public override void OnActionReceived(ActionBuffers actionBuffers)

if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
{
SetReward(1f);
ProvideReward(GridGoal.Plus);
SetReward(-1f);
ProvideReward(GridGoal.Cross);
EndEpisode();
}
}

public override void OnEpisodeBegin()
{
area.AreaReset();
Array values = Enum.GetValues(typeof(GridGoal));
gridGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length));
}
public void FixedUpdate()

正在加载...
取消
保存