浏览代码

Adding the goal conditioning sensors with the new observation specs

/goal-conditioning/sensors-3-pytest-fix
vincentpierre 3 年前
当前提交
1843345f
共有 12 个文件被更改,包括 169 次插入20 次删除
  1. 1
      com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
  2. 5
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  3. 14
      com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
  4. 10
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  5. 10
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  6. 17
      com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorTest.cs
  7. 14
      com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs
  8. 4
      ml-agents-envs/mlagents_envs/base_env.py
  9. 30
      com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs
  10. 11
      com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta
  11. 62
      com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs
  12. 11
      com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta

1
com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs


EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);

5
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


/// <param name="grayscale">Whether to convert the generated image to grayscale or keep color.</param>
/// <param name="name">The name of the camera sensor.</param>
/// <param name="compression">The compression to apply to the generated image.</param>
/// <param name="observationType">The type of observation.</param>
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression)
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default)
{
m_Camera = camera;
m_Width = width;

var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
}

14
com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs


}
[HideInInspector, SerializeField]
ObservationType m_ObservationType;
/// <summary>
/// The type of the observation.
/// </summary>
public ObservationType SensorObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; UpdateSensor(); }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
int m_ObservationStacks = 1;

/// <returns>The created <see cref="CameraSensor"/> object for this component.</returns>
public override ISensor CreateSensor()
{
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);
if (ObservationStacks != 1)
{

10
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


/// Collected observations contain goal information.
/// </summary>
Goal = 1,
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,
/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
}
/// <summary>

10
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


/// </summary>
/// <param name="observationSize">Number of vector observations.</param>
/// <param name="name">Name of the sensor.</param>
public VectorSensor(int observationSize, string name = null)
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
if (name == null)
if (name == null || name == "")
if (observationType != ObservationType.Default)
{
name += "_goal";
}
m_ObservationSpec = ObservationSpec.Vector(observationSize);
m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType);
}
/// <inheritdoc/>

17
com.unity.ml-agents/Tests/Editor/Sensor/CameraSensorTest.cs


}
}
}
[Test]
public void TestObservationType()
{
var width = 24;
var height = 16;
var camera = Camera.main;
var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Goal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
}
}
}

14
com.unity.ml-agents/Tests/Editor/Sensor/VectorSensorTests.cs


}
[Test]
public void TestObservationType()
{
var sensor = new VectorSensor(1);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Goal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Goal);
}
[Test]
public void TestAddObservationInt()
{
var sensor = new VectorSensor(1);

4
ml-agents-envs/mlagents_envs/base_env.py


DEFAULT = 0
# Observation contains goal information for current task.
GOAL = 1
# Observation contains reward information for current task.
REWARD = 2
# Observation contains a message from another agent.
MESSAGE = 3
class ObservationSpec(NamedTuple):

30
com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs


using UnityEditor;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Editor
{
[CustomEditor(typeof(VectorSensorComponent))]
[CanEditMultipleObjects]
internal class VectorSensorComponentEditor : UnityEditor.Editor
{
public override void OnInspectorGUI()
{
var so = serializedObject;
so.Update();
// Drawing the VectorSensorComponent
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// These fields affect the sensor order or observation size,
// So can't be changed at runtime.
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_observationSize"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
so.ApplyModifiedProperties();
}
}
}

11
com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta


fileFormatVersion: 2
guid: aa0230c3402f04921acdbbdb61f6ff00
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

62
com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs


using UnityEngine;
using UnityEngine.Serialization;
namespace Unity.MLAgents.Sensors
{
[AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)]
public class VectorSensorComponent : SensorComponent
{
/// <summary>
/// Name of the generated <see cref="VectorSensor"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
private string m_SensorName = "VectorSensor";
public int ObservationSize
{
get { return m_observationSize; }
set { m_observationSize = value; }
}
[HideInInspector, SerializeField]
int m_observationSize;
[HideInInspector, SerializeField]
ObservationType m_ObservationType;
VectorSensor m_sensor;
public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; }
}
/// <summary>
/// Creates a VectorSensor.
/// </summary>
/// <returns></returns>
public override ISensor CreateSensor()
{
m_sensor = new VectorSensor(m_observationSize, m_SensorName, m_ObservationType);
return m_sensor;
}
/// <inheritdoc/>
public override int[] GetObservationShape()
{
return new[] { m_observationSize };
}
public VectorSensor GetSensor()
{
return m_sensor;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta


fileFormatVersion: 2
guid: 38b7cc1f5819445aa85e9a9b054552dc
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存