浏览代码

Merge pull request #4733 from Unity-Technologies/gc-food-goals

Adds goal signal to GridWorld environment
/goal-conditioning
GitHub 4 年前
当前提交
76faf383
共有 3 个文件被更改,包括 70 次插入15 次删除
  1. 19
      Project/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab
  2. 31
      Project/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
  3. 35
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs

19
Project/Assets/ML-Agents/Examples/GridWorld/Prefabs/Area.prefab


m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
vectorObservationSize: 0
numStackedVectorObservations: 1
vectorActionSize: 05000000
vectorActionDescriptions: []
vectorActionSpaceType: 0
m_Model: {fileID: 11400000, guid: a812f1ce7763a4a0c912717f3594fe20, type: 3}
VectorObservationSize: 2
NumStackedVectorObservations: 1
VectorActionSize: 05000000
VectorActionDescriptions: []
VectorActionSpaceType: 0
m_Model: {fileID: 0}
m_UseChildActuators: 1
m_ObservableAttributeHandling: 0
--- !u!114 &114650561397225712
MonoBehaviour:
m_ObjectHideFlags: 0

agentParameters:
maxStep: 0
hasUpgradedFromAgentParameters: 1
maxStep: 100
MaxStep: 100
gridGoal: 0
--- !u!114 &114889700908650620
MonoBehaviour:
m_ObjectHideFlags: 0

m_Width: 84
m_Height: 64
m_Grayscale: 0
m_ObservationStacks: 1
m_Compression: 1
--- !u!114 &7980686505185502968
MonoBehaviour:

m_Script: {fileID: 11500000, guid: 3a6da8f78a394c6ab027688eab81e04d, type: 3}
m_Name:
m_EditorClassIdentifier:
debugCommandLineOverride:
--- !u!1 &1625008366184734
GameObject:
m_ObjectHideFlags: 0

31
Project/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity


agentParameters:
maxStep: 100
hasUpgradedFromAgentParameters: 1
maxStep: 100
MaxStep: 100
gridGoal: 0
--- !u!65 &125487788
BoxCollider:
m_ObjectHideFlags: 0

m_Name:
m_EditorClassIdentifier:
m_BrainParameters:
vectorObservationSize: 0
numStackedVectorObservations: 1
vectorActionSize: 05000000
vectorActionDescriptions: []
vectorActionSpaceType: 0
VectorObservationSize: 2
NumStackedVectorObservations: 1
VectorActionSize: 05000000
VectorActionDescriptions: []
VectorActionSpaceType: 0
m_Model: {fileID: 11400000, guid: a812f1ce7763a4a0c912717f3594fe20, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0

m_UseChildActuators: 1
m_ObservableAttributeHandling: 0
--- !u!114 &125487791
MonoBehaviour:
m_ObjectHideFlags: 0

m_RenderTexture: {fileID: 8400000, guid: 114608d5384404f89bff4b6f88432958, type: 2}
m_SensorName: RenderTextureSensor
m_Grayscale: 0
m_ObservationStacks: 1
m_Compression: 1
--- !u!1 &260425459
GameObject:

type: 3}
propertyPath: compression
value: 0
objectReference: {fileID: 0}
- target: {fileID: 114889700908650620, guid: 5c2bd19e4bbda4991b74387ca5d28156,
type: 3}
propertyPath: m_Compression
value: 0
objectReference: {fileID: 0}
- target: {fileID: 114935253044749092, guid: 5c2bd19e4bbda4991b74387ca5d28156,
type: 3}
propertyPath: m_BrainParameters.VectorObservationSize
value: 2
objectReference: {fileID: 0}
- target: {fileID: 114935253044749092, guid: 5c2bd19e4bbda4991b74387ca5d28156,
type: 3}
propertyPath: m_Model
value:
objectReference: {fileID: 0}
m_RemovedComponents: []
m_SourcePrefab: {fileID: 100100000, guid: 5c2bd19e4bbda4991b74387ca5d28156, type: 3}

35
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


using UnityEngine;
using System.Linq;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
using UnityEngine.Serialization;

"masking turned on may not behave optimally when action masking is turned off.")]
public bool maskActions = true;
public GridGoal gridGoal;
const int k_NoAction = 0; // do nothing!
const int k_Up = 1;
const int k_Down = 2;

public enum GridGoal
{
Plus,
Cross,
}
}
public override void CollectObservations(VectorSensor sensor)
{
Array values = Enum.GetValues(typeof(GridGoal));
int goalNum = (int)gridGoal;
sensor.AddOneHotObservation(goalNum, values.Length);
}
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

}
}
private void ProvideReward(GridGoal hitObject)
{
if (gridGoal == hitObject)
{
SetReward(1f);
}
else
{
SetReward(-1f);
}
}
// to be implemented by the developer
public override void OnActionReceived(ActionBuffers actionBuffers)

if (hit.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
{
SetReward(1f);
ProvideReward(GridGoal.Plus);
SetReward(-1f);
ProvideReward(GridGoal.Cross);
EndEpisode();
}
}

public override void OnEpisodeBegin()
{
area.AreaReset();
Array values = Enum.GetValues(typeof(GridGoal));
gridGoal = (GridGoal)values.GetValue(UnityEngine.Random.Range(0, values.Length));
}
public void FixedUpdate()

正在加载...
取消
保存