浏览代码

Adding the goal conditioning sensors with the new observation specs (#5159)

* Fixing networks.py for the merge

* fix compile error

* Adding the goal conditioning sensors with the new observation specs

* addressing feedback

* I forgot to change the m_observationType

* Renaming Goal to GoalSignal (#5190)

* Renaming GOAL to GOAL_SIGNAL

* VectorSensorComponent to use new API

* Adding docstrings

* verbose pytest on github action

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/check-for-ModelOverriders
GitHub 3 年前
当前提交
c37cfac1
共有 19 个文件被更改,包括 194 次插入55 次删除
  1. 2
      .github/workflows/pytest.yml
  2. 1
      com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs
  3. 12
      com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs
  4. 5
      com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs
  5. 14
      com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs
  6. 12
      com.unity.ml-agents/Runtime/Sensors/ISensor.cs
  7. 10
      com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs
  8. 17
      com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs
  9. 14
      com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs
  10. 6
      ml-agents-envs/mlagents_envs/base_env.py
  11. 18
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py
  12. 8
      ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi
  13. 2
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  14. 2
      ml-agents/mlagents/trainers/torch/networks.py
  15. 6
      protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto
  16. 30
      com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs
  17. 11
      com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta
  18. 68
      com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs
  19. 11
      com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta

2
.github/workflows/pytest.yml


pip freeze > pip_versions-${{ matrix.python-version }}.txt
cat pip_versions-${{ matrix.python-version }}.txt
- name: Run pytest
run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings
run: pytest --cov=ml-agents --cov=ml-agents-envs --cov=gym-unity --cov-report html --junitxml=junit/test-results-${{ matrix.python-version }}.xml -p no:warnings -v
- name: Upload pytest test results
uses: actions/upload-artifact@v2
with:

1
com.unity.ml-agents/Editor/CameraSensorComponentEditor.cs


EditorGUILayout.PropertyField(so.FindProperty("m_Height"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_Grayscale"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationStacks"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
EditorGUILayout.PropertyField(so.FindProperty("m_Compression"), true);

12
com.unity.ml-agents/Runtime/Grpc/CommunicatorObjects/Observation.cs


"b25fdHlwZRgHIAEoDjIqLmNvbW11bmljYXRvcl9vYmplY3RzLk9ic2VydmF0",
"aW9uVHlwZVByb3RvEgwKBG5hbWUYCCABKAkaGQoJRmxvYXREYXRhEgwKBGRh",
"dGEYASADKAJCEgoQb2JzZXJ2YXRpb25fZGF0YSopChRDb21wcmVzc2lvblR5",
"cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqRgoUT2JzZXJ2YXRpb25UeXBl",
"UHJvdG8SCwoHREVGQVVMVBAAEggKBEdPQUwQARIKCgZSRVdBUkQQAhILCgdN",
"RVNTQUdFEANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVj",
"dHNiBnByb3RvMw=="));
"cGVQcm90bxIICgROT05FEAASBwoDUE5HEAEqQAoUT2JzZXJ2YXRpb25UeXBl",
"UHJvdG8SCwoHREVGQVVMVBAAEg8KC0dPQUxfU0lHTkFMEAEiBAgCEAIiBAgD",
"EANCJaoCIlVuaXR5Lk1MQWdlbnRzLkNvbW11bmljYXRvck9iamVjdHNiBnBy",
"b3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(new[] {typeof(global::Unity.MLAgents.CommunicatorObjects.CompressionTypeProto), typeof(global::Unity.MLAgents.CommunicatorObjects.ObservationTypeProto), }, new pbr::GeneratedClrTypeInfo[] {

internal enum ObservationTypeProto {
[pbr::OriginalName("DEFAULT")] Default = 0,
[pbr::OriginalName("GOAL")] Goal = 1,
[pbr::OriginalName("REWARD")] Reward = 2,
[pbr::OriginalName("MESSAGE")] Message = 3,
[pbr::OriginalName("GOAL_SIGNAL")] GoalSignal = 1,
}
#endregion

5
com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs


/// <param name="grayscale">Whether to convert the generated image to grayscale or keep color.</param>
/// <param name="name">The name of the camera sensor.</param>
/// <param name="compression">The compression to apply to the generated image.</param>
/// <param name="observationType">The type of observation.</param>
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression)
Camera camera, int width, int height, bool grayscale, string name, SensorCompressionType compression, ObservationType observationType = ObservationType.Default)
{
m_Camera = camera;
m_Width = width;

var channels = grayscale ? 1 : 3;
m_ObservationSpec = ObservationSpec.Visual(height, width, channels);
m_ObservationSpec = ObservationSpec.Visual(height, width, channels, observationType);
m_CompressionType = compression;
}

14
com.unity.ml-agents/Runtime/Sensors/CameraSensorComponent.cs


}
[HideInInspector, SerializeField]
ObservationType m_ObservationType;
/// <summary>
/// The type of the observation.
/// </summary>
public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; UpdateSensor(); }
}
[HideInInspector, SerializeField]
[Range(1, 50)]
[Tooltip("Number of camera frames that will be stacked before being fed to the neural network.")]
int m_ObservationStacks = 1;

/// <returns>The created <see cref="CameraSensor"/> object for this component.</returns>
public override ISensor[] CreateSensors()
{
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression);
m_Sensor = new CameraSensor(m_Camera, m_Width, m_Height, Grayscale, m_SensorName, m_Compression, m_ObservationType);
if (ObservationStacks != 1)
{

12
com.unity.ml-agents/Runtime/Sensors/ISensor.cs


/// <summary>
/// Collected observations contain goal information.
/// </summary>
Goal = 1,
/// <summary>
/// Collected observations contain reward information.
/// </summary>
Reward = 2,
/// <summary>
/// Collected observations are messages from other agents.
/// </summary>
Message = 3,
GoalSignal = 1,
}
/// <summary>

10
com.unity.ml-agents/Runtime/Sensors/VectorSensor.cs


/// </summary>
/// <param name="observationSize">Number of vector observations.</param>
/// <param name="name">Name of the sensor.</param>
public VectorSensor(int observationSize, string name = null)
public VectorSensor(int observationSize, string name = null, ObservationType observationType = ObservationType.Default)
if (name == null)
if (string.IsNullOrEmpty(name))
if (observationType != ObservationType.Default)
{
name += $"_{observationType.ToString()}";
}
m_ObservationSpec = ObservationSpec.Vector(observationSize);
m_ObservationSpec = ObservationSpec.Vector(observationSize, observationType);
}
/// <inheritdoc/>

17
com.unity.ml-agents/Tests/Runtime/Sensor/CameraSensorTest.cs


}
}
}
[Test]
public void TestObservationType()
{
var width = 24;
var height = 16;
var camera = Camera.main;
var sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new CameraSensor(camera, width, height, true, "TestCameraSensor", SensorCompressionType.None, ObservationType.GoalSignal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
}
}
}

14
com.unity.ml-agents/Tests/Runtime/Sensor/VectorSensorTests.cs


}
[Test]
public void TestObservationType()
{
var sensor = new VectorSensor(1);
var spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.Default);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.Default);
sensor = new VectorSensor(1, observationType: ObservationType.GoalSignal);
spec = sensor.GetObservationSpec();
Assert.AreEqual((int)spec.ObservationType, (int)ObservationType.GoalSignal);
}
[Test]
public void TestAddObservationInt()
{
var sensor = new VectorSensor(1);

6
ml-agents-envs/mlagents_envs/base_env.py


# Observation information is generic.
DEFAULT = 0
# Observation contains goal information for current task.
GOAL = 1
# Observation contains reward information for current task.
REWARD = 2
# Observation contains a message from another agent.
MESSAGE = 3
GOAL_SIGNAL = 1
class ObservationSpec(NamedTuple):

18
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.py


name='mlagents_envs/communicator_objects/observation.proto',
package='communicator_objects',
syntax='proto3',
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*F\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x08\n\x04GOAL\x10\x01\x12\n\n\x06REWARD\x10\x02\x12\x0b\n\x07MESSAGE\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
serialized_pb=_b('\n4mlagents_envs/communicator_objects/observation.proto\x12\x14\x63ommunicator_objects\"\x8f\x03\n\x10ObservationProto\x12\r\n\x05shape\x18\x01 \x03(\x05\x12\x44\n\x10\x63ompression_type\x18\x02 \x01(\x0e\x32*.communicator_objects.CompressionTypeProto\x12\x19\n\x0f\x63ompressed_data\x18\x03 \x01(\x0cH\x00\x12\x46\n\nfloat_data\x18\x04 \x01(\x0b\x32\x30.communicator_objects.ObservationProto.FloatDataH\x00\x12\"\n\x1a\x63ompressed_channel_mapping\x18\x05 \x03(\x05\x12\x1c\n\x14\x64imension_properties\x18\x06 \x03(\x05\x12\x44\n\x10observation_type\x18\x07 \x01(\x0e\x32*.communicator_objects.ObservationTypeProto\x12\x0c\n\x04name\x18\x08 \x01(\t\x1a\x19\n\tFloatData\x12\x0c\n\x04\x64\x61ta\x18\x01 \x03(\x02\x42\x12\n\x10observation_data*)\n\x14\x43ompressionTypeProto\x12\x08\n\x04NONE\x10\x00\x12\x07\n\x03PNG\x10\x01*@\n\x14ObservationTypeProto\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0f\n\x0bGOAL_SIGNAL\x10\x01\"\x04\x08\x02\x10\x02\"\x04\x08\x03\x10\x03\x42%\xaa\x02\"Unity.MLAgents.CommunicatorObjectsb\x06proto3')
)
_COMPRESSIONTYPEPROTO = _descriptor.EnumDescriptor(

options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='GOAL', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='REWARD', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='MESSAGE', index=3, number=3,
name='GOAL_SIGNAL', index=1, number=1,
options=None,
type=None),
],

serialized_end=593,
serialized_end=587,
)
_sym_db.RegisterEnumDescriptor(_OBSERVATIONTYPEPROTO)

DEFAULT = 0
GOAL = 1
REWARD = 2
MESSAGE = 3
GOAL_SIGNAL = 1

8
ml-agents-envs/mlagents_envs/communicator_objects/observation_pb2.pyi


@classmethod
def items(cls) -> typing___List[typing___Tuple[builtin___str, 'ObservationTypeProto']]: ...
DEFAULT = typing___cast('ObservationTypeProto', 0)
GOAL = typing___cast('ObservationTypeProto', 1)
REWARD = typing___cast('ObservationTypeProto', 2)
MESSAGE = typing___cast('ObservationTypeProto', 3)
GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1)
GOAL = typing___cast('ObservationTypeProto', 1)
REWARD = typing___cast('ObservationTypeProto', 2)
MESSAGE = typing___cast('ObservationTypeProto', 3)
GOAL_SIGNAL = typing___cast('ObservationTypeProto', 1)
class ObservationProto(google___protobuf___message___Message):
DESCRIPTOR: google___protobuf___descriptor___Descriptor = ...

2
ml-agents/mlagents/trainers/tests/simple_test_envs.py


obs_spec[i] = ObservationSpec(
shape=obs_spec[i].shape,
dimension_property=obs_spec[i].dimension_property,
observation_type=ObservationType.GOAL,
observation_type=ObservationType.GOAL_SIGNAL,
name=obs_spec[i].name,
)
return obs_spec

2
ml-agents/mlagents/trainers/torch/networks.py


self._total_goal_enc_size = 0
self._goal_processor_indices: List[int] = []
for i in range(len(observation_specs)):
if observation_specs[i].observation_type == ObservationType.GOAL:
if observation_specs[i].observation_type == ObservationType.GOAL_SIGNAL:
self._total_goal_enc_size += self.embedding_sizes[i]
self._goal_processor_indices.append(i)

6
protobuf-definitions/proto/mlagents_envs/communicator_objects/observation.proto


enum ObservationTypeProto {
DEFAULT = 0;
GOAL = 1;
REWARD = 2;
MESSAGE = 3;
GOAL_SIGNAL = 1;
reserved 2; // Reserved for potential "reward" type
reserved 3; // Reserved for potential "message" type
}
message ObservationProto {

30
com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs


using UnityEditor;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Editor
{
[CustomEditor(typeof(VectorSensorComponent))]
[CanEditMultipleObjects]
internal class VectorSensorComponentEditor : UnityEditor.Editor
{
public override void OnInspectorGUI()
{
var so = serializedObject;
so.Update();
// Drawing the VectorSensorComponent
EditorGUI.BeginDisabledGroup(!EditorUtilities.CanUpdateModelProperties());
{
// These fields affect the sensor order or observation size,
// So can't be changed at runtime.
EditorGUILayout.PropertyField(so.FindProperty("m_SensorName"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationSize"), true);
EditorGUILayout.PropertyField(so.FindProperty("m_ObservationType"), true);
}
EditorGUI.EndDisabledGroup();
so.ApplyModifiedProperties();
}
}
}

11
com.unity.ml-agents/Editor/VectorSensorComponentEditor.cs.meta


fileFormatVersion: 2
guid: aa0230c3402f04921acdbbdb61f6ff00
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

68
com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs


using UnityEngine;
using UnityEngine.Serialization;
namespace Unity.MLAgents.Sensors
{
/// <summary>
/// A SensorComponent that creates a <see cref="VectorSensor"/>.
/// </summary>
[AddComponentMenu("ML Agents/Vector Sensor", (int)MenuGroup.Sensors)]
public class VectorSensorComponent : SensorComponent
{
/// <summary>
/// Name of the generated <see cref="VectorSensor"/> object.
/// Note that changing this at runtime does not affect how the Agent sorts the sensors.
/// </summary>
public string SensorName
{
get { return m_SensorName; }
set { m_SensorName = value; }
}
[HideInInspector, SerializeField]
private string m_SensorName = "VectorSensor";
/// <summary>
/// The number of float observations in the VectorSensor
/// </summary>
public int ObservationSize
{
get { return m_ObservationSize; }
set { m_ObservationSize = value; }
}
[HideInInspector, SerializeField]
int m_ObservationSize;
[HideInInspector, SerializeField]
ObservationType m_ObservationType;
VectorSensor m_Sensor;
/// <summary>
/// The type of the observation.
/// </summary>
public ObservationType ObservationType
{
get { return m_ObservationType; }
set { m_ObservationType = value; }
}
/// <summary>
/// Creates a VectorSensor.
/// </summary>
/// <returns></returns>
public override ISensor[] CreateSensors()
{
m_Sensor = new VectorSensor(m_ObservationSize, m_SensorName, m_ObservationType);
return new ISensor[] { m_Sensor };
}
/// <summary>
/// Returns the underlying VectorSensor
/// </summary>
public VectorSensor GetSensor()
{
return m_Sensor;
}
}
}

11
com.unity.ml-agents/Runtime/Sensors/VectorSensorComponent.cs.meta


fileFormatVersion: 2
guid: 38b7cc1f5819445aa85e9a9b054552dc
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存