浏览代码

add virtual method for custom observations

/develop/custom-raycast
Ruo-Ping Dong 3 年前
当前提交
dc1f22e3
共有 1 个文件被更改,包括 51 次插入9 次删除
  1. 60
      com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs

60
com.unity.ml-agents/Runtime/Sensors/RayPerceptionSensor.cs


public class RayPerceptionSensor : ISensor, IBuiltInSensor
{
float[] m_Observations;
float[] m_SingleRayObservations;
ObservationSpec m_ObservationSpec;
string m_Name;

m_Name = name;
m_RayPerceptionInput = rayInput;
SetNumObservations(rayInput.OutputSize());
SetNumObservations(GetObservationSizePerRay(), GetNumberOfRays());
m_DebugLastFrameCount = Time.frameCount;
m_RayPerceptionOutput = new RayPerceptionOutput();

/// The ray perception input configurations.
/// </summary>
public RayPerceptionInput RayPerceptionInput
{
get { return m_RayPerceptionInput; }
}
/// <summary>
/// The most recent raycast results.
/// </summary>
public RayPerceptionOutput RayPerceptionOutput

void SetNumObservations(int numObservations)
/// <summary>
/// The observation size per ray.
/// Override this method for custom observations.
/// </summary>
public virtual int GetObservationSizePerRay()
m_ObservationSpec = ObservationSpec.Vector(numObservations);
m_Observations = new float[numObservations];
return RayPerceptionInput.DetectableTags.Count + 2;
}
/// <summary>
/// The number of rays in the sensor.
/// Override this method for custom observations.
/// </summary>
public int GetNumberOfRays()
{
return RayPerceptionInput.Angles.Count;
}
void SetNumObservations(int observationsSize, int numRays)
{
m_ObservationSpec = ObservationSpec.Vector(observationsSize * numRays);
m_Observations = new float[observationsSize * numRays];
m_SingleRayObservations = new float[observationsSize];
}
internal void SetRayPerceptionInput(RayPerceptionInput rayInput)

if (m_RayPerceptionInput.OutputSize() != rayInput.OutputSize())
var oldObservationSize = GetObservationSizePerRay();
m_RayPerceptionInput = rayInput;
if (GetObservationSizePerRay() != oldObservationSize)
{
Debug.Log(
"Changing the number of tags or rays at runtime is not " +

// keep this consistent.
SetNumObservations(rayInput.OutputSize());
SetNumObservations(GetObservationSizePerRay(), GetNumberOfRays());
m_RayPerceptionInput = rayInput;
}
public virtual void RayOutputToArray(RayPerceptionOutput.RayOutput rayOutput, int rayIndex, float[] buffer)
{
if (rayOutput.HitTaggedObject)
{
buffer[rayOutput.HitTagIndex] = 1f;
}
var numDetectableTags = RayPerceptionInput.DetectableTags.Count;
buffer[numDetectableTags] = rayOutput.HasHit ? 0f : 1f;
buffer[numDetectableTags + 1] = rayOutput.HitFraction;
}
/// <summary>

{
Array.Clear(m_Observations, 0, m_Observations.Length);
var numRays = m_RayPerceptionInput.Angles.Count;
var numDetectableTags = m_RayPerceptionInput.DetectableTags.Count;
var rayObservartionSize = GetObservationSizePerRay();
m_RayPerceptionOutput.RayOutputs?[rayIndex].ToFloatArray(numDetectableTags, rayIndex, m_Observations);
Array.Clear(m_SingleRayObservations, 0, rayObservartionSize);
RayOutputToArray(m_RayPerceptionOutput.RayOutputs[rayIndex], rayIndex, m_SingleRayObservations);
Array.Copy(m_SingleRayObservations, 0, m_Observations, rayObservartionSize * rayIndex, rayObservartionSize);
}
// Finally, add the observations to the ObservationWriter

正在加载...
取消
保存