浏览代码

[MLA-1880] Raycast sensor interface improvements (#5222)

* WIP

* remove debug info struct

* cleanup + add to test

* changelog

* fix unit tests

* PR feedback
/check-for-ModelOverriders
GitHub 3 年前
当前提交
1b3e0ea3
共有 5 个文件被更改,包括 107 次插入84 次删除
  1. 5
      com.unity.ml-agents/CHANGELOG.md
  2. 141
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs
  3. 35
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs
  4. 5
      com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs
  5. 5
      com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs

5
com.unity.ml-agents/CHANGELOG.md


- Make com.unity.modules.physics and com.unity.modules.physics2d optional dependencies. (#5112)
- The default `InferenceDevice` is now `InferenceDevice.Default`, which is equivalent to `InferenceDevice.Burst`. If you
depend on the previous behavior, you can explicitly set the Agent's `InferenceDevice` to `InferenceDevice.CPU`. (#5175)
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- Added support for `Goal Signal` as a type of observation. Trainers can now use HyperNetworks to process `Goal Signal`. Trainers with HyperNetworks are more effective at solving multiple tasks. (#5142, #5159, #5149)
- Modified the [GridWorld environment](https://github.com/Unity-Technologies/ml-agents/blob/main/docs/Learning-Environment-Examples.md#gridworld) to use the new `Goal Signal` feature. (#5193)
- `RaycastPerceptionSensor` now caches its raycast results; they can be accessed via `RayPerceptionSensor.RayPerceptionOutput`. (#5222)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)

141
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


/// </summary>
public GameObject HitGameObject;
/// <summary>
/// Start position of the ray in world space.
/// </summary>
public Vector3 StartPositionWorld;
/// <summary>
/// End position of the ray in world space.
/// </summary>
public Vector3 EndPositionWorld;
/// <summary>
/// The scaled length of the ray.
/// </summary>
/// <remarks>
/// If there is non-(1,1,1) scale, |EndPositionWorld - StartPositionWorld| will be different from
/// the input rayLength.
/// </remarks>
public float ScaledRayLength
{
get
{
var rayDirection = EndPositionWorld - StartPositionWorld;
return rayDirection.magnitude;
}
}
/// <summary>
/// The scaled size of the cast.
/// </summary>
/// <remarks>
/// If there is non-(1,1,1) scale, the cast radius will be also be scaled.
/// </remarks>
public float ScaledCastRadius;
/// <summary>
/// Writes the ray output information to a subset of the float array. Each element in the rayAngles array

}
/// <summary>
/// Debug information for the raycast hits. This is used by the RayPerceptionSensorComponent.
/// </summary>
internal class DebugDisplayInfo
{
public struct RayInfo
{
public Vector3 worldStart;
public Vector3 worldEnd;
public float castRadius;
public RayPerceptionOutput.RayOutput rayOutput;
}
public void Reset()
{
m_Frame = Time.frameCount;
}
/// <summary>
/// "Age" of the results in number of frames. This is used to adjust the alpha when drawing.
/// </summary>
public int age
{
get { return Time.frameCount - m_Frame; }
}
public RayInfo[] rayInfos;
int m_Frame;
}
/// <summary>
/// A sensor implementation that supports ray cast-based observations.
/// </summary>
public class RayPerceptionSensor : ISensor, IBuiltInSensor

string m_Name;
RayPerceptionInput m_RayPerceptionInput;
RayPerceptionOutput m_RayPerceptionOutput;
DebugDisplayInfo m_DebugDisplayInfo;
/// <summary>
/// Time.frameCount at the last time Update() was called. This is only used for display in gizmos.
/// </summary>
int m_DebugLastFrameCount;
internal DebugDisplayInfo debugDisplayInfo
internal int DebugLastFrameCount
get { return m_DebugDisplayInfo; }
get { return m_DebugLastFrameCount; }
}
/// <summary>

SetNumObservations(rayInput.OutputSize());
if (Application.isEditor)
{
m_DebugDisplayInfo = new DebugDisplayInfo();
}
m_DebugLastFrameCount = Time.frameCount;
m_RayPerceptionOutput = new RayPerceptionOutput();
}
/// <summary>
/// The most recent raycast results.
/// </summary>
public RayPerceptionOutput RayPerceptionOutput
{
get { return m_RayPerceptionOutput; }
}
void SetNumObservations(int numObservations)

using (TimerStack.Instance.Scoped("RayPerceptionSensor.Perceive"))
{
Array.Clear(m_Observations, 0, m_Observations.Length);
if (m_DebugDisplayInfo != null)
{
// Reset the age information, and resize the buffer if needed.
m_DebugDisplayInfo.Reset();
if (m_DebugDisplayInfo.rayInfos == null || m_DebugDisplayInfo.rayInfos.Length != numRays)
{
m_DebugDisplayInfo.rayInfos = new DebugDisplayInfo.RayInfo[numRays];
}
}
// For each ray, do the casting, and write the information to the observation buffer
// For each ray, write the information to the observation buffer
DebugDisplayInfo.RayInfo debugRay;
var rayOutput = PerceiveSingleRay(m_RayPerceptionInput, rayIndex, out debugRay);
if (m_DebugDisplayInfo != null)
{
m_DebugDisplayInfo.rayInfos[rayIndex] = debugRay;
}
rayOutput.ToFloatArray(numDetectableTags, rayIndex, m_Observations);
m_RayPerceptionOutput.RayOutputs[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations);
// Finally, add the observations to the ObservationWriter
writer.AddList(m_Observations);
}

/// <inheritdoc/>
public void Update()
{
m_DebugLastFrameCount = Time.frameCount;
var numRays = m_RayPerceptionInput.Angles.Count;
if (m_RayPerceptionOutput.RayOutputs == null || m_RayPerceptionOutput.RayOutputs.Length != numRays)
{
m_RayPerceptionOutput.RayOutputs = new RayPerceptionOutput.RayOutput[numRays];
}
// For each ray, do the casting and save the results.
for (var rayIndex = 0; rayIndex < numRays; rayIndex++)
{
m_RayPerceptionOutput.RayOutputs[rayIndex] = PerceiveSingleRay(m_RayPerceptionInput, rayIndex);
}
}
/// <inheritdoc/>

for (var rayIndex = 0; rayIndex < input.Angles.Count; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex, out debugRay);
output.RayOutputs[rayIndex] = PerceiveSingleRay(input, rayIndex);
}
return output;

/// </summary>
/// <param name="input"></param>
/// <param name="rayIndex"></param>
/// <param name="debugRayOut"></param>
int rayIndex,
out DebugDisplayInfo.RayInfo debugRayOut
int rayIndex
)
{
var unscaledRayLength = input.RayLength;

HitFraction = hitFraction,
HitTaggedObject = false,
HitTagIndex = -1,
HitGameObject = hitObject
HitGameObject = hitObject,
StartPositionWorld = startPositionWorld,
EndPositionWorld = endPositionWorld,
ScaledCastRadius = scaledCastRadius
};
if (castHit)

}
}
debugRayOut.worldStart = startPositionWorld;
debugRayOut.worldEnd = endPositionWorld;
debugRayOut.rayOutput = rayOutput;
debugRayOut.castRadius = scaledCastRadius;
return rayOutput;
}

35
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensorComponentBase.cs


}
}
internal int SensorObservationAge()
{
if (m_RaySensor != null)
{
return Time.frameCount - m_RaySensor.DebugLastFrameCount;
}
return 0;
}
if (m_RaySensor?.debugDisplayInfo?.rayInfos != null)
if (m_RaySensor?.RayPerceptionOutput?.RayOutputs != null)
var alpha = Mathf.Pow(.5f, m_RaySensor.debugDisplayInfo.age);
var alpha = Mathf.Pow(.5f, SensorObservationAge());
foreach (var rayInfo in m_RaySensor.debugDisplayInfo.rayInfos)
foreach (var rayInfo in m_RaySensor.RayPerceptionOutput.RayOutputs)
{
DrawRaycastGizmos(rayInfo, alpha);
}

rayInput.DetectableTags = null;
for (var rayIndex = 0; rayIndex < rayInput.Angles.Count; rayIndex++)
{
DebugDisplayInfo.RayInfo debugRay;
RayPerceptionSensor.PerceiveSingleRay(rayInput, rayIndex, out debugRay);
DrawRaycastGizmos(debugRay);
var rayOutput = RayPerceptionSensor.PerceiveSingleRay(rayInput, rayIndex);
DrawRaycastGizmos(rayOutput);
}
}
}

/// </summary>
void DrawRaycastGizmos(DebugDisplayInfo.RayInfo rayInfo, float alpha = 1.0f)
void DrawRaycastGizmos(RayPerceptionOutput.RayOutput rayOutput, float alpha = 1.0f)
var startPositionWorld = rayInfo.worldStart;
var endPositionWorld = rayInfo.worldEnd;
var startPositionWorld = rayOutput.StartPositionWorld;
var endPositionWorld = rayOutput.EndPositionWorld;
rayDirection *= rayInfo.rayOutput.HitFraction;
rayDirection *= rayOutput.HitFraction;
var lerpT = rayInfo.rayOutput.HitFraction * rayInfo.rayOutput.HitFraction;
var lerpT = rayOutput.HitFraction * rayOutput.HitFraction;
var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
color.a *= alpha;
Gizmos.color = color;

if (rayInfo.rayOutput.HasHit)
if (rayOutput.HasHit)
var hitRadius = Mathf.Max(rayInfo.castRadius, .05f);
var hitRadius = Mathf.Max(rayOutput.ScaledCastRadius, .05f);
Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);
}
}

5
com.unity.ml-agents/Tests/Editor/PublicAPI/PublicApiValidation.cs


sensorComponent.ObservationStacks = 2;
sensorComponent.CreateSensors();
var sensor = sensorComponent.RaySensor;
sensor.Update();
var outputs = sensor.RayPerceptionOutput;
Assert.AreEqual(outputs.RayOutputs.Length, 2*sensorComponent.RaysPerDirection + 1);
}
#endif
}

5
com.unity.ml-agents/Tests/Runtime/Sensor/RayPerceptionSensorTests.cs


{
perception.SphereCastRadius = castRadius;
var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);

perception.DetectableTags.Add(k_SphereTag);
var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];

perception.RayLayerMask = layerMask;
var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];

{
perception.SphereCastRadius = castRadius;
var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);

// Set the layer mask to either the default, or one that ignores the close cube's layer
var sensor = perception.CreateSensors()[0];
sensor.Update();
var expectedObs = (2 * perception.RaysPerDirection + 1) * (perception.DetectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationSpec().Shape[0], expectedObs);
var outputBuffer = new float[expectedObs];

正在加载...
取消
保存