浏览代码

Merge branch 'master' into soccer-fives

/soccer-fives
Andrew Cohen 4 年前
当前提交
d40ff1ff
共有 18 个文件被更改,包括 762 次插入376 次删除
  1. 1
      com.unity.ml-agents/CHANGELOG.md
  2. 24
      com.unity.ml-agents/Runtime/Agent.cs
  3. 4
      com.unity.ml-agents/Runtime/Demonstration.cs
  4. 130
      com.unity.ml-agents/Runtime/DemonstrationRecorder.cs
  5. 102
      com.unity.ml-agents/Runtime/DemonstrationStore.cs
  6. 14
      com.unity.ml-agents/Runtime/DemonstrationStore.cs.meta
  7. 2
      com.unity.ml-agents/Runtime/InferenceBrain/TensorProxy.cs
  8. 4
      com.unity.ml-agents/Runtime/Sensor/ISensor.cs
  9. 534
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs
  10. 4
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs
  11. 4
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs
  12. 38
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs
  13. 6
      com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs
  14. 44
      com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs
  15. 3
      docs/Migrating.md
  16. 2
      docs/Training-Imitation-Learning.md
  17. 145
      gym-unity/gym_unity/envs/__init__.py
  18. 77
      gym-unity/gym_unity/tests/test_gym.py

1
com.unity.ml-agents/CHANGELOG.md


- A tutorial on adding custom SideChannels was added (#3391)
- The stepping logic for the Agent and the Academy has been simplified (#3448)
- Update Barracuda to 0.6.0-preview
- The interface for `RayPerceptionSensor.PerceiveStatic()` was changed to take an input class and write to an output class.
- The checkpoint file suffix was changed from `.cptk` to `.ckpt` (#3470)
- The command-line argument used to determine the port that an environment will listen on was changed from `--port` to `--mlagents-port`.
- The method `GetStepCount()` on the Agent class has been replaced with the property getter `StepCount`

24
com.unity.ml-agents/Runtime/Agent.cs


/// Struct that contains all the information for an Agent, including its
/// observations, actions and current status, that is sent to the Brain.
/// </summary>
public struct AgentInfo
internal struct AgentInfo
{
/// <summary>
/// Keeps track of the last vector action taken by the Brain.

/// Whether or not the agent requests a decision.
bool m_RequestDecision;
/// Keeps track of the number of steps taken by the agent in this episode.
/// Note that this value is different for each agent, and may not overlap
/// with the step counter in the Academy, since agents reset based on

ActionMasker m_ActionMasker;
/// <summary>
/// Demonstration recorder.
/// Set of DemonstrationStores that the Agent will write its step information to.
/// If you use a DemonstrationRecorder component, this will automatically register its DemonstrationStore.
/// You can also add your own DemonstrationStore by calling DemonstrationRecorder.AddDemonstrationStoreToAgent()
DemonstrationRecorder m_Recorder;
internal ISet<DemonstrationStore> DemonstrationStores = new HashSet<DemonstrationStore>();
/// <summary>
/// List of sensors used to generate observations.

// Grab the "static" properties for the Agent.
m_EpisodeId = EpisodeIdCounter.GetEpisodeId();
m_PolicyFactory = GetComponent<BehaviorParameters>();
m_Recorder = GetComponent<DemonstrationRecorder>();
m_Info = new AgentInfo();
m_Action = new AgentAction();

/// becomes disabled or inactive.
void OnDisable()
{
DemonstrationStores.Clear();
// If Academy.Dispose has already been called, we don't need to unregister with it.
// We don't want to even try, because this will lazily create a new Academy!
if (Academy.IsInitialized)

// We request a decision so Python knows the Agent is done immediately
m_Brain?.RequestDecision(m_Info, sensors);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
// We also have to write any to any DemonstationStores so that they get the "done" flag.
foreach(var demoWriter in DemonstrationStores)
m_Recorder.WriteExperience(m_Info, sensors);
demoWriter.Record(m_Info, sensors);
}
UpdateRewardStats();

m_Brain.RequestDecision(m_Info, sensors);
if (m_Recorder != null && m_Recorder.record && Application.isEditor)
// If we have any DemonstrationStores, write the AgentInfo and sensors to them.
foreach(var demoWriter in DemonstrationStores)
m_Recorder.WriteExperience(m_Info, sensors);
demoWriter.Record(m_Info, sensors);
}
}

sensors[i].Update();
}
}
/// <summary>
/// Collects the vector observations of the agent.

4
com.unity.ml-agents/Runtime/Demonstration.cs


/// Used for imitation learning, or other forms of learning from data.
/// </summary>
[Serializable]
public class Demonstration : ScriptableObject
internal class Demonstration : ScriptableObject
{
public DemonstrationMetaData metaData;
public BrainParameters brainParameters;

/// Kept in a struct for easy serialization and deserialization.
/// </summary>
[Serializable]
public class DemonstrationMetaData
internal class DemonstrationMetaData
{
public int numberExperiences;
public int numberEpisodes;

130
com.unity.ml-agents/Runtime/DemonstrationRecorder.cs


using System.IO.Abstractions;
using System.Text.RegularExpressions;
using UnityEngine;
using System.Collections.Generic;
using System.IO;
namespace MLAgents
{

[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)]
public class DemonstrationRecorder : MonoBehaviour
{
[Tooltip("Whether or not to record demonstrations.")]
[Tooltip("Base demonstration file name. Will have numbers appended to make unique.")]
string m_FilePath;
[Tooltip("Base directory to write the demo files. If null, will use {Application.dataPath}/Demonstrations.")]
public string demonstrationDirectory;
public const int MaxNameLength = 16;
internal const int MaxNameLength = 16;
const string k_ExtensionType = ".demo";
IFileSystem m_FileSystem;
Agent m_Agent;
void Start()
void OnEnable()
if (Application.isEditor && record)
{
InitializeDemoStore();
}
m_Agent = GetComponent<Agent>();
if (Application.isEditor && record && m_DemoStore == null)
if (record)
InitializeDemoStore();
LazyInitialize();
/// Has no effect if the demonstration store was already created.
public void InitializeDemoStore(IFileSystem fileSystem = null)
internal DemonstrationStore LazyInitialize(IFileSystem fileSystem = null)
m_DemoStore = new DemonstrationStore(fileSystem);
if (m_DemoStore != null)
{
return m_DemoStore;
}
if (m_Agent == null)
{
m_Agent = GetComponent<Agent>();
}
m_FileSystem = fileSystem ?? new FileSystem();
if (string.IsNullOrEmpty(demonstrationName))
{
demonstrationName = behaviorParams.behaviorName;
}
if (string.IsNullOrEmpty(demonstrationDirectory))
{
demonstrationDirectory = Path.Combine(Application.dataPath, "Demonstrations");
}
var filePath = MakeDemonstrationFilePath(m_FileSystem, demonstrationDirectory, demonstrationName);
var stream = m_FileSystem.File.Create(filePath);
m_DemoStore = new DemonstrationStore(stream);
behaviorParams.fullyQualifiedBehaviorName);
behaviorParams.fullyQualifiedBehaviorName
);
AddDemonstrationStoreToAgent(m_DemoStore);
return m_DemoStore;
}
/// <summary>

public static string SanitizeName(string demoName, int maxNameLength)
internal static string SanitizeName(string demoName, int maxNameLength)
{
var rgx = new Regex("[^a-zA-Z0-9 -]");
demoName = rgx.Replace(demoName, "");

}
/// <summary>
/// Forwards AgentInfo to Demonstration Store.
/// Gets a unique path for the demonstrationName in the demonstrationDirectory.
public void WriteExperience(AgentInfo info, List<ISensor> sensors)
/// <param name="fileSystem"></param>
/// <param name="demonstrationDirectory"></param>
/// <param name="demonstrationName"></param>
/// <returns></returns>
internal static string MakeDemonstrationFilePath(
IFileSystem fileSystem, string demonstrationDirectory, string demonstrationName
)
m_DemoStore?.Record(info, sensors);
// Create the directory if it doesn't already exist
if (!fileSystem.Directory.Exists(demonstrationDirectory))
{
fileSystem.Directory.CreateDirectory(demonstrationDirectory);
}
var literalName = demonstrationName;
var filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType);
var uniqueNameCounter = 0;
while (fileSystem.File.Exists(filePath))
{
// TODO should we use a timestamp instead of a counter here? This loops an increasing number of times
// as the number of demos increases.
literalName = demonstrationName + "_" + uniqueNameCounter;
filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType);
uniqueNameCounter++;
}
return filePath;
/// <summary>
/// Close the DemonstrationStore and remove it from the Agent.
/// Has no effect if the DemonstrationStore is already closed (or wasn't opened)
/// </summary>
RemoveDemonstrationStoreFromAgent(m_DemoStore);
m_DemoStore.Close();
m_DemoStore = null;
}

/// Closes Demonstration store.
/// Clean up the DemonstrationStore when shutting down or destroying the Agent.
void OnApplicationQuit()
void OnDestroy()
if (Application.isEditor && record)
{
Close();
}
Close();
}
/// <summary>
/// Add additional DemonstrationStore to the Agent. It is still up to the user to Close this
/// DemonstrationStores when recording is done.
/// </summary>
/// <param name="demoStore"></param>
public void AddDemonstrationStoreToAgent(DemonstrationStore demoStore)
{
m_Agent.DemonstrationStores.Add(demoStore);
}
/// <summary>
/// Remove additional DemonstrationStore to the Agent. It is still up to the user to Close this
/// DemonstrationStores when recording is done.
/// </summary>
/// <param name="demoStore"></param>
public void RemoveDemonstrationStoreFromAgent(DemonstrationStore demoStore)
{
m_Agent.DemonstrationStores.Remove(demoStore);
}
}
}

102
com.unity.ml-agents/Runtime/DemonstrationStore.cs


using System.IO;
using System.IO.Abstractions;
using Google.Protobuf;
using System.Collections.Generic;

/// Responsible for writing demonstration data to file.
/// Responsible for writing demonstration data to stream (usually a file stream).
readonly IFileSystem m_FileSystem;
const string k_DemoDirectory = "Assets/Demonstrations/";
const string k_ExtensionType = ".demo";
string m_FilePath;
public DemonstrationStore(IFileSystem fileSystem)
/// <summary>
/// Create a DemonstrationStore that will write to the specified stream.
/// The stream must support writes and seeking.
/// </summary>
/// <param name="stream"></param>
public DemonstrationStore(Stream stream)
if (fileSystem != null)
{
m_FileSystem = fileSystem;
}
else
{
m_FileSystem = new FileSystem();
}
m_Writer = stream;
/// Initializes the Demonstration Store, and writes initial data.
/// Writes the initial data to the stream.
CreateDirectory();
CreateDemonstrationFile(demonstrationName);
WriteBrainParameters(brainName, brainParameters);
}
/// <summary>
/// Checks for the existence of the Demonstrations directory
/// and creates it if it does not exist.
/// </summary>
void CreateDirectory()
{
if (!m_FileSystem.Directory.Exists(k_DemoDirectory))
if (m_Writer == null)
m_FileSystem.Directory.CreateDirectory(k_DemoDirectory);
// Already closed
return;
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProto = m_MetaData.ToProto();
metaProto.WriteDelimitedTo(m_Writer);
WriteBrainParameters(brainName, brainParameters);
/// Creates demonstration file.
/// Writes meta-data. Note that this is called at the *end* of recording, but writes to the
/// beginning of the file.
void CreateDemonstrationFile(string demonstrationName)
void WriteMetadata()
// Creates demonstration file.
var literalName = demonstrationName;
m_FilePath = k_DemoDirectory + literalName + k_ExtensionType;
var uniqueNameCounter = 0;
while (m_FileSystem.File.Exists(m_FilePath))
if (m_Writer == null)
literalName = demonstrationName + "_" + uniqueNameCounter;
m_FilePath = k_DemoDirectory + literalName + k_ExtensionType;
uniqueNameCounter++;
// Already closed
return;
m_Writer = m_FileSystem.File.Create(m_FilePath);
m_MetaData = new DemonstrationMetaData { demonstrationName = demonstrationName };
var metaProtoBytes = metaProto.ToByteArray();
m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length);
m_Writer.Seek(0, 0);
metaProto.WriteDelimitedTo(m_Writer);
}

void WriteBrainParameters(string brainName, BrainParameters brainParameters)
{
if (m_Writer == null)
{
// Already closed
return;
}
// Writes BrainParameters to file.
m_Writer.Seek(MetaDataBytes + 1, 0);
var brainProto = brainParameters.ToProto(brainName, false);

/// <summary>
/// Write AgentInfo experience to file.
/// </summary>
public void Record(AgentInfo info, List<ISensor> sensors)
internal void Record(AgentInfo info, List<ISensor> sensors)
if (m_Writer == null)
{
// Already closed
return;
}
// Increment meta-data counters.
m_MetaData.numberExperiences++;
m_CumulativeReward += info.reward;

agentProto.WriteDelimitedTo(m_Writer);
}
if (m_Writer == null)
{
// Already closed
return;
}
m_Writer = null;
}
/// <summary>

{
m_MetaData.numberEpisodes += 1;
}
/// <summary>
/// Writes meta-data.
/// </summary>
void WriteMetadata()
{
var metaProto = m_MetaData.ToProto();
var metaProtoBytes = metaProto.ToByteArray();
m_Writer.Write(metaProtoBytes, 0, metaProtoBytes.Length);
m_Writer.Seek(0, 0);
metaProto.WriteDelimitedTo(m_Writer);
}
}
}

14
com.unity.ml-agents/Runtime/DemonstrationStore.cs.meta


fileFormatVersion: 2
guid: a79c7ccb2cd042b5b1e710b9588d921b
timeCreated: 1537388072
fileFormatVersion: 2
guid: a79c7ccb2cd042b5b1e710b9588d921b
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

2
com.unity.ml-agents/Runtime/InferenceBrain/TensorProxy.cs


/// allowing the user to specify everything but the data in a graphical way.
/// </summary>
[Serializable]
public class TensorProxy
internal class TensorProxy
{
public enum TensorType
{

4
com.unity.ml-agents/Runtime/Sensor/ISensor.cs


/// Note that this (and GetCompressedObservation) may be called multiple times per agent step, so should not
/// mutate any internal state.
/// </summary>
/// <param name="adapater"></param>
/// <param name="adapter"></param>
int Write(WriteAdapter adapater);
int Write(WriteAdapter adapter);
/// <summary>
/// Return a compressed representation of the observation. For small observations, this should generally not be

534
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs


namespace MLAgents
{
public class RayPerceptionSensor : ISensor
/// <summary>
/// Determines which dimensions the sensor will perform the casts in.
/// </summary>
public enum RayPerceptionCastType
{
/// Cast in 2 dimensions, using Physics2D.CircleCast or Physics2D.RayCast.
Cast2D,
/// Cast in 3 dimensions, using Physics.SphereCast or Physics.RayCast.
Cast3D,
}
public struct RayPerceptionInput
public enum CastType
{
Cast2D,
Cast3D,
}
/// <summary>
/// Length of the rays to cast. This will be scaled up or down based on the scale of the transform.
/// </summary>
public float rayLength;
/// <summary>
/// List of tags which correspond to object types agent can see.
/// </summary>
public IReadOnlyList<string> detectableTags;
/// <summary>
/// List of angles (in degrees) used to define the rays.
/// 90 degrees is considered "forward" relative to the game object.
/// </summary>
public IReadOnlyList<float> angles;
/// <summary>
/// Starting height offset of ray from center of agent
/// </summary>
public float startOffset;
float[] m_Observations;
int[] m_Shape;
string m_Name;
/// <summary>
/// Ending height offset of ray from center of agent.
/// </summary>
public float endOffset;
float m_RayDistance;
List<string> m_DetectableObjects;
float[] m_Angles;
/// <summary>
/// Radius of the sphere to use for spherecasting.
/// If 0 or less, rays are used instead - this may be faster, especially for complex environments.
/// </summary>
public float castRadius;
float m_StartOffset;
float m_EndOffset;
float m_CastRadius;
CastType m_CastType;
Transform m_Transform;
int m_LayerMask;
/// <summary>
/// Transform of the GameObject.
/// </summary>
public Transform transform;
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// Whether to perform the casts in 2D or 3D.
public class DebugDisplayInfo
public RayPerceptionCastType castType;
/// <summary>
/// Filtering options for the casts.
/// </summary>
public int layerMask;
/// <summary>
/// Returns the expected number of floats in the output.
/// </summary>
/// <returns></returns>
public int OutputSize()
public struct RayInfo
return (detectableTags.Count + 2) * angles.Count;
}
/// <summary>
/// Get the cast start and end points for the given ray index/
/// </summary>
/// <param name="rayIndex"></param>
/// <returns>A tuple of the start and end positions in world space.</returns>
public (Vector3 StartPositionWorld, Vector3 EndPositionWorld) RayExtents(int rayIndex)
{
var angle = angles[rayIndex];
Vector3 startPositionLocal, endPositionLocal;
if (castType == RayPerceptionCastType.Cast3D)
public Vector3 localStart;
public Vector3 localEnd;
public Vector3 worldStart;
public Vector3 worldEnd;
public bool castHit;
public float hitFraction;
public float castRadius;
startPositionLocal = new Vector3(0, startOffset, 0);
endPositionLocal = PolarToCartesian3D(rayLength, angle);
endPositionLocal.y += endOffset;
public void Reset()
else
m_Frame = Time.frameCount;
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2();
endPositionLocal = PolarToCartesian2D(rayLength, angle);
var startPositionWorld = transform.TransformPoint(startPositionLocal);
var endPositionWorld = transform.TransformPoint(endPositionLocal);
return (StartPositionWorld: startPositionWorld, EndPositionWorld: endPositionWorld);
}
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static internal Vector3 PolarToCartesian3D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector3(x, 0f, z);
}
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static internal Vector2 PolarToCartesian2D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector2(x, y);
}
}
public class RayPerceptionOutput
{
public struct RayOutput
{
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// Whether or not the ray hit anything.
/// </summary>
public bool hasHit;
/// <summary>
/// Whether or not the ray hit an object whose tag is in the input's detectableTags list.
/// </summary>
public bool hitTaggedObject;
/// <summary>
/// The index of the hit object's tag in the detectableTags list, or -1 if there was no hit, or the
/// hit object has a different tag.
/// </summary>
public int hitTagIndex;
/// <summary>
/// Normalized distance to the hit object.
/// </summary>
public float hitFraction;
/// <summary>
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array
/// determines a sublist of data to the observation. The sublist contains the observation data for a single cast.
/// The list is composed of the following:
/// 1. A one-hot encoding for detectable tags. For example, if detectableTags.Length = n, the
/// first n elements of the sublist will be a one-hot encoding of the detectableTag that was hit, or
/// all zeroes otherwise.
/// 2. The 'numDetectableTags' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
/// something (detectable or not).
/// 3. The 'numDetectableTags+1' element of the sublist will contain the normalized distance to the object
/// hit, or 1.0 if nothing was hit.
public int age
/// <param name="numDetectableTags"></param>
/// <param name="rayIndex"></param>
/// <param name="buffer">Output buffer. The size must be equal to (numDetectableTags+2) * rayOutputs.Length</param>
public void ToFloatArray(int numDetectableTags, int rayIndex, float[] buffer)
get { return Time.frameCount - m_Frame; }
var bufferOffset = (numDetectableTags + 2) * rayIndex;
if (hitTaggedObject)
{
buffer[bufferOffset + hitTagIndex] = 1f;
}
buffer[bufferOffset + numDetectableTags] = hasHit ? 0f : 1f;
buffer[bufferOffset + numDetectableTags + 1] = hitFraction;
}
/// <summary>
/// RayOutput for each ray that was cast.
/// </summary>
public RayOutput[] rayOutputs;
}
public RayInfo[] rayInfos;
/// <summary>
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// </summary>
internal class DebugDisplayInfo
{
public struct RayInfo
{
public Vector3 worldStart;
public Vector3 worldEnd;
public float castRadius;
public RayPerceptionOutput.RayOutput rayOutput;
}
public void Reset()
{
m_Frame = Time.frameCount;
}
int m_Frame;
/// <summary>
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// </summary>
public int age
{
get { return Time.frameCount - m_Frame; }
public RayInfo[] rayInfos;
int m_Frame;
}
public class RayPerceptionSensor : ISensor
{
float[] m_Observations;
int[] m_Shape;
string m_Name;
RayPerceptionInput m_RayPerceptionInput;
public DebugDisplayInfo debugDisplayInfo
internal DebugDisplayInfo debugDisplayInfo
public RayPerceptionSensor(string name, float rayDistance, List<string> detectableObjects, float[] angles,
Transform transform, float startOffset, float endOffset, float castRadius, CastType castType,
int rayLayerMask)
public RayPerceptionSensor(string name, RayPerceptionInput rayInput)
var numObservations = (detectableObjects.Count + 2) * angles.Length;
var numObservations = rayInput.OutputSize();
m_RayPerceptionInput = rayInput;
m_RayDistance = rayDistance;
m_DetectableObjects = detectableObjects;
// TODO - preprocess angles, save ray directions instead?
m_Angles = angles;
m_Transform = transform;
m_StartOffset = startOffset;
m_EndOffset = endOffset;
m_CastRadius = castRadius;
m_CastType = castType;
m_LayerMask = rayLayerMask;
if (Application.isEditor)
{

{
using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
{
PerceiveStatic(
m_RayDistance, m_Angles, m_DetectableObjects, m_StartOffset, m_EndOffset,
m_CastRadius, m_Transform, m_CastType, m_Observations, m_LayerMask,
m_DebugDisplayInfo
);
Array.Clear(m_Observations, 0, m_Observations.Length);
var numRays = m_RayPerceptionInput.angles.Count;
var numDetectableTags = m_RayPerceptionInput.detectableTags.Count;
if (m_DebugDisplayInfo != null)
{
// Reset the age information, and resize the buffer if needed.
m_DebugDisplayInfo.Reset();
if (m_DebugDisplayInfo.rayInfos == null || m_DebugDisplayInfo.rayInfos.Length != numRays)
{
m_DebugDisplayInfo.rayInfos = new DebugDisplayInfo.RayInfo[numRays];
}
}
// For each ray, do the casting, and write the information to the observation buffer
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
var rayOutput = PerceiveSingleRay(m_RayPerceptionInput, rayIndex, out debugRay);
if (m_DebugDisplayInfo != null)
{
m_DebugDisplayInfo.rayInfos[rayIndex] = debugRay;
}
rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations);
}
// Finally, add the observations to the WriteAdapter
adapter.AddRange(m_Observations);
}
return m_Observations.Length;

}
/// <summary>
/// Evaluates a perception vector to be used as part of an observation of an agent.
/// Each element in the rayAngles array determines a sublist of data to the observation.
/// The sublist contains the observation data for a single cast. The list is composed of the following:
/// 1. A one-hot encoding for detectable objects. For example, if detectableObjects.Length = n, the
/// first n elements of the sublist will be a one-hot encoding of the detectableObject that was hit, or
/// all zeroes otherwise.
/// 2. The 'length' element of the sublist will be 1 if the ray missed everything, or 0 if it hit
/// something (detectable or not).
/// 3. The 'length+1' element of the sublist will contain the normalised distance to the object hit, or 1 if
/// nothing was hit.
///
/// Evaluates the raycasts to be used as part of an observation of an agent.
/// <param name="unscaledRayLength"></param>
/// <param name="rayAngles">List of angles (in degrees) used to define the rays. 90 degrees is considered
/// "forward" relative to the game object</param>
/// <param name="detectableObjects">List of tags which correspond to object types agent can see</param>
/// <param name="startOffset">Starting height offset of ray from center of agent.</param>
/// <param name="endOffset">Ending height offset of ray from center of agent.</param>
/// <param name="unscaledCastRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// instead - this may be faster, especially for complex environments.</param>
/// <param name="transform">Transform of the GameObject</param>
/// <param name="castType">Whether to perform the casts in 2D or 3D.</param>
/// <param name="perceptionBuffer">Output array of floats. Must be (num rays) * (num tags + 2) in size.</param>
/// <param name="layerMask">Filtering options for the casts</param>
/// <param name="debugInfo">Optional debug information output, only used by RayPerceptionSensor.</param>
/// <param name="input">Input defining the rays that will be cast.</param>
/// <param name="output">Output class that will be written to with raycast results.</param>
public static void PerceiveStatic(float unscaledRayLength,
IReadOnlyList<float> rayAngles, IReadOnlyList<string> detectableObjects,
float startOffset, float endOffset, float unscaledCastRadius,
Transform transform, CastType castType, float[] perceptionBuffer,
int layerMask = Physics.DefaultRaycastLayers,
DebugDisplayInfo debugInfo = null)
public static RayPerceptionOutput PerceiveStatic(RayPerceptionInput input)
Array.Clear(perceptionBuffer, 0, perceptionBuffer.Length);
if (debugInfo != null)
RayPerceptionOutput output = new RayPerceptionOutput();
output.rayOutputs = new RayPerceptionOutput.RayOutput[input.angles.Count];
for (var rayIndex = 0; rayIndex < input.angles.Count; rayIndex++)
debugInfo.Reset();
if (debugInfo.rayInfos == null || debugInfo.rayInfos.Length != rayAngles.Count)
{
debugInfo.rayInfos = new DebugDisplayInfo.RayInfo[rayAngles.Count];
}
DebugDisplayInfo.RayInfo debugRay;
output.rayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex, out debugRay);
// For each ray sublist stores categorical information on detected object
// along with object distance.
int bufferOffset = 0;
for (var rayIndex = 0; rayIndex < rayAngles.Count; rayIndex++)
{
var angle = rayAngles[rayIndex];
Vector3 startPositionLocal, endPositionLocal;
if (castType == CastType.Cast3D)
{
startPositionLocal = new Vector3(0, startOffset, 0);
endPositionLocal = PolarToCartesian3D(unscaledRayLength, angle);
endPositionLocal.y += endOffset;
}
else
{
// Vector2s here get converted to Vector3s (and back to Vector2s for casting)
startPositionLocal = new Vector2();
endPositionLocal = PolarToCartesian2D(unscaledRayLength, angle);
}
return output;
}
var startPositionWorld = transform.TransformPoint(startPositionLocal);
var endPositionWorld = transform.TransformPoint(endPositionLocal);
/// <summary>
/// Evaluate the raycast results of a single ray from the RayPerceptionInput.
/// </summary>
/// <param name="input"></param>
/// <param name="rayIndex"></param>
/// <param name="debugRayOut"></param>
/// <returns></returns>
static RayPerceptionOutput.RayOutput PerceiveSingleRay(
RayPerceptionInput input,
int rayIndex,
out DebugDisplayInfo.RayInfo debugRayOut
)
{
var unscaledRayLength = input.rayLength;
var unscaledCastRadius = input.castRadius;
var rayDirection = endPositionWorld - startPositionWorld;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection.magnitude;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ? unscaledCastRadius * scaledRayLength / unscaledRayLength : unscaledCastRadius;
var extents = input.RayExtents(rayIndex);
var startPositionWorld = extents.StartPositionWorld;
var endPositionWorld = extents.EndPositionWorld;
// Do the cast and assign the hit information for each detectable object.
// sublist[0 ] <- did hit detectableObjects[0]
// ...
// sublist[numObjects-1] <- did hit detectableObjects[numObjects-1]
// sublist[numObjects ] <- 1 if missed else 0
// sublist[numObjects+1] <- hit fraction (or 1 if no hit)
var rayDirection = endPositionWorld - startPositionWorld;
// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection.magnitude;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ?
unscaledCastRadius * scaledRayLength / unscaledRayLength :
unscaledCastRadius;
bool castHit;
float hitFraction;
GameObject hitObject;
// Do the cast and assign the hit information for each detectable tag.
bool castHit;
float hitFraction;
GameObject hitObject;
if (castType == CastType.Cast3D)
if (input.castType == RayPerceptionCastType.Cast3D)
{
RaycastHit rayHit;
if (scaledCastRadius > 0f)
RaycastHit rayHit;
if (scaledCastRadius > 0f)
{
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit,
scaledRayLength, layerMask);
}
else
{
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
scaledRayLength, layerMask);
}
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit,
scaledRayLength, input.layerMask);
RaycastHit2D rayHit;
if (scaledCastRadius > 0f)
{
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection,
scaledRayLength, layerMask);
}
else
{
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, layerMask);
}
castHit = rayHit;
hitFraction = castHit ? rayHit.fraction : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
castHit = Physics.Raycast(startPositionWorld, rayDirection, out rayHit,
scaledRayLength, input.layerMask);
if (debugInfo != null)
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
}
else
{
RaycastHit2D rayHit;
if (scaledCastRadius > 0f)
debugInfo.rayInfos[rayIndex].localStart = startPositionLocal;
debugInfo.rayInfos[rayIndex].localEnd = endPositionLocal;
debugInfo.rayInfos[rayIndex].worldStart = startPositionWorld;
debugInfo.rayInfos[rayIndex].worldEnd = endPositionWorld;
debugInfo.rayInfos[rayIndex].castHit = castHit;
debugInfo.rayInfos[rayIndex].hitFraction = hitFraction;
debugInfo.rayInfos[rayIndex].castRadius = scaledCastRadius;
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection,
scaledRayLength, input.layerMask);
else if (Application.isEditor)
else
// Legacy drawing
Debug.DrawRay(startPositionWorld, rayDirection, Color.black, 0.01f, true);
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, input.layerMask);
if (castHit)
castHit = rayHit;
hitFraction = castHit ? rayHit.fraction : 1.0f;
hitObject = castHit ? rayHit.collider.gameObject : null;
}
var rayOutput = new RayPerceptionOutput.RayOutput
{
hasHit = castHit,
hitFraction = hitFraction,
hitTaggedObject = false,
hitTagIndex = -1
};
if (castHit)
{
// Find the index of the tag of the object that was hit.
for (var i = 0; i < input.detectableTags.Count; i++)
bool hitTaggedObject = false;
for (var i = 0; i < detectableObjects.Count; i++)
if (hitObject.CompareTag(input.detectableTags[i]))
if (hitObject.CompareTag(detectableObjects[i]))
{
perceptionBuffer[bufferOffset + i] = 1;
perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = hitFraction;
hitTaggedObject = true;
break;
}
}
if (!hitTaggedObject)
{
// Something was hit but not on the list. Still set the hit fraction.
perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = hitFraction;
rayOutput.hitTaggedObject = true;
rayOutput.hitTagIndex = i;
break;
else
{
perceptionBuffer[bufferOffset + detectableObjects.Count] = 1f;
// Nothing was hit, so there's full clearance in front of the agent.
perceptionBuffer[bufferOffset + detectableObjects.Count + 1] = 1.0f;
}
}
bufferOffset += detectableObjects.Count + 2;
}
}
debugRayOut.worldStart = startPositionWorld;
debugRayOut.worldEnd = endPositionWorld;
debugRayOut.rayOutput = rayOutput;
debugRayOut.castRadius = scaledCastRadius;
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static Vector3 PolarToCartesian3D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var z = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector3(x, 0f, z);
}
return rayOutput;
/// <summary>
/// Converts polar coordinate to cartesian coordinate.
/// </summary>
static Vector2 PolarToCartesian2D(float radius, float angleDegrees)
{
var x = radius * Mathf.Cos(Mathf.Deg2Rad * angleDegrees);
var y = radius * Mathf.Sin(Mathf.Deg2Rad * angleDegrees);
return new Vector2(x, y);
}
}
}

4
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent2D.cs


rayLayerMask = Physics2D.DefaultRaycastLayers;
}
public override RayPerceptionSensor.CastType GetCastType()
public override RayPerceptionCastType GetCastType()
return RayPerceptionSensor.CastType.Cast2D;
return RayPerceptionCastType.Cast2D;
}
}
}

4
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponent3D.cs


[Tooltip("Ray end is offset up or down by this amount.")]
public float endVerticalOffset;
public override RayPerceptionSensor.CastType GetCastType()
public override RayPerceptionCastType GetCastType()
return RayPerceptionSensor.CastType.Cast3D;
return RayPerceptionCastType.Cast3D;
}
public override float GetStartVerticalOffset()

38
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs


[Header("Debug Gizmos", order = 999)]
public Color rayHitColor = Color.red;
public Color rayMissColor = Color.white;
[Tooltip("Whether to draw the raycasts in the world space of when they happened, or using the Agent's current transform'")]
public bool useWorldPositions = true;
public abstract RayPerceptionSensor.CastType GetCastType();
public abstract RayPerceptionCastType GetCastType();
public virtual float GetStartVerticalOffset()
{

public override ISensor CreateSensor()
{
var rayAngles = GetRayAngles(raysPerDirection, maxRayDegrees);
m_RaySensor = new RayPerceptionSensor(sensorName, rayLength, detectableTags, rayAngles,
transform, GetStartVerticalOffset(), GetEndVerticalOffset(), sphereCastRadius, GetCastType(),
rayLayerMask
);
var rayPerceptionInput = new RayPerceptionInput();
rayPerceptionInput.rayLength = rayLength;
rayPerceptionInput.detectableTags = detectableTags;
rayPerceptionInput.angles = rayAngles;
rayPerceptionInput.startOffset = GetStartVerticalOffset();
rayPerceptionInput.endOffset = GetEndVerticalOffset();
rayPerceptionInput.castRadius = sphereCastRadius;
rayPerceptionInput.transform = transform;
rayPerceptionInput.castType = GetCastType();
rayPerceptionInput.layerMask = rayLayerMask;
m_RaySensor = new RayPerceptionSensor(sensorName, rayPerceptionInput);
if (observationStacks != 1)
{

public override int[] GetObservationShape()
{
var numRays = 2 * raysPerDirection + 1;
var numTags = detectableTags == null ? 0 : detectableTags.Count;
var numTags = detectableTags?.Count ?? 0;
var obsSize = (numTags + 2) * numRays;
var stacks = observationStacks > 1 ? observationStacks : 1;
return new[] { obsSize * stacks };

foreach (var rayInfo in debugInfo.rayInfos)
{
// Either use the original world-space coordinates of the raycast, or transform the agent-local
// coordinates of the rays to the current transform of the agent. If the agent acts every frame,
// these should be the same.
if (!useWorldPositions)
{
startPositionWorld = transform.TransformPoint(rayInfo.localStart);
endPositionWorld = transform.TransformPoint(rayInfo.localEnd);
}
rayDirection *= rayInfo.hitFraction;
rayDirection *= rayInfo.rayOutput.hitFraction;
var lerpT = rayInfo.hitFraction * rayInfo.hitFraction;
var lerpT = rayInfo.rayOutput.hitFraction * rayInfo.rayOutput.hitFraction;
var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
color.a *= alpha;
Gizmos.color = color;

if (rayInfo.castHit)
if (rayInfo.rayOutput.hasHit)
{
var hitRadius = Mathf.Max(rayInfo.castRadius, .05f);
Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);

6
com.unity.ml-agents/Runtime/Sensor/WriteAdapter.cs


TensorShape m_TensorShape;
internal WriteAdapter() { }
/// <summary>
/// Set the adapter to write to an IList at the given channelOffset.
/// </summary>

public void SetTarget(IList<float> data, int[] shape, int offset)
internal void SetTarget(IList<float> data, int[] shape, int offset)
{
m_Data = data;
m_Offset = offset;

/// <param name="tensorProxy">Tensor proxy that will be writtent to.</param>
/// <param name="batchIndex">Batch index in the tensor proxy (i.e. the index of the Agent)</param>
/// <param name="channelOffset">Offset from the start of the channel to write to.</param>
public void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
internal void SetTarget(TensorProxy tensorProxy, int batchIndex, int channelOffset)
{
m_Proxy = tensorProxy;
m_Batch = batchIndex;

44
com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs


[TestFixture]
public class DemonstrationTests : MonoBehaviour
{
const string k_DemoDirecory = "Assets/Demonstrations/";
const string k_DemoDirectory = "Assets/Demonstrations/";
const string k_ExtensionType = ".demo";
const string k_DemoName = "Test";

public void TestStoreInitalize()
{
var fileSystem = new MockFileSystem();
var demoStore = new DemonstrationStore(fileSystem);
Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirecory));
var gameobj = new GameObject("gameObj");
var brainParameters = new BrainParameters
{
vectorObservationSize = 3,
numStackedVectorObservations = 2,
vectorActionDescriptions = new[] { "TestActionA", "TestActionB" },
vectorActionSize = new[] { 2, 2 },
vectorActionSpaceType = SpaceType.Discrete
};
var bp = gameobj.AddComponent<BehaviorParameters>();
bp.brainParameters.vectorObservationSize = 3;
bp.brainParameters.numStackedVectorObservations = 2;
bp.brainParameters.vectorActionDescriptions = new[] { "TestActionA", "TestActionB" };
bp.brainParameters.vectorActionSize = new[] { 2, 2 };
bp.brainParameters.vectorActionSpaceType = SpaceType.Discrete;
demoStore.Initialize(k_DemoName, brainParameters, "TestBrain");
var agent = gameobj.AddComponent<TestAgent>();
Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirecory));
Assert.IsTrue(fileSystem.FileExists(k_DemoDirecory + k_DemoName + k_ExtensionType));
Assert.IsFalse(fileSystem.Directory.Exists(k_DemoDirectory));
var demoRec = gameobj.AddComponent<DemonstrationRecorder>();
demoRec.record = true;
demoRec.demonstrationName = k_DemoName;
demoRec.demonstrationDirectory = k_DemoDirectory;
var demoStore = demoRec.LazyInitialize(fileSystem);
Assert.IsTrue(fileSystem.Directory.Exists(k_DemoDirectory));
Assert.IsTrue(fileSystem.FileExists(k_DemoDirectory + k_DemoName + k_ExtensionType));
var agentInfo = new AgentInfo
{

storedVectorActions = new[] { 0f, 1f },
};
demoRec.Close();
// Make sure close can be called multiple times
demoRec.Close();
// Make sure trying to write after closing doesn't raise an error.
demoStore.Record(agentInfo, new System.Collections.Generic.List<ISensor>());
}
public class ObservationAgent : TestAgent

agentGo1.AddComponent<DemonstrationRecorder>();
var demoRecorder = agentGo1.GetComponent<DemonstrationRecorder>();
var fileSystem = new MockFileSystem();
demoRecorder.demonstrationDirectory = k_DemoDirectory;
demoRecorder.InitializeDemoStore(fileSystem);
demoRecorder.LazyInitialize(fileSystem);
var agentEnableMethod = typeof(Agent).GetMethod("OnEnable",
BindingFlags.Instance | BindingFlags.NonPublic);

3
docs/Migrating.md


### Important changes
* The `Agent.CollectObservations()` virtual method now takes as input a `VectorSensor` sensor as argument. The `Agent.AddVectorObs()` methods were removed.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
* The interface for `RayPerceptionSensor.PerceiveStatic()` was changed to take an input class and write to an output class.
* The `SetActionMask` method must now be called on the optional `ActionMasker` argument of the `CollectObservations` method. (We now consider an action mask as a type of observation)
* The method `GetStepCount()` on the Agent class has been replaced with the property getter `StepCount`

* If you call `RayPerceptionSensor.PerceiveStatic()` manually, add your inputs to a `RayPerceptionInput`. To get the previous float array output, use `RayPerceptionOutput.ToFloatArray()`
* Re-import all of your `*.NN` files to work with the updated Barracuda package.
* Replace all calls to `Agent.GetStepCount()` with `Agent.StepCount`

2
docs/Training-Imitation-Learning.md


from a few minutes or a few hours of demonstration data may be necessary to
be useful for imitation learning. When you have recorded enough data, end
the Editor play session, and a `.demo` file will be created in the
`Assets/Demonstrations` folder. This file contains the demonstrations.
`Assets/Demonstrations` folder (by default). This file contains the demonstrations.
Clicking on the file will provide metadata about the demonstration in the
inspector.

145
gym-unity/gym_unity/envs/__init__.py


import logging
import itertools
import numpy as np
from typing import Any, Dict, List, Optional, Tuple, Union, Set
from typing import Any, Dict, List, Optional, Tuple, Union
import gym
from gym import error, spaces

self.visual_obs = None
self._n_agents = -1
self._done_agents: Set[int] = set()
self.agent_mapper = AgentIdIndexMapper()
# Save the step result from the last time all Agents requested decisions.
self._previous_step_result: BatchedStepResult = None
self._multiagent = multiagent

step_result = self._env.get_step_result(self.brain_name)
self._check_agents(step_result.n_agents())
self._previous_step_result = step_result
self.agent_mapper.set_initial_agents(list(self._previous_step_result.agent_id))
# Set observation and action spaces
if self.group_spec.is_action_discrete():

"The number of agents in the scene does not match the expected number."
)
# remove the done Agents
indices_to_keep: List[int] = []
for index, is_done in enumerate(step_result.done):
if not is_done:
indices_to_keep.append(index)
if step_result.n_agents() - sum(step_result.done) != self._n_agents:
raise UnityGymException(
"The number of agents in the scene does not match the expected number."
)
for index, agent_id in enumerate(step_result.agent_id):
if step_result.done[index]:
self.agent_mapper.mark_agent_done(agent_id, step_result.reward[index])
# Set the new AgentDone flags to True
# Note that the corresponding agent_id that gets marked done will be different

if not self._previous_step_result.contains_agent(agent_id):
step_result.done[index] = True
if agent_id in self._done_agents:
# Register this agent, and get the reward of the previous agent that
# was in its index, so that we can return it to the gym.
last_reward = self.agent_mapper.register_new_agent_id(agent_id)
self._done_agents = set()
step_result.reward[index] = last_reward
# Get a permutation of the agent IDs so that a given ID stays in the same
# index as where it was first seen.
new_id_order = self.agent_mapper.get_id_permutation(list(step_result.agent_id))
_mask.append(step_result.action_mask[mask_index][indices_to_keep])
_mask.append(step_result.action_mask[mask_index][new_id_order])
new_obs.append(step_result.obs[obs_index][indices_to_keep])
new_obs.append(step_result.obs[obs_index][new_id_order])
reward=step_result.reward[indices_to_keep],
done=step_result.done[indices_to_keep],
max_step=step_result.max_step[indices_to_keep],
agent_id=step_result.agent_id[indices_to_keep],
reward=step_result.reward[new_id_order],
done=step_result.done[new_id_order],
max_step=step_result.max_step[new_id_order],
agent_id=step_result.agent_id[new_id_order],
if self._previous_step_result.n_agents() == self._n_agents:
return action
input_index = 0
for index in range(self._previous_step_result.n_agents()):
for index, agent_id in enumerate(self._previous_step_result.agent_id):
sanitized_action[index, :] = action[input_index, :]
input_index = input_index + 1
array_index = self.agent_mapper.get_gym_index(agent_id)
sanitized_action[index, :] = action[array_index, :]
return sanitized_action
def _step(self, needs_reset: bool = False) -> BatchedStepResult:

"The environment does not have the expected amount of agents."
+ "Some agents did not request decisions at the same time."
)
self._done_agents.update(list(info.agent_id))
for agent_id, reward in zip(info.agent_id, info.reward):
self.agent_mapper.mark_agent_done(agent_id, reward)
self._env.step()
info = self._env.get_step_result(self.brain_name)
return self._sanitize_info(info)

:return: The List containing the branched actions.
"""
return self.action_lookup[action]
class AgentIdIndexMapper:
def __init__(self) -> None:
self._agent_id_to_gym_index: Dict[int, int] = {}
self._done_agents_index_to_last_reward: Dict[int, float] = {}
def set_initial_agents(self, agent_ids: List[int]) -> None:
"""
Provide the initial list of agent ids for the mapper
"""
for idx, agent_id in enumerate(agent_ids):
self._agent_id_to_gym_index[agent_id] = idx
def mark_agent_done(self, agent_id: int, reward: float) -> None:
"""
Declare the agent done with the corresponding final reward.
"""
gym_index = self._agent_id_to_gym_index.pop(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
def register_new_agent_id(self, agent_id: int) -> float:
"""
Adds the new agent ID and returns the reward to use for the previous agent in this index
"""
# Any free index is OK here.
free_index, last_reward = self._done_agents_index_to_last_reward.popitem()
self._agent_id_to_gym_index[agent_id] = free_index
return last_reward
def get_id_permutation(self, agent_ids: List[int]) -> List[int]:
"""
Get the permutation from new agent ids to the order that preserves the positions of previous agents.
The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
"""
# Map the new agent ids to the their index
new_agent_ids_to_index = {
agent_id: idx for idx, agent_id in enumerate(agent_ids)
}
# Make the output list. We don't write to it sequentially, so start with dummy values.
new_permutation = [-1] * len(agent_ids)
# For each agent ID, find the new index of the agent, and write it in the original index.
for agent_id, original_index in self._agent_id_to_gym_index.items():
new_permutation[original_index] = new_agent_ids_to_index[agent_id]
return new_permutation
def get_gym_index(self, agent_id: int) -> int:
"""
Get the gym index for the current agent.
"""
return self._agent_id_to_gym_index[agent_id]
class AgentIdIndexMapperSlow:
"""
Reference implementation of AgentIdIndexMapper.
The operations are O(N^2) so it shouldn't be used for large numbers of agents.
See AgentIdIndexMapper for method descriptions
"""
def __init__(self) -> None:
self._gym_id_order: List[int] = []
self._done_agents_index_to_last_reward: Dict[int, float] = {}
def set_initial_agents(self, agent_ids: List[int]) -> None:
self._gym_id_order = list(agent_ids)
def mark_agent_done(self, agent_id: int, reward: float) -> None:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1
def register_new_agent_id(self, agent_id: int) -> float:
original_index = self._gym_id_order.index(-1)
self._gym_id_order[original_index] = agent_id
reward = self._done_agents_index_to_last_reward.pop(original_index)
return reward
def get_id_permutation(self, agent_ids):
new_id_order = []
for agent_id in self._gym_id_order:
new_id_order.append(agent_ids.index(agent_id))
return new_id_order
def get_gym_index(self, agent_id: int) -> int:
return self._gym_id_order.index(agent_id)

77
gym-unity/gym_unity/tests/test_gym.py


import numpy as np
from gym import spaces
from gym_unity.envs import UnityEnv, UnityGymException
from gym_unity.envs import (
UnityEnv,
UnityGymException,
AgentIdIndexMapper,
AgentIdIndexMapperSlow,
)
from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult

assert isinstance(info, dict)
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_shuffled_id(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=5)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
shuffled_step_result = create_mock_vector_step_result(num_agents=5)
shuffled_order = [4, 2, 3, 1, 0]
shuffled_step_result.reward = np.array(shuffled_order)
shuffled_step_result.agent_id = np.array(shuffled_order)
sanitized_result = env._sanitize_info(shuffled_step_result)
for expected_reward, reward in zip(range(5), sanitized_result.reward):
assert expected_reward == reward
for expected_agent_id, agent_id in zip(range(5), sanitized_result.agent_id):
assert expected_agent_id == agent_id
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_sanitize_action_one_agent_done(mock_env):
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[2, 2, 3]
)
mock_step = create_mock_vector_step_result(num_agents=5)
mock_step.agent_id = np.array(range(5))
setup_mock_unityenvironment(mock_env, mock_spec, mock_step)
env = UnityEnv(" ", use_visual=False, multiagent=True)
received_step_result = create_mock_vector_step_result(num_agents=6)
received_step_result.agent_id = np.array(range(6))
# agent #3 (id = 2) is Done
received_step_result.done = np.array([False] * 2 + [True] + [False] * 3)
sanitized_result = env._sanitize_info(received_step_result)
for expected_agent_id, agent_id in zip([0, 1, 5, 3, 4], sanitized_result.agent_id):
assert expected_agent_id == agent_id
# Helper methods

mock_env.return_value.get_agent_groups.return_value = ["MockBrain"]
mock_env.return_value.get_agent_group_spec.return_value = mock_spec
mock_env.return_value.get_step_result.return_value = mock_result
@pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow])
def test_agent_id_index_mapper(mapper_cls):
mapper = mapper_cls()
initial_agent_ids = [1001, 1002, 1003, 1004]
mapper.set_initial_agents(initial_agent_ids)
# Mark some agents as done with their last rewards.
mapper.mark_agent_done(1001, 42.0)
mapper.mark_agent_done(1004, 1337.0)
# Now add new agents, and get the rewards of the agent they replaced.
old_reward1 = mapper.register_new_agent_id(2001)
old_reward2 = mapper.register_new_agent_id(2002)
# Order of the rewards don't matter
assert {old_reward1, old_reward2} == {42.0, 1337.0}
new_agent_ids = [1002, 1003, 2001, 2002]
permutation = mapper.get_id_permutation(new_agent_ids)
# Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats.
assert set(permutation) == set(range(0, 4))
# For initial agents that were in the initial group, they need to be in the same slot.
# Agents that were added later can appear in any free slot.
permuted_ids = [new_agent_ids[i] for i in permutation]
for idx, agent_id in enumerate(initial_agent_ids):
if agent_id in permuted_ids:
assert permuted_ids[idx] == agent_id
正在加载...
取消
保存