浏览代码

write observations directly to protobuf (#3229)

* write observations directly to protobuf

* docstring and comment about Capacity
/asymm-envs
GitHub 5 年前
当前提交
6451f564
共有 6 个文件被更改,包括 74 次插入55 次删除
  1. 3
      UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs
  2. 23
      UnitySDK/Assets/ML-Agents/Scripts/Agent.cs
  3. 4
      UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs
  4. 12
      UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs
  5. 58
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs
  6. 29
      UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs

3
UnitySDK/Assets/ML-Agents/Editor/Tests/DemonstrationTests.cs


using System.IO.Abstractions.TestingHelpers;
using System.Reflection;
using MLAgents.CommunicatorObjects;
using MLAgents.Sensor;
namespace MLAgents.Tests
{

storedVectorActions = new[] { 0f, 1f },
};
demoStore.Record(agentInfo, new System.Collections.Generic.List<Sensor.Observation>());
demoStore.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
demoStore.Close();
}

23
UnitySDK/Assets/ML-Agents/Scripts/Agent.cs


/// </summary>
public VectorSensor collectObservationsSensor;
/// <summary>
/// Internal buffer used for generating float observations.
/// </summary>
float[] m_VectorSensorBuffer;
WriteAdapter m_WriteAdapter = new WriteAdapter();
/// MonoBehaviour function that is called when the attached GameObject
/// becomes enabled or active.
void OnEnable()

}
m_Info.actionMasks = m_ActionMasker.GetMask();
// var param = m_PolicyFactory.brainParameters; // look, no brain params!
m_Info.reward = m_Reward;
m_Info.done = m_Done;
m_Info.maxStepReached = m_MaxStepReached;

if (m_Recorder != null && m_Recorder.record && Application.isEditor)
{
if (m_VectorSensorBuffer == null)
{
// Create a buffer for writing uncompressed (i.e. float) sensor data to
m_VectorSensorBuffer = new float[sensors.GetSensorFloatObservationSize()];
}
// This is a bit of a hack - if we're in inference mode, observations won't be generated
// But we need these to be generated for the recorder. So generate them here.
var observations = new List<Observation>();
GenerateSensorData(sensors, m_VectorSensorBuffer, m_WriteAdapter, observations);
m_Recorder.WriteExperience(m_Info, observations);
m_Recorder.WriteExperience(m_Info, sensors);
}
}

4
UnitySDK/Assets/ML-Agents/Scripts/DemonstrationRecorder.cs


/// <summary>
/// Forwards AgentInfo to Demonstration Store.
/// </summary>
public void WriteExperience(AgentInfo info, List<Observation> observations)
public void WriteExperience(AgentInfo info, List<ISensor> sensors)
m_DemoStore.Record(info, observations);
m_DemoStore.Record(info, sensors);
}
public void Close()

12
UnitySDK/Assets/ML-Agents/Scripts/DemonstrationStore.cs


DemonstrationMetaData m_MetaData;
Stream m_Writer;
float m_CumulativeReward;
WriteAdapter m_WriteAdapter = new WriteAdapter();
public DemonstrationStore(IFileSystem fileSystem)
{

/// <summary>
/// Write AgentInfo experience to file.
/// </summary>
public void Record(AgentInfo info, List<Observation> observations)
public void Record(AgentInfo info, List<ISensor> sensors)
{
// Increment meta-data counters.
m_MetaData.numberExperiences++;

EndEpisode();
}
// Write AgentInfo to file.
var agentProto = info.ToInfoActionPairProto(observations);
// Generate observations and add AgentInfo to file.
var agentProto = info.ToInfoActionPairProto();
foreach (var sensor in sensors)
{
agentProto.AgentInfo.Observations.Add(sensor.GetObservationProto(m_WriteAdapter));
}
agentProto.WriteDelimitedTo(m_Writer);
}

58
UnitySDK/Assets/ML-Agents/Scripts/Grpc/GrpcExtensions.cs


/// Converts a AgentInfo to a protobuf generated AgentInfoActionPairProto
/// </summary>
/// <returns>The protobuf version of the AgentInfoActionPairProto.</returns>
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai, List<Observation> observations)
public static AgentInfoActionPairProto ToInfoActionPairProto(this AgentInfo ai)
var agentInfoProto = ai.ToAgentInfoProto(observations);
var agentInfoProto = ai.ToAgentInfoProto();
var agentActionProto = new AgentActionProto
{

/// Converts a AgentInfo to a protobuf generated AgentInfoProto
/// </summary>
/// <returns>The protobuf version of the AgentInfo.</returns>
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai, List<Observation> observations)
public static AgentInfoProto ToAgentInfoProto(this AgentInfo ai)
{
var agentInfoProto = new AgentInfoProto
{

if (ai.actionMasks != null)
{
agentInfoProto.ActionMask.AddRange(ai.actionMasks);
}
if (observations != null)
{
foreach (var obs in observations)
{
agentInfoProto.Observations.Add(obs.ToProto());
}
}
return agentInfoProto;

obsProto.Shape.AddRange(obs.Shape);
return obsProto;
}
/// <summary>
/// Generate an ObservationProto for the sensor using the provided WriteAdapter.
/// This is equivalent to producing an Observation and calling Observation.ToProto(),
/// but avoid some intermediate memory allocations.
/// </summary>
/// <param name="sensor"></param>
/// <param name="writeAdapter"></param>
/// <returns></returns>
public static ObservationProto GetObservationProto(this ISensor sensor, WriteAdapter writeAdapter)
{
var shape = sensor.GetObservationShape();
ObservationProto observationProto = null;
if (sensor.GetCompressionType() == SensorCompressionType.None)
{
var numFloats = sensor.ObservationSize();
var floatDataProto = new ObservationProto.Types.FloatData();
// Resize the float array
// TODO upgrade protobuf versions so that we can set the Capacity directly - see https://github.com/protocolbuffers/protobuf/pull/6530
for (var i = 0; i < numFloats; i++)
{
floatDataProto.Data.Add(0.0f);
}
writeAdapter.SetTarget(floatDataProto.Data, sensor.GetObservationShape(), 0);
sensor.Write(writeAdapter);
observationProto = new ObservationProto
{
FloatData = floatDataProto,
CompressionType = (CompressionTypeProto)SensorCompressionType.None,
};
}
else
{
observationProto = new ObservationProto
{
CompressedData = ByteString.CopyFrom(sensor.GetCompressedObservation()),
CompressionType = (CompressionTypeProto)sensor.GetCompressionType(),
};
}
observationProto.Shape.AddRange(shape);
return observationProto;
}
}
}

29
UnitySDK/Assets/ML-Agents/Scripts/Grpc/RpcCommunicator.cs


List<string> m_BehaviorNames = new List<string>();
bool m_NeedCommunicateThisStep;
float[] m_VectorObservationBuffer = new float[0];
List<Observation> m_ObservationBuffer = new List<Observation>();
WriteAdapter m_WriteAdapter = new WriteAdapter();
Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>();
Dictionary<string, List<IdCallbackPair>> m_ActionCallbacks = new Dictionary<string, List<IdCallbackPair>>();

}
/// <summary>
/// Sends the observations of one Agent.
/// Sends the observations of one Agent.
int numFloatObservations = sensors.GetSensorFloatObservationSize();
if (m_VectorObservationBuffer.Length < numFloatObservations)
{
m_VectorObservationBuffer = new float[numFloatObservations];
}
# if DEBUG
if (!m_SensorShapeValidators.ContainsKey(brainKey))
{

#endif
using (TimerStack.Instance.Scoped("GenerateSensorData"))
{
Agent.GenerateSensorData(sensors, m_VectorObservationBuffer, m_WriteAdapter, m_ObservationBuffer);
}
var agentInfoProto = info.ToAgentInfoProto(m_ObservationBuffer);
var agentInfoProto = info.ToAgentInfoProto();
using (TimerStack.Instance.Scoped("GenerateSensorData"))
{
foreach (var sensor in sensors)
{
var obsProto = sensor.GetObservationProto(m_WriteAdapter);
agentInfoProto.Observations.Add(obsProto);
}
}
m_ObservationBuffer.Clear();
m_NeedCommunicateThisStep = true;
if (!m_ActionCallbacks.ContainsKey(brainKey))
{

#region Handling side channels
/// <summary>
/// Registers a side channel to the communicator. The side channel will exchange
/// Registers a side channel to the communicator. The side channel will exchange
/// messages with its Python equivalent.
/// </summary>
/// <param name="sideChannel"> The side channel to be registered.</param>

正在加载...
取消
保存