浏览代码

[MLA-345] float visual observations (#3148)

* pass shape to WriteAdapter

* handle floats on python side

* cleanup

* whitespace

* rename GetFloatObservationShape, support uncompressed in RenderTexture sensor

* numpy float32

* remove unused using

* Float sensor and unit test

* replace asserts with exceptions, docstrings
/asymm-envs
GitHub 5 年前
当前提交
a488299f
共有 26 个文件被更改,包括 655 次插入407 次删除
  1. 2
      UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs
  2. 2
      UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs
  3. 2
      UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs
  4. 22
      UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs
  5. 673
      UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
  6. 9
      UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
  7. 8
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
  8. 2
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs
  9. 4
      UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs
  10. 4
      UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs
  11. 9
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs
  12. 3
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensorComponent.cs
  13. 4
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs
  14. 2
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs
  15. 11
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs
  16. 3
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensorComponent.cs
  17. 6
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs
  18. 7
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs
  19. 2
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs
  20. 53
      UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs
  21. 8
      ml-agents-envs/mlagents_envs/exception.py
  22. 43
      ml-agents-envs/mlagents_envs/rpc_utils.py
  23. 43
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  24. 32
      ml-agents/mlagents/trainers/brain.py
  25. 105
      UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/FloatVisualSensorTests.cs
  26. 3
      UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/FloatVisualSensorTests.cs.meta

2
UnitySDK/Assets/ML-Agents/Editor/Tests/MLAgentsEditModeTest.cs


sensorName = n;
}
public int[] GetFloatObservationShape()
public int[] GetObservationShape()
{
return new[] { 0 };
}

2
UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/StackingSensorTests.cs


ISensor wrapped = new VectorSensor(4);
ISensor sensor = new StackingSensor(wrapped, 4);
Assert.AreEqual("StackingSensor_size4_VectorSensor_size4", sensor.GetName());
Assert.AreEqual(sensor.GetFloatObservationShape(), new [] {16});
Assert.AreEqual(sensor.GetObservationShape(), new [] {16});
}
[Test]

2
UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/VectorSensorTests.cs


Assert.AreEqual(fill, output[0]);
WriteAdapter writer = new WriteAdapter();
writer.SetTarget(output, 0);
writer.SetTarget(output, sensor.GetObservationShape(), 0);
// Make sure WriteAdapter didn't touch anything
Assert.AreEqual(fill, output[0]);

22
UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/WriterAdapterTests.cs


{
WriteAdapter writer = new WriteAdapter();
var buffer = new[] { 0f, 0f, 0f };
var shape = new[] { 3 };
writer.SetTarget(buffer, 0);
writer.SetTarget(buffer, shape, 0);
// Elementwise writes
writer[0] = 1f;
writer[2] = 2f;

writer.SetTarget(buffer, 1);
writer.SetTarget(buffer, shape, 1);
writer.SetTarget(buffer, 0);
writer.SetTarget(buffer, shape, 0);
writer.SetTarget(buffer, 1);
writer.SetTarget(buffer, shape, 1);
writer.AddRange(new [] {6f, 7f});
Assert.AreEqual(new[] { 4f, 6f, 7f }, buffer);
}

valueType = TensorProxy.TensorType.FloatingPoint,
data = new Tensor(2, 3)
};
writer.SetTarget(t, 0, 0);
var shape = new[] { 3 };
writer.SetTarget(t, shape, 0, 0);
writer.SetTarget(t, 1, 1);
writer.SetTarget(t, shape, 1, 1);
writer[0] = 2f;
writer[1] = 3f;
// [0, 0] shouldn't change

data = new Tensor(2, 3)
};
writer.SetTarget(t, 1, 1);
writer.SetTarget(t, shape, 1, 1);
writer.AddRange(new [] {-1f, -2f});
Assert.AreEqual(0f, t.data[0, 0]);
Assert.AreEqual(0f, t.data[0, 1]);

data = new Tensor(2, 2, 2, 3)
};
writer.SetTarget(t, 0, 0);
var shape = new[] { 2, 2, 3 };
writer.SetTarget(t, shape, 0, 0);
writer.SetTarget(t, 0, 1);
writer.SetTarget(t, shape, 0, 1);
writer[1, 0, 0] = 2f;
Assert.AreEqual(2f, t.data[0, 1, 0, 1]);
}

673
UnitySDK/Assets/ML-Agents/Examples/GridWorld/Scenes/GridWorld.unity
文件差异内容过多而无法显示
查看文件

9
UnitySDK/Assets/ML-Agents/Scripts/Agent.cs


Debug.Assert(!sensors[i].GetName().Equals(sensors[i + 1].GetName()), "Sensor names must be unique.");
}
#endif
// Create a buffer for writing vector sensor data too
// Create a buffer for writing uncompressed (i.e. float) sensor data to
int numFloatObservations = 0;
for (var i = 0; i < sensors.Count; i++)
{

var sensor = sensors[i];
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
// only handles 1D
m_WriteAdapter.SetTarget(m_VectorSensorBuffer, floatsWritten);
m_WriteAdapter.SetTarget(m_VectorSensorBuffer, sensor.GetObservationShape(), floatsWritten);
Shape = sensor.GetFloatObservationShape(),
Shape = sensor.GetObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.observations.Add(floatObs);

var compressedObs = new Observation
{
CompressedData = sensor.GetCompressedObservation(),
Shape = sensor.GetFloatObservationShape(),
Shape = sensor.GetObservationShape(),
CompressionType = sensor.GetCompressionType()
};
m_Info.observations.Add(compressedObs);

8
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs


// Write each sensor consecutively to the tensor
foreach (var sensorIndex in m_SensorIndices)
{
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, tensorOffset);
var shape = sensor.GetObservationShape();
m_WriteAdapter.SetTarget(tensorProxy, shape, agentIndex, tensorOffset);
var numWritten = sensor.Write(m_WriteAdapter);
tensorOffset += numWritten;
}

var agentIndex = 0;
foreach (var agent in agents)
{
m_WriteAdapter.SetTarget(tensorProxy, agentIndex, 0);
agent.sensors[m_SensorIndex].Write(m_WriteAdapter);
var sensor = agent.sensors[m_SensorIndex];
m_WriteAdapter.SetTarget(tensorProxy, sensor.GetObservationShape(), agentIndex, 0);
sensor.Write(m_WriteAdapter);
agentIndex++;
}
}

2
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs


for (var sensorIndex = 0; sensorIndex < agent.sensors.Count; sensorIndex++)
{
var sensor = agent.sensors[sensorIndex];
var shape = sensor.GetFloatObservationShape();
var shape = sensor.GetObservationShape();
// TODO generalize - we currently only have vector or visual, but can't handle "2D" observations
var isVectorSensor = (shape.Length == 1);
if (isVectorSensor)

4
UnitySDK/Assets/ML-Agents/Scripts/Policy/BarracudaPolicy.cs


// First agent, save the sensor sizes
foreach (var sensor in agent.sensors)
{
m_SensorShapes.Add(sensor.GetFloatObservationShape());
m_SensorShapes.Add(sensor.GetObservationShape());
}
}
else

for (var i = 0; i < m_SensorShapes.Count; i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = agent.sensors[i].GetFloatObservationShape();
var sensorShape = agent.sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < cachedShape.Length; j++)
{

4
UnitySDK/Assets/ML-Agents/Scripts/Policy/RemotePolicy.cs


// First agent, save the sensor sizes
foreach (var sensor in agent.sensors)
{
m_SensorShapes.Add(sensor.GetFloatObservationShape());
m_SensorShapes.Add(sensor.GetObservationShape());
}
}
else

for (var i = 0; i < m_SensorShapes.Count; i++)
{
var cachedShape = m_SensorShapes[i];
var sensorShape = agent.sensors[i].GetFloatObservationShape();
var sensorShape = agent.sensors[i].GetObservationShape();
Debug.Assert(cachedShape.Length == sensorShape.Length, "Sensor dimensions must match.");
for (var j = 0; j < cachedShape.Length; j++)
{

9
UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensor.cs


bool m_Grayscale;
string m_Name;
int[] m_Shape;
SensorCompressionType m_CompressionType;
public CameraSensor(Camera camera, int width, int height, bool grayscale, string name)
public CameraSensor(Camera camera, int width, int height, bool grayscale, string name,
SensorCompressionType compression)
{
m_Camera = camera;
m_Width = width;

m_Shape = new[] { height, width, grayscale ? 1 : 3 };
m_CompressionType = compression;
}
public string GetName()

public int[] GetFloatObservationShape()
public int[] GetObservationShape()
{
return m_Shape;
}

public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.PNG;
return m_CompressionType;
}
/// <summary>

3
UnitySDK/Assets/ML-Agents/Scripts/Sensor/CameraSensorComponent.cs


public int width = 84;
public int height = 84;
public bool grayscale;
public SensorCompressionType compression = SensorCompressionType.PNG;
return new CameraSensor(camera, width, height, grayscale, sensorName);
return new CameraSensor(camera, width, height, grayscale, sensorName, compression);
}
public override int[] GetObservationShape()

4
UnitySDK/Assets/ML-Agents/Scripts/Sensor/ISensor.cs


/// A sensor that returns an RGB image would return new [] {Width, Height, 3}
/// </summary>
/// <returns></returns>
int[] GetFloatObservationShape();
int[] GetObservationShape();
/// <summary>
/// Write the observation data directly to the WriteAdapter.

/// <returns></returns>
public static int ObservationSize(this ISensor sensor)
{
var shape = sensor.GetFloatObservationShape();
var shape = sensor.GetObservationShape();
int count = 1;
for (var i = 0; i < shape.Length; i++)
{

2
UnitySDK/Assets/ML-Agents/Scripts/Sensor/RayPerceptionSensor.cs


{
}
public int[] GetFloatObservationShape()
public int[] GetObservationShape()
{
return m_Shape;
}

11
UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensor.cs


bool m_Grayscale;
string m_Name;
int[] m_Shape;
SensorCompressionType m_CompressionType;
public RenderTextureSensor(RenderTexture renderTexture, bool grayscale, string name)
public RenderTextureSensor(RenderTexture renderTexture, bool grayscale, string name,
SensorCompressionType compressionType)
{
m_RenderTexture = renderTexture;
var width = renderTexture != null ? renderTexture.width : 0;

m_Shape = new[] { height, width, grayscale ? 1 : 3 };
m_CompressionType = compressionType;
}
public string GetName()

public int[] GetFloatObservationShape()
public int[] GetObservationShape()
{
return m_Shape;
}

public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.PNG;
return m_CompressionType;
/// Converts a RenderTexture and correspinding resolution to a 2D texture.
/// Converts a RenderTexture to a 2D texture.
/// </summary>
/// <returns>The 2D texture.</returns>
/// <param name="obsTexture">RenderTexture.</param>

3
UnitySDK/Assets/ML-Agents/Scripts/Sensor/RenderTextureSensorComponent.cs


public RenderTexture renderTexture;
public string sensorName = "RenderTextureSensor";
public bool grayscale;
public SensorCompressionType compression = SensorCompressionType.PNG;
return new RenderTextureSensor(renderTexture, grayscale, sensorName);
return new RenderTextureSensor(renderTexture, grayscale, sensorName, compression);
}
public override int[] GetObservationShape()

6
UnitySDK/Assets/ML-Agents/Scripts/Sensor/SensorBase.cs


{
/// <summary>
/// Write the observations to the output buffer. This size of the buffer will be product of the sizes returned
/// by GetFloatObservationShape().
/// by GetObservationShape().
public abstract int[] GetFloatObservationShape();
public abstract int[] GetObservationShape();
public abstract string GetName();

/// <param name="adapter"></param>
public virtual int Write(WriteAdapter adapter)
{
// TODO reuse buffer for similar agents, don't call GetFloatObservationShape()
// TODO reuse buffer for similar agents, don't call GetObservationShape()
var numFloats = this.ObservationSize();
float[] buffer = new float[numFloats];
WriteObservation(buffer);

7
UnitySDK/Assets/ML-Agents/Scripts/Sensor/StackingSensor.cs


m_Name = $"StackingSensor_size{numStackedObservations}_{wrapped.GetName()}";
var shape = wrapped.GetFloatObservationShape();
var shape = wrapped.GetObservationShape();
m_Shape = new int[shape.Length];
m_UnstackedObservationSize = wrapped.ObservationSize();

public int Write(WriteAdapter adapter)
{
// First, call the wrapped sensor's write method. Make sure to use our own adapater, not the passed one.
m_LocalAdapter.SetTarget(m_StackedObservations[m_CurrentIndex], 0);
var wrappedShape = m_WrappedSensor.GetObservationShape();
m_LocalAdapter.SetTarget(m_StackedObservations[m_CurrentIndex], wrappedShape, 0);
m_WrappedSensor.Write(m_LocalAdapter);
// Now write the saved observations (oldest first)

m_CurrentIndex = (m_CurrentIndex + 1) % m_NumStackedObservations;
}
public int[] GetFloatObservationShape()
public int[] GetObservationShape()
{
return m_Shape;
}

2
UnitySDK/Assets/ML-Agents/Scripts/Sensor/VectorSensor.cs


Clear();
}
public int[] GetFloatObservationShape()
public int[] GetObservationShape()
{
return m_Shape;
}

53
UnitySDK/Assets/ML-Agents/Scripts/Sensor/WriteAdapter.cs


using System;
using System.Collections.Generic;
using MLAgents.InferenceBrain;

TensorProxy m_Proxy;
int m_Batch;
int[] m_Shape;
/// <param name="data"></param>
/// <param name="offset"></param>
public void SetTarget(IList<float> data, int offset)
/// <param name="data">Float array or list that will be written to.</param>
/// <param name="shape">Shape of the observations to be written.</param>
/// <param name="offset">Offset from the start of the float data to write to.</param>
public void SetTarget(IList<float> data, int[] shape, int offset)
m_Batch = -1;
m_Batch = 0;
m_Shape = shape;
/// <param name="tensorProxy"></param>
/// <param name="batchIndex"></param>
/// <param name="channelOffset"></param>
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
/// <param name="tensorProxy">Tensor proxy that will be writtent to.</param>
/// <param name="shape">Shape of the observations to be written.</param>
/// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent)</param>
/// <param name="channelOffset">Offset from the start of the channel to write to.</param>
public void SetTarget(TensorProxy tensorProxy, int[] shape, int batchIndex, int channelOffset)
m_Shape = shape;
}
/// <summary>

{
set
{
// TODO check shape is 1D?
if (m_Data != null)
{
m_Data[index + m_Offset] = value;

{
set
{
// Only TensorProxy supports 3D access
m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value;
if (m_Data != null)
{
var height = m_Shape[0];
var width = m_Shape[1];
var channels = m_Shape[2];
if (h < 0 || h >= height)
{
throw new IndexOutOfRangeException($"height value {h} must be in range [0, {height-1}]");
}
if (w < 0 || w >= width)
{
throw new IndexOutOfRangeException($"width value {w} must be in range [0, {width-1}]");
}
if (ch < 0 || ch >= channels)
{
throw new IndexOutOfRangeException($"channel value {ch} must be in range [0, {channels-1}]");
}
// Math copied from TensorShape.Index(). Note that m_Batch should always be 0
var index = m_Batch * height * width * channels + h * width * channels + w * channels + ch;
m_Data[index + m_Offset] = value;
}
else
{
m_Proxy.data[m_Batch, h, w, ch + m_Offset] = value;
}
}
}

8
ml-agents-envs/mlagents_envs/exception.py


pass
class UnityObservationException(UnityException):
"""
Related to errors with receiving observations.
"""
pass
class UnityActionException(UnityException):
"""
Related to errors with sending actions.

43
ml-agents-envs/mlagents_envs/rpc_utils.py


from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult
from mlagents_envs.exception import UnityObservationException
from mlagents_envs.communicator_objects.observation_pb2 import (
ObservationProto,
NONE as COMPRESSION_NONE,
)
from typing import cast, List, Tuple, Union, Collection
from typing import cast, List, Tuple, Union, Collection, Optional, Iterable
from PIL import Image
logger = logging.getLogger("mlagents_envs")

image = Image.open(io.BytesIO(image_bytearray))
# Normally Image loads lazily, this forces it to do loading in the timer scope.
image.load()
s = np.array(image) / 255.0
s = np.array(image, dtype=np.float32) / 255.0
if gray_scale:
s = np.mean(s, axis=2)
s = np.reshape(s, [s.shape[0], s.shape[1], 1])

@timed
def observation_to_np_array(
obs: ObservationProto, expected_shape: Optional[Iterable[int]] = None
) -> np.ndarray:
"""
Converts observation proto into numpy array of the appropriate size.
:param obs: observation proto to be converted
:param expected_shape: optional shape information, used for sanity checks.
:return: processed numpy array of observation from environment
"""
if expected_shape is not None:
if list(obs.shape) != list(expected_shape):
raise UnityObservationException(
f"Observation did not have the expected shape - got {obs.shape} but expected {expected_shape}"
)
gray_scale = obs.shape[2] == 1
if obs.compression_type == COMPRESSION_NONE:
img = np.array(obs.float_data.data, dtype=np.float32)
img = np.reshape(img, obs.shape)
return img
else:
img = process_pixels(obs.compressed_data, gray_scale)
# Compare decompressed image size to observation shape and make sure they match
if list(obs.shape) != list(img.shape):
raise UnityObservationException(
f"Decompressed observation did not have the expected shape - "
f"decompressed had {img.shape} but expected {obs.shape}"
)
return img
@timed
def _process_visual_observation(
obs_index: int,
shape: Tuple[int, int, int],

if len(agent_info_list) == 0:
return np.zeros((0, shape[0], shape[1], shape[2]), dtype=np.float32)
gray_scale = shape[2] == 1
process_pixels(agent_obs.observations[obs_index].compressed_data, gray_scale)
observation_to_np_array(agent_obs.observations[obs_index], shape)
for agent_obs in agent_info_list
]
return np.array(batched_visual, dtype=np.float32)

43
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


import io
import numpy as np
import pytest
from mlagents_envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents_envs.communicator_objects.observation_pb2 import (
ObservationProto,

from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
import numpy as np
import io
from mlagents_envs.exception import UnityObservationException
from mlagents_envs.rpc_utils import (
agent_group_spec_from_proto,
process_pixels,

return obs_proto
def generate_uncompressed_proto_obs(in_array: np.ndarray) -> ObservationProto:
obs_proto = ObservationProto()
obs_proto.float_data.data.extend(in_array.flatten().tolist())
obs_proto.compression_type = NONE
obs_proto.shape.extend(in_array.shape)
return obs_proto
in_array = np.random.rand(128, 128, 3)
in_array = np.random.rand(128, 64, 3)
assert out_array.shape == (128, 128, 3)
assert out_array.shape == (128, 64, 3)
in_array = np.random.rand(128, 128, 3)
in_array = np.random.rand(128, 64, 3)
assert out_array.shape == (128, 128, 1)
assert out_array.shape == (128, 64, 1)
assert np.mean(in_array.mean(axis=2, keepdims=True) - out_array) < 0.01
assert (in_array.mean(axis=2, keepdims=True) - out_array < 0.01).all()

def test_process_visual_observation():
in_array_1 = np.random.rand(128, 128, 3)
in_array_1 = np.random.rand(128, 64, 3)
in_array_2 = np.random.rand(128, 128, 3)
proto_obs_2 = generate_compressed_proto_obs(in_array_2)
in_array_2 = np.random.rand(128, 64, 3)
proto_obs_2 = generate_uncompressed_proto_obs(in_array_2)
arr = _process_visual_observation(0, (128, 128, 3), ap_list)
assert list(arr.shape) == [2, 128, 128, 3]
arr = _process_visual_observation(0, (128, 64, 3), ap_list)
assert list(arr.shape) == [2, 128, 64, 3]
def test_process_visual_observation_bad_shape():
in_array_1 = np.random.rand(128, 64, 3)
proto_obs_1 = generate_compressed_proto_obs(in_array_1)
ap1 = AgentInfoProto()
ap1.observations.extend([proto_obs_1])
ap_list = [ap1]
with pytest.raises(UnityObservationException):
_process_visual_observation(0, (128, 42, 3), ap_list)
def test_batched_step_result_from_proto():

32
ml-agents/mlagents/trainers/brain.py


import logging
import numpy as np
import io
from mlagents_envs.timers import hierarchical_timer, timed
from mlagents_envs.timers import timed
from mlagents_envs import rpc_utils
from PIL import Image
logger = logging.getLogger("mlagents.trainers")

@staticmethod
@timed
def process_pixels(image_bytes: bytes, gray_scale: bool) -> np.ndarray:
"""
Converts byte array observation image into numpy array, re-sizes it,
and optionally converts it to grey scale
:param gray_scale: Whether to convert the image to grayscale.
:param image_bytes: input byte array corresponding to image
:return: processed numpy array of observation from environment
"""
with hierarchical_timer("image_decompress"):
image_bytearray = bytearray(image_bytes)
image = Image.open(io.BytesIO(image_bytearray))
# Normally Image loads lazily, this forces it to do loading in the timer scope.
image.load()
s = np.array(image) / 255.0
if gray_scale:
s = np.mean(s, axis=2)
s = np.reshape(s, [s.shape[0], s.shape[1], 1])
return s
@staticmethod
@timed
def from_agent_proto(
worker_id: int,
agent_info_list: Collection[

vis_obs: List[np.ndarray] = []
for i in range(brain_params.number_visual_observations):
# TODO check compression type, handle uncompressed visuals
BrainInfo.process_pixels(
agent_obs[i].compressed_data,
brain_params.camera_resolutions[i].gray_scale,
rpc_utils.observation_to_np_array(
agent_obs[i], brain_params.camera_resolutions[i]
)
for agent_obs in visual_observation_protos
]

105
UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/FloatVisualSensorTests.cs


using NUnit.Framework;
using UnityEngine;
using MLAgents.Sensor;
namespace MLAgents.Tests
{
public class Float2DSensor : ISensor
{
public int Width { get; }
public int Height { get; }
string m_Name;
int[] m_Shape;
public float[,] floatData;
public Float2DSensor(int width, int height, string name)
{
Width = width;
Height = height;
m_Name = name;
m_Shape = new[] { height, width, 1 };
floatData = new float[Height, Width];
}
public Float2DSensor(float[,] floatData, string name)
{
this.floatData = floatData;
Height = floatData.GetLength(0);
Width = floatData.GetLength(1);
m_Name = name;
m_Shape = new[] { Height, Width, 1 };
}
public string GetName()
{
return m_Name;
}
public int[] GetObservationShape()
{
return m_Shape;
}
public byte[] GetCompressedObservation()
{
return null;
}
public int Write(WriteAdapter adapter)
{
using (TimerStack.Instance.Scoped("Float2DSensor.Write"))
{
for (var h = 0; h < Height; h++)
{
for (var w = 0; w < Width; w++)
{
adapter[h, w, 0] = floatData[h, w];
}
}
var numWritten = Height * Width;
return numWritten;
}
}
public void Update() { }
public SensorCompressionType GetCompressionType()
{
return SensorCompressionType.None;
}
}
public class FloatVisualSensorTests
{
[Test]
public void TestFloat2DSensorWrite()
{
var sensor = new Float2DSensor(3, 4, "floatsensor");
for (var h = 0; h < 4; h++)
{
for (var w = 0; w < 3; w++)
{
sensor.floatData[h, w] = 3 * h + w;
}
}
var output = new float[12];
var writer = new WriteAdapter();
writer.SetTarget(output, sensor.GetObservationShape(), 0);
sensor.Write(writer);
for (var i = 0; i < 9; i++)
{
Assert.AreEqual(i, output[i]);
}
}
[Test]
public void TestFloat2DSensorExternalData()
{
var data = new float[4, 3];
var sensor = new Float2DSensor(data, "floatsensor");
Assert.AreEqual(sensor.Height, 4);
Assert.AreEqual(sensor.Width, 3);
}
}
}

3
UnitySDK/Assets/ML-Agents/Editor/Tests/Sensor/FloatVisualSensorTests.cs.meta


fileFormatVersion: 2
guid: 49b7da14949a486b803e28ed32d91a09
timeCreated: 1578093005
正在加载...
取消
保存