浏览代码

Merge pull request #5008 from Unity-Technologies/multi-goal-conditioning

Support multiple goals in networkbody
/goal-conditioning/new
GitHub 3 年前
当前提交
001990af
共有 12 个文件被更改,包括 90 次插入124 次删除
  1. 5
      Project/Assets/ML-Agents/Examples/Crawler/Prefabs/DynamicPlatform.prefab
  2. 4
      Project/Assets/ML-Agents/Examples/Crawler/Scenes/CrawlerDynamicTarget.unity
  3. 8
      Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs
  4. 13
      Project/Assets/ML-Agents/Examples/GoalNav/Prefabs/VisualArea.prefab
  5. 8
      Project/Assets/ML-Agents/Examples/GoalNav/Scripts/GoalNavAgent.cs
  6. 6
      Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs
  7. 6
      Project/Assets/ML-Agents/Examples/PushJump/Scripts/WJPBAgent.cs
  8. 6
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  9. 42
      ml-agents/mlagents/trainers/torch/networks.py
  10. 41
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/VectorSensorComponent.cs
  11. 75
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs
  12. 0
      /Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/VectorSensorComponent.cs.meta

5
Project/Assets/ML-Agents/Examples/Crawler/Prefabs/DynamicPlatform.prefab


m_Modification:
m_TransformParent: {fileID: 6810587057221831324}
m_Modifications:
- target: {fileID: 4845971001715176648, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
type: 3}
propertyPath: m_BehaviorType
value: 0
objectReference: {fileID: 0}
- target: {fileID: 4845971001715176651, guid: 0456c89e8c9c243d595b039fe7aa0bf9,
type: 3}
propertyPath: m_LocalPosition.x

4
Project/Assets/ML-Agents/Examples/Crawler/Scenes/CrawlerDynamicTarget.unity


m_Name:
m_EditorClassIdentifier:
target: {fileID: 380947237}
smoothingTime: 0
--- !u!1001 &1481808307
PrefabInstance:
m_ObjectHideFlags: 0

type: 3}
propertyPath: m_Model
value:
objectReference: {fileID: 11400000, guid: 53b79c6f40aeb46e693e7c1822ff1047,
type: 3}
objectReference: {fileID: 0}
- target: {fileID: 6810587057221831324, guid: 0058b366f9d6d44a3ba35beb06b0174b,
type: 3}
propertyPath: m_LocalPosition.x

8
Project/Assets/ML-Agents/Examples/Crawler/Scripts/CrawlerAgent.cs


"Static - The agent will run towards a static target. "
)]
public CrawlerAgentBehaviorType typeOfCrawler;
GoalSensorComponent goalSensor;
VectorSensorComponent goalSensor;
//Crawler Brains
//A different brain will be used depending on the CrawlerAgentBehaviorType selected

/// </summary>
public override void CollectObservations(VectorSensor sensor)
{
goalSensor = this.GetComponent<GoalSensorComponent>();
goalSensor = this.GetComponent<VectorSensorComponent>();
var cubeForward = m_OrientationCube.transform.forward;
//velocity we want to match

//avg body vel relative to cube
sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(avgVel));
//vel goal relative to cube
goalSensor.AddGoal(m_OrientationCube.transform.InverseTransformDirection(velGoal));
goalSensor.sensor.AddObservation(m_OrientationCube.transform.InverseTransformDirection(velGoal));
goalSensor.AddGoal(m_OrientationCube.transform.InverseTransformPoint(m_Target.transform.position));
goalSensor.sensor.AddObservation(m_OrientationCube.transform.InverseTransformPoint(m_Target.transform.position));
RaycastHit hit;
float maxRaycastDist = 10;

13
Project/Assets/ML-Agents/Examples/GoalNav/Prefabs/VisualArea.prefab


VectorActionDescriptions: []
VectorActionSpaceType: 0
hasUpgradedBrainParametersWithActionSpec: 1
m_Model: {fileID: 0}
m_Model: {fileID: 11400000, guid: 3fd57dbc0c08d4cee9ccc6edb9451017, type: 3}
m_InferenceDevice: 0
m_BehaviorType: 0
m_BehaviorName: GoalNav

m_Extent: {x: 0, y: 0, z: 0}
goalObject: {fileID: 704293525323185430}
obstacleObject: {fileID: 7160026636022562747}
useVectorObs: 1
useVectorObs: 0
--- !u!114 &5692574058063671958
MonoBehaviour:
m_ObjectHideFlags: 0

m_Component:
- component: {fileID: 5694061835409238912}
- component: {fileID: 5694061835409238914}
- component: {fileID: 5694061835409238915}
m_Layer: 0
m_Name: Camera
m_TagString: Untagged

m_OcclusionCulling: 1
m_StereoConvergence: 10
m_StereoSeparation: 0.022
--- !u!81 &5694061835409238915
AudioListener:
m_ObjectHideFlags: 0
m_CorrespondingSourceObject: {fileID: 0}
m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 5694061835409238913}
m_Enabled: 1
--- !u!1 &5694061835435845149
GameObject:
m_ObjectHideFlags: 0

8
Project/Assets/ML-Agents/Examples/GoalNav/Scripts/GoalNavAgent.cs


PushBlockSettings m_PushBlockSettings;
GoalSensorComponent goalSensor;
VectorSensorComponent goalSensor;
public GameObject goalObject;
public GameObject obstacleObject;

public override void CollectObservations(VectorSensor sensor)
{
goalSensor = this.GetComponent<GoalSensorComponent>();
goalSensor = this.GetComponent<VectorSensorComponent>();
goalSensor.AddGoal(goalLoc / 10f);
goalSensor.sensor.AddObservation(goalLoc / 10f);
goalSensor.AddGoal(obstacleLoc / 10f);
goalSensor.sensor.AddObservation(obstacleLoc / 10f);
}

6
Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs


"a camera to render before making a decision. Place the agentCam here if using " +
"RenderTexture as observations.")]
public Camera renderCamera;
GoalSensorComponent goalSensor;
VectorSensorComponent goalSensor;
public enum GridGoal
{

{
Array values = Enum.GetValues(typeof(GridGoal));
int goalNum = (int)gridGoal;
goalSensor = this.GetComponent<GoalSensorComponent>();
goalSensor.AddOneHotGoal(goalNum, values.Length);
goalSensor = this.GetComponent<VectorSensorComponent>();
goalSensor.sensor.AddOneHotObservation(goalNum, values.Length);
}
public override void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

6
Project/Assets/ML-Agents/Examples/PushJump/Scripts/WJPBAgent.cs


public GameObject ground;
public GameObject spawnArea;
Bounds m_SpawnAreaBounds;
GoalSensorComponent goalSensor;
VectorSensorComponent goalSensor;
public GameObject wallJumpGoal;
public GameObject pushBlockGoal;

sensor.AddObservation(agentPos / 20f);
sensor.AddObservation(DoGroundCheck(true) ? 1 : 0);
goalSensor = this.GetComponent<GoalSensorComponent>();
goalSensor.AddGoal(m_GoalOneHot);
goalSensor = this.GetComponent<VectorSensorComponent>();
goalSensor.sensor.AddObservation(m_GoalOneHot);
}
/// <summary>

6
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


List<float> m_Observations;
int[] m_Shape;
string m_Name;
ObservationType m_ObservationType;
/// <summary>
/// Initializes the sensor.

public VectorSensor(int observationSize, string name = null)
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
{
if (name == null)
{

m_Observations = new List<float>(observationSize);
m_Name = name;
m_Shape = new[] { observationSize };
m_ObservationType = observationType;
}
/// <inheritdoc/>

/// <inheritdoc/>
public virtual ObservationType GetObservationType()
{
return ObservationType.Default;
return m_ObservationType;
}
/// <inheritdoc/>

42
ml-agents/mlagents/trainers/torch/networks.py


memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
encodes = []
goal_signal = None
obs_encodes = []
goal_encodes = []
if (
self.obs_types[idx] == ObservationType.DEFAULT
or self.conditioning_type == ConditioningType.DEFAULT
):
encodes.append(processed_obs)
elif (
self.obs_types[idx] == ObservationType.GOAL
and self.conditioning_type != ConditioningType.DEFAULT
):
if goal_signal is not None:
raise Exception("TODO : Cannot currently handle more than one goal")
goal_signal = processed_obs
if len(encodes) == 0:
if self.obs_types[idx] == ObservationType.DEFAULT:
obs_encodes.append(processed_obs)
elif self.obs_types[idx] == ObservationType.GOAL:
goal_encodes.append(processed_obs)
else:
raise Exception(
"TODO : Something other than a goal or observation was passed to the agent."
)
if self.conditioning_type == ConditioningType.DEFAULT:
obs_encodes = obs_encodes + goal_encodes
goal_encodes = []
if len(obs_encodes) == 0:
inputs = torch.cat(encodes + [actions], dim=-1)
obs_inputs = torch.cat(obs_encodes + [actions], dim=-1)
inputs = torch.cat(encodes, dim=-1)
obs_inputs = torch.cat(obs_encodes, dim=-1)
if goal_signal is None:
encoding = self.linear_encoder(inputs)
if len(goal_encodes) == 0:
encoding = self.linear_encoder(obs_inputs)
encoding = self.linear_encoder(inputs, goal_signal)
goal_inputs = torch.cat(goal_encodes, dim=-1)
encoding = self.linear_encoder(obs_inputs, goal_inputs)
if self.use_lstm:
# Resize to (batch, sequence length, encoding size)

41
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/VectorSensorComponent.cs


using System.Collections.Generic;
using System.Collections.ObjectModel;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class VectorSensorComponent : SensorComponent
{
int m_observationSize;
ObservationType m_ObservationType;
public int ObservationSize
{
get { return m_observationSize; }
set { m_observationSize = value; }
}
public VectorSensor sensor;
public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; }
}
/// <summary>
/// Creates a VectorSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
sensor = new VectorSensor(m_observationSize, observationType: m_ObservationType);
return sensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { m_observationSize };
}
}

75
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs


using System.Collections.Generic;
using System.Collections.ObjectModel;
using Unity.MLAgents.Sensors;
using UnityEngine;
public class GoalSensorComponent : SensorComponent
{
public int observationSize;
public GoalSensor goalSensor;
/// <summary>
/// Creates a GoalSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
goalSensor = new GoalSensor(observationSize);
return goalSensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { observationSize };
}
public void AddGoal(IEnumerable<float> goal)
{
if (goalSensor != null)
{
goalSensor.AddObservation(goal);
}
}
public void AddGoal(float goal)
{
if (goalSensor != null)
{
goalSensor.AddObservation(goal);
}
}
public void AddOneHotGoal(int goal, int range)
{
if (goalSensor != null)
{
goalSensor.AddOneHotObservation(goal, range);
}
}
public void AddGoal(Vector3 goal)
{
if (goalSensor != null)
{
goalSensor.AddObservation(goal);
}
}
}
public class GoalSensor : VectorSensor
{
public GoalSensor(int observationSize, string name = null) : base(observationSize)
{
if (name == null)
{
name = $"GoalSensor_size{observationSize}";
}
}
public override ObservationType GetObservationType()
{
return ObservationType.Goal;
}
}

/Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/GoalSensorComponent.cs.meta → /Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/VectorSensorComponent.cs.meta

正在加载...
取消
保存