Arthur Juliani
5 年前
当前提交
28e095e0
共有 64 个文件被更改,包括 1751 次插入 和 427 次删除
-
1.github/ISSUE_TEMPLATE/bug_report.md
-
23.pre-commit-config.yaml
-
35README.md
-
4com.unity.ml-agents/CHANGELOG.md
-
23com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
-
20com.unity.ml-agents/Editor/BrainParametersDrawer.cs
-
18com.unity.ml-agents/Runtime/Agent.cs
-
43com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
-
43com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
-
6com.unity.ml-agents/Runtime/Policies/BrainParameters.cs
-
24com.unity.ml-agents/Tests/Editor/EditModeTestInternalBrainTensorGenerator.cs
-
65com.unity.ml-agents/Tests/Editor/MLAgentsEditModeTest.cs
-
2com.unity.ml-agents/Tests/Editor/PublicAPI/Unity.ML-Agents.Editor.Tests.PublicAPI.asmdef
-
10com.unity.ml-agents/Tests/Runtime/RuntimeAPITest.cs
-
5docs/Learning-Environment-Create-New.md
-
2docs/Migrating.md
-
17docs/Python-API.md
-
2docs/Training-ML-Agents.md
-
2docs/Using-Tensorboard.md
-
21gym-unity/gym_unity/envs/__init__.py
-
5gym-unity/gym_unity/tests/test_gym.py
-
50ml-agents-envs/mlagents_envs/base_env.py
-
293ml-agents-envs/mlagents_envs/environment.py
-
54ml-agents-envs/mlagents_envs/tests/test_envs.py
-
47ml-agents-envs/mlagents_envs/tests/test_side_channel.py
-
14ml-agents/mlagents/trainers/learn.py
-
1ml-agents/mlagents/trainers/policy/tf_policy.py
-
4ml-agents/mlagents/trainers/ppo/trainer.py
-
8ml-agents/mlagents/trainers/simple_env_manager.py
-
10ml-agents/mlagents/trainers/subprocess_env_manager.py
-
12ml-agents/mlagents/trainers/tests/simple_test_envs.py
-
5ml-agents/mlagents/trainers/tests/test_learn.py
-
4ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
6ml-agents/tests/yamato/check_coverage_percent.py
-
4ml-agents/tests/yamato/scripts/run_llapi.py
-
2ml-agents/tests/yamato/yamato_utils.py
-
6utils/validate_versions.py
-
8com.unity.ml-agents/Runtime/Sensors/Reflection.meta
-
11com.unity.ml-agents/Tests/Editor/Communicator/GrpcExtensionsTests.cs.meta
-
292com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs
-
11com.unity.ml-agents/Tests/Editor/Sensor/ObservableAttributeTests.cs.meta
-
108ml-agents-envs/mlagents_envs/env_utils.py
-
81ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py
-
64ml-agents-envs/mlagents_envs/tests/test_env_utils.py
-
102ml-agents-envs/mlagents_envs/tests/test_steps.py
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs.meta
-
97com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/ReflectionSensorBase.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs.meta
-
11com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs.meta
-
19com.unity.ml-agents/Runtime/Sensors/Reflection/BoolReflectionSensor.cs
-
19com.unity.ml-agents/Runtime/Sensors/Reflection/FloatReflectionSensor.cs
-
19com.unity.ml-agents/Runtime/Sensors/Reflection/IntReflectionSensor.cs
-
272com.unity.ml-agents/Runtime/Sensors/Reflection/ObservableAttribute.cs
-
22com.unity.ml-agents/Runtime/Sensors/Reflection/QuaternionReflectionSensor.cs
-
20com.unity.ml-agents/Runtime/Sensors/Reflection/Vector2ReflectionSensor.cs
-
21com.unity.ml-agents/Runtime/Sensors/Reflection/Vector3ReflectionSensor.cs
-
22com.unity.ml-agents/Runtime/Sensors/Reflection/Vector4ReflectionSensor.cs
|
|||
fileFormatVersion: 2 |
|||
guid: 08ece3d7e9bb94089a9d59c6f269ab0a |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: e5e4df2934c014aa3b835b9eb9ad20b3 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using NUnit.Framework; |
|||
using UnityEngine; |
|||
using Unity.MLAgents.Sensors; |
|||
using Unity.MLAgents.Sensors.Reflection; |
|||
|
|||
namespace Unity.MLAgents.Tests |
|||
{ |
|||
[TestFixture] |
|||
public class ObservableAttributeTests |
|||
{ |
|||
class TestClass |
|||
{ |
|||
// Non-observables
|
|||
int m_NonObservableInt; |
|||
float m_NonObservableFloat; |
|||
|
|||
//
|
|||
// Int
|
|||
//
|
|||
[Observable] |
|||
public int m_IntMember; |
|||
|
|||
int m_IntProperty; |
|||
|
|||
[Observable] |
|||
public int IntProperty |
|||
{ |
|||
get => m_IntProperty; |
|||
set => m_IntProperty = value; |
|||
} |
|||
|
|||
//
|
|||
// Float
|
|||
//
|
|||
[Observable("floatMember")] |
|||
public float m_FloatMember; |
|||
|
|||
float m_FloatProperty; |
|||
[Observable("floatProperty")] |
|||
public float FloatProperty |
|||
{ |
|||
get => m_FloatProperty; |
|||
set => m_FloatProperty = value; |
|||
} |
|||
|
|||
//
|
|||
// Bool
|
|||
//
|
|||
[Observable("boolMember")] |
|||
public bool m_BoolMember; |
|||
|
|||
bool m_BoolProperty; |
|||
[Observable("boolProperty")] |
|||
public bool BoolProperty |
|||
{ |
|||
get => m_BoolProperty; |
|||
set => m_BoolProperty = value; |
|||
} |
|||
|
|||
//
|
|||
// Vector2
|
|||
//
|
|||
|
|||
[Observable("vector2Member")] |
|||
public Vector2 m_Vector2Member; |
|||
|
|||
Vector2 m_Vector2Property; |
|||
|
|||
[Observable("vector2Property")] |
|||
public Vector2 Vector2Property |
|||
{ |
|||
get => m_Vector2Property; |
|||
set => m_Vector2Property = value; |
|||
} |
|||
|
|||
//
|
|||
// Vector3
|
|||
//
|
|||
[Observable("vector3Member")] |
|||
public Vector3 m_Vector3Member; |
|||
|
|||
Vector3 m_Vector3Property; |
|||
|
|||
[Observable("vector3Property")] |
|||
public Vector3 Vector3Property |
|||
{ |
|||
get => m_Vector3Property; |
|||
set => m_Vector3Property = value; |
|||
} |
|||
|
|||
//
|
|||
// Vector4
|
|||
//
|
|||
|
|||
[Observable("vector4Member")] |
|||
public Vector4 m_Vector4Member; |
|||
|
|||
Vector4 m_Vector4Property; |
|||
|
|||
[Observable("vector4Property")] |
|||
public Vector4 Vector4Property |
|||
{ |
|||
get => m_Vector4Property; |
|||
set => m_Vector4Property = value; |
|||
} |
|||
|
|||
//
|
|||
// Quaternion
|
|||
//
|
|||
[Observable("quaternionMember")] |
|||
public Quaternion m_QuaternionMember; |
|||
|
|||
Quaternion m_QuaternionProperty; |
|||
|
|||
[Observable("quaternionProperty")] |
|||
public Quaternion QuaternionProperty |
|||
{ |
|||
get => m_QuaternionProperty; |
|||
set => m_QuaternionProperty = value; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestGetObservableSensors() |
|||
{ |
|||
var testClass = new TestClass(); |
|||
testClass.m_IntMember = 1; |
|||
testClass.IntProperty = 2; |
|||
|
|||
testClass.m_FloatMember = 1.1f; |
|||
testClass.FloatProperty = 1.2f; |
|||
|
|||
testClass.m_BoolMember = true; |
|||
testClass.BoolProperty = true; |
|||
|
|||
testClass.m_Vector2Member = new Vector2(2.0f, 2.1f); |
|||
testClass.Vector2Property = new Vector2(2.2f, 2.3f); |
|||
|
|||
testClass.m_Vector3Member = new Vector3(3.0f, 3.1f, 3.2f); |
|||
testClass.Vector3Property = new Vector3(3.3f, 3.4f, 3.5f); |
|||
|
|||
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); |
|||
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); |
|||
|
|||
testClass.m_Vector4Member = new Vector4(4.0f, 4.1f, 4.2f, 4.3f); |
|||
testClass.Vector4Property = new Vector4(4.4f, 4.5f, 4.5f, 4.7f); |
|||
|
|||
testClass.m_QuaternionMember = new Quaternion(5.0f, 5.1f, 5.2f, 5.3f); |
|||
testClass.QuaternionProperty = new Quaternion(5.4f, 5.5f, 5.5f, 5.7f); |
|||
|
|||
var sensors = ObservableAttribute.CreateObservableSensors(testClass, false); |
|||
|
|||
var sensorsByName = new Dictionary<string, ISensor>(); |
|||
foreach (var sensor in sensors) |
|||
{ |
|||
sensorsByName[sensor.GetName()] = sensor; |
|||
} |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.m_IntMember"], new[] { 1.0f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["ObservableAttribute:TestClass.IntProperty"], new[] { 2.0f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["floatMember"], new[] { 1.1f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["floatProperty"], new[] { 1.2f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["boolMember"], new[] { 1.0f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["boolProperty"], new[] { 1.0f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["vector2Member"], new[] { 2.0f, 2.1f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["vector2Property"], new[] { 2.2f, 2.3f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["vector3Member"], new[] { 3.0f, 3.1f, 3.2f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["vector3Property"], new[] { 3.3f, 3.4f, 3.5f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["vector4Member"], new[] { 4.0f, 4.1f, 4.2f, 4.3f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["vector4Property"], new[] { 4.4f, 4.5f, 4.5f, 4.7f }); |
|||
|
|||
SensorTestHelper.CompareObservation(sensorsByName["quaternionMember"], new[] { 5.0f, 5.1f, 5.2f, 5.3f }); |
|||
SensorTestHelper.CompareObservation(sensorsByName["quaternionProperty"], new[] { 5.4f, 5.5f, 5.5f, 5.7f }); |
|||
} |
|||
|
|||
[Test] |
|||
public void TestGetTotalObservationSize() |
|||
{ |
|||
var testClass = new TestClass(); |
|||
var errors = new List<string>(); |
|||
var expectedObsSize = 2 * (1 + 1 + 1 + 2 + 3 + 4 + 4); |
|||
Assert.AreEqual(expectedObsSize, ObservableAttribute.GetTotalObservationSize(testClass, false, errors)); |
|||
Assert.AreEqual(0, errors.Count); |
|||
} |
|||
|
|||
class BadClass |
|||
{ |
|||
[Observable] |
|||
double m_Double; |
|||
|
|||
[Observable] |
|||
double DoubleProperty |
|||
{ |
|||
get => m_Double; |
|||
set => m_Double = value; |
|||
} |
|||
|
|||
float m_WriteOnlyProperty; |
|||
|
|||
[Observable] |
|||
// No get property, so we shouldn't be able to make a sensor out of this.
|
|||
public float WriteOnlyProperty |
|||
{ |
|||
set => m_WriteOnlyProperty = value; |
|||
} |
|||
} |
|||
|
|||
[Test] |
|||
public void TestInvalidObservables() |
|||
{ |
|||
var bad = new BadClass(); |
|||
bad.WriteOnlyProperty = 1.0f; |
|||
var errors = new List<string>(); |
|||
Assert.AreEqual(0, ObservableAttribute.GetTotalObservationSize(bad, false, errors)); |
|||
Assert.AreEqual(3, errors.Count); |
|||
|
|||
// Should be able to safely generate sensors (and get nothing back)
|
|||
var sensors = ObservableAttribute.CreateObservableSensors(bad, false); |
|||
Assert.AreEqual(0, sensors.Count); |
|||
} |
|||
|
|||
class StackingClass |
|||
{ |
|||
[Observable(numStackedObservations: 2)] |
|||
public float FloatVal; |
|||
} |
|||
|
|||
[Test] |
|||
public void TestObservableAttributeStacking() |
|||
{ |
|||
var c = new StackingClass(); |
|||
c.FloatVal = 1.0f; |
|||
var sensors = ObservableAttribute.CreateObservableSensors(c, false); |
|||
var sensor = sensors[0]; |
|||
Assert.AreEqual(typeof(StackingSensor), sensor.GetType()); |
|||
SensorTestHelper.CompareObservation(sensor, new[] { 0.0f, 1.0f }); |
|||
|
|||
sensor.Update(); |
|||
c.FloatVal = 3.0f; |
|||
SensorTestHelper.CompareObservation(sensor, new[] { 1.0f, 3.0f }); |
|||
|
|||
var errors = new List<string>(); |
|||
Assert.AreEqual(2, ObservableAttribute.GetTotalObservationSize(c, false, errors)); |
|||
Assert.AreEqual(0, errors.Count); |
|||
} |
|||
|
|||
class BaseClass |
|||
{ |
|||
[Observable("base")] |
|||
public float m_BaseField; |
|||
|
|||
[Observable("private")] |
|||
float m_PrivateField; |
|||
} |
|||
|
|||
class DerivedClass : BaseClass |
|||
{ |
|||
[Observable("derived")] |
|||
float m_DerivedField; |
|||
} |
|||
|
|||
[Test] |
|||
public void TestObservableAttributeExcludeInherited() |
|||
{ |
|||
var d = new DerivedClass(); |
|||
d.m_BaseField = 1.0f; |
|||
|
|||
// excludeInherited=false will get fields in the derived class, plus public and protected inherited fields
|
|||
var sensorAll = ObservableAttribute.CreateObservableSensors(d, false); |
|||
Assert.AreEqual(2, sensorAll.Count); |
|||
// Note - actual order doesn't matter here, we can change this to use a HashSet if neeed.
|
|||
Assert.AreEqual("derived", sensorAll[0].GetName()); |
|||
Assert.AreEqual("base", sensorAll[1].GetName()); |
|||
|
|||
// excludeInherited=true will only get fields in the derived class
|
|||
var sensorsDerivedOnly = ObservableAttribute.CreateObservableSensors(d, true); |
|||
Assert.AreEqual(1, sensorsDerivedOnly.Count); |
|||
Assert.AreEqual("derived", sensorsDerivedOnly[0].GetName()); |
|||
|
|||
var b = new BaseClass(); |
|||
var baseSensors = ObservableAttribute.CreateObservableSensors(b, false); |
|||
Assert.AreEqual(2, baseSensors.Count); |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 33d7912e6b3504412bd261b40e46df32 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import glob |
|||
import os |
|||
import subprocess |
|||
from sys import platform |
|||
from typing import Optional, List |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.exception import UnityEnvironmentException |
|||
|
|||
|
|||
def get_platform(): |
|||
""" |
|||
returns the platform of the operating system : linux, darwin or win32 |
|||
""" |
|||
return platform |
|||
|
|||
|
|||
def validate_environment_path(env_path: str) -> Optional[str]: |
|||
""" |
|||
Strip out executable extensions of the env_path |
|||
:param env_path: The path to the executable |
|||
""" |
|||
env_path = ( |
|||
env_path.strip() |
|||
.replace(".app", "") |
|||
.replace(".exe", "") |
|||
.replace(".x86_64", "") |
|||
.replace(".x86", "") |
|||
) |
|||
true_filename = os.path.basename(os.path.normpath(env_path)) |
|||
get_logger(__name__).debug("The true file name is {}".format(true_filename)) |
|||
|
|||
if not (glob.glob(env_path) or glob.glob(env_path + ".*")): |
|||
return None |
|||
|
|||
cwd = os.getcwd() |
|||
launch_string = None |
|||
true_filename = os.path.basename(os.path.normpath(env_path)) |
|||
if get_platform() == "linux" or get_platform() == "linux2": |
|||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86_64") |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(os.path.join(cwd, env_path) + ".x86") |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(env_path + ".x86_64") |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(env_path + ".x86") |
|||
if len(candidates) > 0: |
|||
launch_string = candidates[0] |
|||
|
|||
elif get_platform() == "darwin": |
|||
candidates = glob.glob( |
|||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", true_filename) |
|||
) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob( |
|||
os.path.join(env_path + ".app", "Contents", "MacOS", true_filename) |
|||
) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob( |
|||
os.path.join(cwd, env_path + ".app", "Contents", "MacOS", "*") |
|||
) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob( |
|||
os.path.join(env_path + ".app", "Contents", "MacOS", "*") |
|||
) |
|||
if len(candidates) > 0: |
|||
launch_string = candidates[0] |
|||
elif get_platform() == "win32": |
|||
candidates = glob.glob(os.path.join(cwd, env_path + ".exe")) |
|||
if len(candidates) == 0: |
|||
candidates = glob.glob(env_path + ".exe") |
|||
if len(candidates) > 0: |
|||
launch_string = candidates[0] |
|||
return launch_string |
|||
|
|||
|
|||
def launch_executable(file_name: str, args: List[str]) -> subprocess.Popen: |
|||
""" |
|||
Launches a Unity executable and returns the process handle for it. |
|||
:param file_name: the name of the executable |
|||
:param args: List of string that will be passed as command line arguments |
|||
when launching the executable. |
|||
""" |
|||
launch_string = validate_environment_path(file_name) |
|||
if launch_string is None: |
|||
raise UnityEnvironmentException( |
|||
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments." |
|||
) |
|||
else: |
|||
get_logger(__name__).debug("This is the launch string {}".format(launch_string)) |
|||
# Launch Unity environment |
|||
subprocess_args = [launch_string] + args |
|||
try: |
|||
return subprocess.Popen( |
|||
subprocess_args, |
|||
# start_new_session=True means that signals to the parent python process |
|||
# (e.g. SIGINT from keyboard interrupt) will not be sent to the new process on POSIX platforms. |
|||
# This is generally good since we want the environment to have a chance to shutdown, |
|||
# but may be undesirable in come cases; if so, we'll add a command-line toggle. |
|||
# Note that on Windows, the CTRL_C signal will still be sent. |
|||
start_new_session=True, |
|||
) |
|||
except PermissionError as perm: |
|||
# This is likely due to missing read or execute permissions on file. |
|||
raise UnityEnvironmentException( |
|||
f"Error when trying to launch environment - make sure " |
|||
f"permissions are set correctly. For example " |
|||
f'"chmod -R 755 {launch_string}"' |
|||
) from perm |
|
|||
import uuid |
|||
import struct |
|||
from typing import Dict, Optional, List |
|||
from mlagents_envs.side_channel import SideChannel, IncomingMessage |
|||
from mlagents_envs.exception import UnityEnvironmentException |
|||
from mlagents_envs.logging_util import get_logger |
|||
|
|||
|
|||
class SideChannelManager: |
|||
def __init__(self, side_channels=Optional[List[SideChannel]]): |
|||
self._side_channels_dict = self._get_side_channels_dict(side_channels) |
|||
|
|||
def process_side_channel_message(self, data: bytes) -> None: |
|||
""" |
|||
Separates the data received from Python into individual messages for each |
|||
registered side channel and calls on_message_received on them. |
|||
:param data: The packed message sent by Unity |
|||
""" |
|||
offset = 0 |
|||
while offset < len(data): |
|||
try: |
|||
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16])) |
|||
offset += 16 |
|||
message_len, = struct.unpack_from("<i", data, offset) |
|||
offset = offset + 4 |
|||
message_data = data[offset : offset + message_len] |
|||
offset = offset + message_len |
|||
except (struct.error, ValueError, IndexError): |
|||
raise UnityEnvironmentException( |
|||
"There was a problem reading a message in a SideChannel. " |
|||
"Please make sure the version of MLAgents in Unity is " |
|||
"compatible with the Python version." |
|||
) |
|||
if len(message_data) != message_len: |
|||
raise UnityEnvironmentException( |
|||
"The message received by the side channel {0} was " |
|||
"unexpectedly short. Make sure your Unity Environment " |
|||
"sending side channel data properly.".format(channel_id) |
|||
) |
|||
if channel_id in self._side_channels_dict: |
|||
incoming_message = IncomingMessage(message_data) |
|||
self._side_channels_dict[channel_id].on_message_received( |
|||
incoming_message |
|||
) |
|||
else: |
|||
get_logger(__name__).warning( |
|||
f"Unknown side channel data received. Channel type: {channel_id}." |
|||
) |
|||
|
|||
def generate_side_channel_messages(self) -> bytearray: |
|||
""" |
|||
Gathers the messages that the registered side channels will send to Unity |
|||
and combines them into a single message ready to be sent. |
|||
""" |
|||
result = bytearray() |
|||
for channel_id, channel in self._side_channels_dict.items(): |
|||
for message in channel.message_queue: |
|||
result += channel_id.bytes_le |
|||
result += struct.pack("<i", len(message)) |
|||
result += message |
|||
channel.message_queue = [] |
|||
return result |
|||
|
|||
@staticmethod |
|||
def _get_side_channels_dict( |
|||
side_channels: Optional[List[SideChannel]] |
|||
) -> Dict[uuid.UUID, SideChannel]: |
|||
""" |
|||
Converts a list of side channels into a dictionary of channel_id to SideChannel |
|||
:param side_channels: The list of side channels. |
|||
""" |
|||
side_channels_dict: Dict[uuid.UUID, SideChannel] = {} |
|||
if side_channels is not None: |
|||
for _sc in side_channels: |
|||
if _sc.channel_id in side_channels_dict: |
|||
raise UnityEnvironmentException( |
|||
f"There cannot be two side channels with " |
|||
f"the same channel id {_sc.channel_id}." |
|||
) |
|||
side_channels_dict[_sc.channel_id] = _sc |
|||
return side_channels_dict |
|
|||
from unittest import mock |
|||
import pytest |
|||
from mlagents_envs.env_utils import validate_environment_path, launch_executable |
|||
from mlagents_envs.exception import UnityEnvironmentException |
|||
from mlagents_envs.logging_util import ( |
|||
set_log_level, |
|||
get_logger, |
|||
INFO, |
|||
ERROR, |
|||
FATAL, |
|||
CRITICAL, |
|||
DEBUG, |
|||
) |
|||
|
|||
|
|||
def mock_glob_method(path): |
|||
""" |
|||
Given a path input, returns a list of candidates |
|||
""" |
|||
if ".x86" in path: |
|||
return ["linux"] |
|||
if ".app" in path: |
|||
return ["darwin"] |
|||
if ".exe" in path: |
|||
return ["win32"] |
|||
if "*" in path: |
|||
return "Any" |
|||
return [] |
|||
|
|||
|
|||
@mock.patch("sys.platform") |
|||
@mock.patch("glob.glob") |
|||
def test_validate_path_empty(glob_mock, platform_mock): |
|||
glob_mock.return_value = None |
|||
path = validate_environment_path(" ") |
|||
assert path is None |
|||
|
|||
|
|||
@mock.patch("mlagents_envs.env_utils.get_platform") |
|||
@mock.patch("glob.glob") |
|||
def test_validate_path(glob_mock, platform_mock): |
|||
glob_mock.side_effect = mock_glob_method |
|||
for platform in ["linux", "darwin", "win32"]: |
|||
platform_mock.return_value = platform |
|||
path = validate_environment_path(" ") |
|||
assert path == platform |
|||
|
|||
|
|||
@mock.patch("glob.glob") |
|||
@mock.patch("subprocess.Popen") |
|||
def test_launch_executable(mock_popen, glob_mock): |
|||
with pytest.raises(UnityEnvironmentException): |
|||
launch_executable(" ", []) |
|||
glob_mock.return_value = ["FakeLaunchPath"] |
|||
launch_executable(" ", []) |
|||
mock_popen.side_effect = PermissionError("Fake permission error") |
|||
with pytest.raises(UnityEnvironmentException): |
|||
launch_executable(" ", []) |
|||
|
|||
|
|||
def test_set_logging_level(): |
|||
for level in [INFO, ERROR, FATAL, CRITICAL, DEBUG]: |
|||
set_log_level(level) |
|||
assert get_logger("test").level == level |
|
|||
import pytest |
|||
import numpy as np |
|||
|
|||
from mlagents_envs.base_env import ( |
|||
DecisionSteps, |
|||
TerminalSteps, |
|||
ActionType, |
|||
BehaviorSpec, |
|||
) |
|||
|
|||
|
|||
def test_decision_steps(): |
|||
ds = DecisionSteps( |
|||
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], |
|||
reward=np.array(range(3), dtype=np.float32), |
|||
agent_id=np.array(range(10, 13), dtype=np.int32), |
|||
action_mask=[np.zeros((3, 4), dtype=np.bool)], |
|||
) |
|||
|
|||
assert ds.agent_id_to_index[10] == 0 |
|||
assert ds.agent_id_to_index[11] == 1 |
|||
assert ds.agent_id_to_index[12] == 2 |
|||
|
|||
with pytest.raises(KeyError): |
|||
assert ds.agent_id_to_index[-1] == -1 |
|||
|
|||
mask_agent = ds[10].action_mask |
|||
assert isinstance(mask_agent, list) |
|||
assert len(mask_agent) == 1 |
|||
assert np.array_equal(mask_agent[0], np.zeros((4), dtype=np.bool)) |
|||
|
|||
for agent_id in ds: |
|||
assert ds.agent_id_to_index[agent_id] in range(3) |
|||
|
|||
|
|||
def test_empty_decision_steps(): |
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.CONTINUOUS, |
|||
action_shape=3, |
|||
) |
|||
ds = DecisionSteps.empty(specs) |
|||
assert len(ds.obs) == 2 |
|||
assert ds.obs[0].shape == (0, 3, 2) |
|||
assert ds.obs[1].shape == (0, 5) |
|||
|
|||
|
|||
def test_terminal_steps(): |
|||
ts = TerminalSteps( |
|||
obs=[np.array(range(12), dtype=np.float32).reshape(3, 4)], |
|||
reward=np.array(range(3), dtype=np.float32), |
|||
agent_id=np.array(range(10, 13), dtype=np.int32), |
|||
interrupted=np.array([1, 0, 1], dtype=np.bool), |
|||
) |
|||
|
|||
assert ts.agent_id_to_index[10] == 0 |
|||
assert ts.agent_id_to_index[11] == 1 |
|||
assert ts.agent_id_to_index[12] == 2 |
|||
|
|||
assert ts[10].interrupted |
|||
assert not ts[11].interrupted |
|||
assert ts[12].interrupted |
|||
|
|||
with pytest.raises(KeyError): |
|||
assert ts.agent_id_to_index[-1] == -1 |
|||
|
|||
for agent_id in ts: |
|||
assert ts.agent_id_to_index[agent_id] in range(3) |
|||
|
|||
|
|||
def test_empty_terminal_steps(): |
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.CONTINUOUS, |
|||
action_shape=3, |
|||
) |
|||
ts = TerminalSteps.empty(specs) |
|||
assert len(ts.obs) == 2 |
|||
assert ts.obs[0].shape == (0, 3, 2) |
|||
assert ts.obs[1].shape == (0, 5) |
|||
|
|||
|
|||
def test_specs(): |
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.CONTINUOUS, |
|||
action_shape=3, |
|||
) |
|||
assert specs.discrete_action_branches is None |
|||
assert specs.action_size == 3 |
|||
assert specs.create_empty_action(5).shape == (5, 3) |
|||
assert specs.create_empty_action(5).dtype == np.float32 |
|||
|
|||
specs = BehaviorSpec( |
|||
observation_shapes=[(3, 2), (5,)], |
|||
action_type=ActionType.DISCRETE, |
|||
action_shape=(3,), |
|||
) |
|||
assert specs.discrete_action_branches == (3,) |
|||
assert specs.action_size == 1 |
|||
assert specs.create_empty_action(5).shape == (5, 1) |
|||
assert specs.create_empty_action(5).dtype == np.int32 |
|
|||
fileFormatVersion: 2 |
|||
guid: be795c90750a6420d93f569b69ddc1ba |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 51ed837d5b7cd44349287ac8066120fc |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 5cae4c843cc074d11a549aaa3904c898 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: a75086dc66a594baea6b8b2935f5dacf |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: d38241d74074d459bb4590f7f5d16c80 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
using System.Reflection; |
|||
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Construction info for a ReflectionSensorBase.
|
|||
/// </summary>
|
|||
internal struct ReflectionSensorInfo |
|||
{ |
|||
public object Object; |
|||
|
|||
public FieldInfo FieldInfo; |
|||
public PropertyInfo PropertyInfo; |
|||
public ObservableAttribute ObservableAttribute; |
|||
public string SensorName; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Abstract base class for reflection-based sensors.
|
|||
/// </summary>
|
|||
internal abstract class ReflectionSensorBase : ISensor |
|||
{ |
|||
protected object m_Object; |
|||
|
|||
// Exactly one of m_FieldInfo and m_PropertyInfo should be non-null.
|
|||
protected FieldInfo m_FieldInfo; |
|||
protected PropertyInfo m_PropertyInfo; |
|||
|
|||
// Not currently used, but might want later.
|
|||
protected ObservableAttribute m_ObservableAttribute; |
|||
|
|||
// Cached sensor names and shapes.
|
|||
string m_SensorName; |
|||
int[] m_Shape; |
|||
|
|||
public ReflectionSensorBase(ReflectionSensorInfo reflectionSensorInfo, int size) |
|||
{ |
|||
m_Object = reflectionSensorInfo.Object; |
|||
m_FieldInfo = reflectionSensorInfo.FieldInfo; |
|||
m_PropertyInfo = reflectionSensorInfo.PropertyInfo; |
|||
m_ObservableAttribute = reflectionSensorInfo.ObservableAttribute; |
|||
m_SensorName = reflectionSensorInfo.SensorName; |
|||
m_Shape = new[] {size}; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int[] GetObservationShape() |
|||
{ |
|||
return m_Shape; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public int Write(ObservationWriter writer) |
|||
{ |
|||
WriteReflectedField(writer); |
|||
return m_Shape[0]; |
|||
} |
|||
|
|||
internal abstract void WriteReflectedField(ObservationWriter writer); |
|||
|
|||
/// <summary>
|
|||
/// Get either the reflected field, or return the reflected property.
|
|||
/// This should be used by implementations in their WriteReflectedField() method.
|
|||
/// </summary>
|
|||
/// <returns></returns>
|
|||
protected object GetReflectedValue() |
|||
{ |
|||
return m_FieldInfo != null ? |
|||
m_FieldInfo.GetValue(m_Object) : |
|||
m_PropertyInfo.GetMethod.Invoke(m_Object, null); |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public byte[] GetCompressedObservation() |
|||
{ |
|||
return null; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Update() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public void Reset() {} |
|||
|
|||
/// <inheritdoc/>
|
|||
public SensorCompressionType GetCompressionType() |
|||
{ |
|||
return SensorCompressionType.None; |
|||
} |
|||
|
|||
/// <inheritdoc/>
|
|||
public string GetName() |
|||
{ |
|||
return m_SensorName; |
|||
} |
|||
} |
|||
} |
|
|||
fileFormatVersion: 2 |
|||
guid: 6b68d855fb94a45fbbeb0dbe968a35f8 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: da06ff33f6f2d409cbf240cffa2ba0be |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: e756976ec2a0943cfbc0f97a6550a85b |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
fileFormatVersion: 2 |
|||
guid: 01d93aaa1b42b47b8960d303d7c498d3 |
|||
MonoImporter: |
|||
externalObjects: {} |
|||
serializedVersion: 2 |
|||
defaultReferences: [] |
|||
executionOrder: 0 |
|||
icon: {instanceID: 0} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a boolean field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class BoolReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public BoolReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 1) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var boolVal = (System.Boolean)GetReflectedValue(); |
|||
writer[0] = boolVal ? 1.0f : 0.0f; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a float field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class FloatReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public FloatReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 1) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var floatVal = (System.Single)GetReflectedValue(); |
|||
writer[0] = floatVal; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps an integer field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class IntReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public IntReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 1) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var intVal = (System.Int32)GetReflectedValue(); |
|||
writer[0] = (float)intVal; |
|||
} |
|||
} |
|||
} |
|
|||
using System; |
|||
using System.Collections.Generic; |
|||
using System.Reflection; |
|||
using UnityEngine; |
|||
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Specify that a field or property should be used to generate observations for an Agent.
|
|||
/// For each field or property that uses ObservableAttribute, a corresponding
|
|||
/// <see cref="ISensor"/> will be created during Agent initialization, and this
|
|||
/// sensor will read the values during training and inference.
|
|||
/// </summary>
|
|||
/// <remarks>
|
|||
/// ObservableAttribute is intended to make initial setup of an Agent easier. Because it
|
|||
/// uses reflection to read the values of fields and properties at runtime, this may
|
|||
/// be much slower than reading the values directly. If the performance of
|
|||
/// ObservableAttribute is an issue, you can get the same functionality by overriding
|
|||
/// <see cref="Agent.CollectObservations(VectorSensor)"/> or creating a custom
|
|||
/// <see cref="ISensor"/> implementation to read the values without reflection.
|
|||
///
|
|||
/// Note that you do not need to adjust the VectorObservationSize in
|
|||
/// <see cref="Unity.MLAgents.Policies.BrainParameters"/> when adding ObservableAttribute
|
|||
/// to fields or properties.
|
|||
/// </remarks>
|
|||
/// <example>
|
|||
/// This sample class will produce two observations, one for the m_Health field, and one
|
|||
/// for the HealthPercent property.
|
|||
/// <code>
|
|||
/// using Unity.MLAgents;
|
|||
/// using Unity.MLAgents.Sensors.Reflection;
|
|||
///
|
|||
/// public class MyAgent : Agent
|
|||
/// {
|
|||
/// [Observable]
|
|||
/// int m_Health;
|
|||
///
|
|||
/// [Observable]
|
|||
/// float HealthPercent
|
|||
/// {
|
|||
/// get => return 100.0f * m_Health / float(m_MaxHealth);
|
|||
/// }
|
|||
/// }
|
|||
/// </code>
|
|||
/// </example>
|
|||
[AttributeUsage(AttributeTargets.Field | AttributeTargets.Property)] |
|||
public class ObservableAttribute : Attribute |
|||
{ |
|||
string m_Name; |
|||
int m_NumStackedObservations; |
|||
|
|||
/// <summary>
|
|||
/// Default binding flags used for reflection of members and properties.
|
|||
/// </summary>
|
|||
const BindingFlags k_BindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic; |
|||
|
|||
/// <summary>
|
|||
/// Supported types and their observation sizes and corresponding sensor type.
|
|||
/// </summary>
|
|||
static Dictionary<Type, (int, Type)> s_TypeToSensorInfo = new Dictionary<Type, (int, Type)>() |
|||
{ |
|||
{typeof(int), (1, typeof(IntReflectionSensor))}, |
|||
{typeof(bool), (1, typeof(BoolReflectionSensor))}, |
|||
{typeof(float), (1, typeof(FloatReflectionSensor))}, |
|||
|
|||
{typeof(Vector2), (2, typeof(Vector2ReflectionSensor))}, |
|||
{typeof(Vector3), (3, typeof(Vector3ReflectionSensor))}, |
|||
{typeof(Vector4), (4, typeof(Vector4ReflectionSensor))}, |
|||
{typeof(Quaternion), (4, typeof(QuaternionReflectionSensor))}, |
|||
}; |
|||
|
|||
/// <summary>
|
|||
/// ObservableAttribute constructor.
|
|||
/// </summary>
|
|||
/// <param name="name">Optional override for the sensor name. Note that all sensors for an Agent
|
|||
/// must have a unique name.</param>
|
|||
/// <param name="numStackedObservations">Number of frames to concatenate observations from.</param>
|
|||
public ObservableAttribute(string name = null, int numStackedObservations = 1) |
|||
{ |
|||
m_Name = name; |
|||
m_NumStackedObservations = numStackedObservations; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns a FieldInfo for all fields that have an ObservableAttribute
|
|||
/// </summary>
|
|||
/// <param name="o">Object being reflected</param>
|
|||
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
|
|||
/// <returns></returns>
|
|||
static IEnumerable<(FieldInfo, ObservableAttribute)> GetObservableFields(object o, bool excludeInherited) |
|||
{ |
|||
// TODO cache these (and properties) by type, so that we only have to reflect once.
|
|||
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
|||
var fields = o.GetType().GetFields(bindingFlags); |
|||
foreach (var field in fields) |
|||
{ |
|||
var attr = (ObservableAttribute)GetCustomAttribute(field, typeof(ObservableAttribute)); |
|||
if (attr != null) |
|||
{ |
|||
yield return (field, attr); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Returns a PropertyInfo for all fields that have an ObservableAttribute
|
|||
/// </summary>
|
|||
/// <param name="o">Object being reflected</param>
|
|||
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
|
|||
/// <returns></returns>
|
|||
static IEnumerable<(PropertyInfo, ObservableAttribute)> GetObservableProperties(object o, bool excludeInherited) |
|||
{ |
|||
var bindingFlags = k_BindingFlags | (excludeInherited ? BindingFlags.DeclaredOnly : 0); |
|||
var properties = o.GetType().GetProperties(bindingFlags); |
|||
foreach (var prop in properties) |
|||
{ |
|||
var attr = (ObservableAttribute)GetCustomAttribute(prop, typeof(ObservableAttribute)); |
|||
if (attr != null) |
|||
{ |
|||
yield return (prop, attr); |
|||
} |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Creates sensors for each field and property with ObservableAttribute.
|
|||
/// </summary>
|
|||
/// <param name="o">Object being reflected</param>
|
|||
/// <param name="excludeInherited">Whether to exclude inherited properties or not.</param>
|
|||
/// <returns></returns>
|
|||
internal static List<ISensor> CreateObservableSensors(object o, bool excludeInherited) |
|||
{ |
|||
var sensorsOut = new List<ISensor>(); |
|||
foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) |
|||
{ |
|||
var sensor = CreateReflectionSensor(o, field, null, attr); |
|||
if (sensor != null) |
|||
{ |
|||
sensorsOut.Add(sensor); |
|||
} |
|||
} |
|||
|
|||
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) |
|||
{ |
|||
if (!prop.CanRead) |
|||
{ |
|||
// Skip unreadable properties.
|
|||
continue; |
|||
} |
|||
var sensor = CreateReflectionSensor(o, null, prop, attr); |
|||
if (sensor != null) |
|||
{ |
|||
sensorsOut.Add(sensor); |
|||
} |
|||
} |
|||
|
|||
return sensorsOut; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Create the ISensor for either the field or property on the provided object.
|
|||
/// If the data type is unsupported, or the property is write-only, returns null.
|
|||
/// </summary>
|
|||
/// <param name="o"></param>
|
|||
/// <param name="fieldInfo"></param>
|
|||
/// <param name="propertyInfo"></param>
|
|||
/// <param name="observableAttribute"></param>
|
|||
/// <returns></returns>
|
|||
/// <exception cref="UnityAgentsException"></exception>
|
|||
static ISensor CreateReflectionSensor(object o, FieldInfo fieldInfo, PropertyInfo propertyInfo, ObservableAttribute observableAttribute) |
|||
{ |
|||
string memberName; |
|||
string declaringTypeName; |
|||
Type memberType; |
|||
if (fieldInfo != null) |
|||
{ |
|||
declaringTypeName = fieldInfo.DeclaringType.Name; |
|||
memberName = fieldInfo.Name; |
|||
memberType = fieldInfo.FieldType; |
|||
} |
|||
else |
|||
{ |
|||
declaringTypeName = propertyInfo.DeclaringType.Name; |
|||
memberName = propertyInfo.Name; |
|||
memberType = propertyInfo.PropertyType; |
|||
} |
|||
|
|||
if (!s_TypeToSensorInfo.ContainsKey(memberType)) |
|||
{ |
|||
// For unsupported types, return null and we'll filter them out later.
|
|||
return null; |
|||
} |
|||
|
|||
string sensorName; |
|||
if (string.IsNullOrEmpty(observableAttribute.m_Name)) |
|||
{ |
|||
sensorName = $"ObservableAttribute:{declaringTypeName}.{memberName}"; |
|||
} |
|||
else |
|||
{ |
|||
sensorName = observableAttribute.m_Name; |
|||
} |
|||
|
|||
var reflectionSensorInfo = new ReflectionSensorInfo |
|||
{ |
|||
Object = o, |
|||
FieldInfo = fieldInfo, |
|||
PropertyInfo = propertyInfo, |
|||
ObservableAttribute = observableAttribute, |
|||
SensorName = sensorName |
|||
}; |
|||
|
|||
var (_, sensorType) = s_TypeToSensorInfo[memberType]; |
|||
var sensor = (ISensor) Activator.CreateInstance(sensorType, reflectionSensorInfo); |
|||
|
|||
// Wrap the base sensor in a StackingSensor if we're using stacking.
|
|||
if (observableAttribute.m_NumStackedObservations > 1) |
|||
{ |
|||
return new StackingSensor(sensor, observableAttribute.m_NumStackedObservations); |
|||
} |
|||
|
|||
return sensor; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets the sum of the observation sizes of the Observable fields and properties on an object.
|
|||
/// Also appends errors to the errorsOut array.
|
|||
/// </summary>
|
|||
/// <param name="o"></param>
|
|||
/// <param name="excludeInherited"></param>
|
|||
/// <param name="errorsOut"></param>
|
|||
/// <returns></returns>
|
|||
internal static int GetTotalObservationSize(object o, bool excludeInherited, List<string> errorsOut) |
|||
{ |
|||
int sizeOut = 0; |
|||
foreach (var(field, attr) in GetObservableFields(o, excludeInherited)) |
|||
{ |
|||
if (s_TypeToSensorInfo.ContainsKey(field.FieldType)) |
|||
{ |
|||
var (obsSize, _) = s_TypeToSensorInfo[field.FieldType]; |
|||
sizeOut += obsSize * attr.m_NumStackedObservations; |
|||
} |
|||
else |
|||
{ |
|||
errorsOut.Add($"Unsupported Observable type {field.FieldType.Name} on field {field.Name}"); |
|||
} |
|||
} |
|||
|
|||
foreach (var(prop, attr) in GetObservableProperties(o, excludeInherited)) |
|||
{ |
|||
if (s_TypeToSensorInfo.ContainsKey(prop.PropertyType)) |
|||
{ |
|||
if (prop.CanRead) |
|||
{ |
|||
var (obsSize, _) = s_TypeToSensorInfo[prop.PropertyType]; |
|||
sizeOut += obsSize * attr.m_NumStackedObservations; |
|||
} |
|||
else |
|||
{ |
|||
errorsOut.Add($"Observable property {prop.Name} is write-only."); |
|||
} |
|||
} |
|||
else |
|||
{ |
|||
errorsOut.Add($"Unsupported Observable type {prop.PropertyType.Name} on property {prop.Name}"); |
|||
} |
|||
} |
|||
|
|||
return sizeOut; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a quaternion field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class QuaternionReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public QuaternionReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 4) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var quatVal = (UnityEngine.Quaternion)GetReflectedValue(); |
|||
writer[0] = quatVal.x; |
|||
writer[1] = quatVal.y; |
|||
writer[2] = quatVal.z; |
|||
writer[3] = quatVal.w; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a Vector2 field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class Vector2ReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public Vector2ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 2) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var vecVal = (UnityEngine.Vector2)GetReflectedValue(); |
|||
writer[0] = vecVal.x; |
|||
writer[1] = vecVal.y; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a Vector3 field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class Vector3ReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public Vector3ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 3) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var vecVal = (UnityEngine.Vector3)GetReflectedValue(); |
|||
writer[0] = vecVal.x; |
|||
writer[1] = vecVal.y; |
|||
writer[2] = vecVal.z; |
|||
} |
|||
} |
|||
} |
|
|||
namespace Unity.MLAgents.Sensors.Reflection |
|||
{ |
|||
/// <summary>
|
|||
/// Sensor that wraps a Vector4 field or property of an object, and returns
|
|||
/// that as an observation.
|
|||
/// </summary>
|
|||
internal class Vector4ReflectionSensor : ReflectionSensorBase |
|||
{ |
|||
public Vector4ReflectionSensor(ReflectionSensorInfo reflectionSensorInfo) |
|||
: base(reflectionSensorInfo, 4) |
|||
{} |
|||
|
|||
internal override void WriteReflectedField(ObservationWriter writer) |
|||
{ |
|||
var vecVal = (UnityEngine.Vector4)GetReflectedValue(); |
|||
writer[0] = vecVal.x; |
|||
writer[1] = vecVal.y; |
|||
writer[2] = vecVal.z; |
|||
writer[3] = vecVal.w; |
|||
} |
|||
} |
|||
} |
撰写
预览
正在加载...
取消
保存
Reference in new issue