浏览代码

Make RayPerception sensor work better with transforms that have scale (#3334)

* handle non-1 scale, handle 0 length

* draw scaled spheres correctly, add tests
/asymm-envs
GitHub 5 年前
当前提交
b12b906f
共有 3 个文件被更改,包括 111 次插入17 次删除
  1. 40
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs
  2. 4
      com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs
  3. 84
      com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs

40
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensor.cs


public Vector3 worldEnd;
public bool castHit;
public float hitFraction;
public float castRadius;
}
public void Reset()

/// nothing was hit.
///
/// </summary>
/// <param name="rayLength"></param>
/// <param name="unscaledRayLength"></param>
/// <param name="castRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// <param name="unscaledCastRadius">Radius of the sphere to use for spherecasting. If 0 or less, rays are used
/// instead - this may be faster, especially for complex environments.</param>
/// <param name="transform">Transform of the GameObject</param>
/// <param name="castType">Whether to perform the casts in 2D or 3D.</param>

public static void PerceiveStatic(float rayLength,
public static void PerceiveStatic(float unscaledRayLength,
float startOffset, float endOffset, float castRadius,
float startOffset, float endOffset, float unscaledCastRadius,
Transform transform, CastType castType, float[] perceptionBuffer,
int layerMask = Physics.DefaultRaycastLayers,
DebugDisplayInfo debugInfo = null)

if (castType == CastType.Cast3D)
{
startPositionLocal = new Vector3(0, startOffset, 0);
endPositionLocal = PolarToCartesian3D(rayLength, angle);
endPositionLocal = PolarToCartesian3D(unscaledRayLength, angle);
endPositionLocal.y += endOffset;
}
else

endPositionLocal = PolarToCartesian2D(rayLength, angle);
endPositionLocal = PolarToCartesian2D(unscaledRayLength, angle);
}
var startPositionWorld = transform.TransformPoint(startPositionLocal);

// If there is non-unity scale, |rayDirection| will be different from rayLength.
// We want to use this transformed ray length for determining cast length, hit fraction etc.
// We also it to scale up or down the sphere or circle radii
var scaledRayLength = rayDirection.magnitude;
// Avoid 0/0 if unscaledRayLength is 0
var scaledCastRadius = unscaledRayLength > 0 ? unscaledCastRadius * scaledRayLength / unscaledRayLength : unscaledCastRadius;
// Do the cast and assign the hit information for each detectable object.
// sublist[0 ] <- did hit detectableObjects[0]

if (castType == CastType.Cast3D)
{
RaycastHit rayHit;
if (castRadius > 0f)
if (scaledCastRadius > 0f)
castHit = Physics.SphereCast(startPositionWorld, castRadius, rayDirection, out rayHit,
rayLength, layerMask);
castHit = Physics.SphereCast(startPositionWorld, scaledCastRadius, rayDirection, out rayHit,
scaledRayLength, layerMask);
rayLength, layerMask);
scaledRayLength, layerMask);
hitFraction = castHit ? rayHit.distance / rayLength : 1.0f;
// If scaledRayLength is 0, we still could have a hit with sphere casts (maybe?).
// To avoid 0/0, set the fraction to 0.
hitFraction = castHit ? (scaledRayLength > 0 ? rayHit.distance / scaledRayLength : 0.0f) : 1.0f;
if (castRadius > 0f)
if (scaledCastRadius > 0f)
rayHit = Physics2D.CircleCast(startPositionWorld, castRadius, rayDirection,
rayLength, layerMask);
rayHit = Physics2D.CircleCast(startPositionWorld, scaledCastRadius, rayDirection,
scaledRayLength, layerMask);
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, rayLength, layerMask);
rayHit = Physics2D.Raycast(startPositionWorld, rayDirection, scaledRayLength, layerMask);
}
castHit = rayHit;

debugInfo.rayInfos[rayIndex].worldEnd = endPositionWorld;
debugInfo.rayInfos[rayIndex].castHit = castHit;
debugInfo.rayInfos[rayIndex].hitFraction = hitFraction;
debugInfo.rayInfos[rayIndex].castRadius = scaledCastRadius;
}
else if (Application.isEditor)
{

4
com.unity.ml-agents/Runtime/Sensor/RayPerceptionSensorComponentBase.cs


// hit fraction ^2 will shift "far" hits closer to the hit color
var lerpT = rayInfo.hitFraction * rayInfo.hitFraction;
var color = Color.Lerp(rayHitColor, rayMissColor, lerpT);
color.a = alpha;
color.a *= alpha;
Gizmos.color = color;
Gizmos.DrawRay(startPositionWorld, rayDirection);

var hitRadius = Mathf.Max(sphereCastRadius, .05f);
var hitRadius = Mathf.Max(rayInfo.castRadius, .05f);
Gizmos.DrawWireSphere(startPositionWorld + rayDirection, hitRadius);
}
}

84
com.unity.ml-agents/Tests/Editor/Sensor/RayPerceptionSensorTests.cs


}
}
}
[Test]
public void TestRaycastsScaled()
{
SetupScene();
var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();
obj.transform.localScale = new Vector3(2, 2,2 );
perception.raysPerDirection = 0;
perception.maxRayDegrees = 45;
perception.rayLength = 20;
perception.detectableTags = new List<string>();
perception.detectableTags.Add(k_CubeTag);
var radii = new[] { 0f, .5f };
foreach (var castRadius in radii)
{
perception.sphereCastRadius = castRadius;
var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.raysPerDirection + 1) * (perception.detectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
var outputBuffer = new float[expectedObs];
WriteAdapter writer = new WriteAdapter();
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);
// Expected hits:
// ray 0 should hit the cube at roughly 1/4 way
//
Assert.AreEqual(1.0f, outputBuffer[0]); // hit cube
Assert.AreEqual(0.0f, outputBuffer[1]); // missed unknown tag
// Hit is at z=9.0 in world space, ray length was 20
// But scale increases the cast size and the ray length
var scaledRayLength = 2 * perception.rayLength;
var scaledCastRadius = 2 * castRadius;
Assert.That(
outputBuffer[2], Is.EqualTo((9.5f - scaledCastRadius) / scaledRayLength).Within(.0005f)
);
}
}
[Test]
public void TestRayZeroLength()
{
// Place the cube touching the origin
var cube = GameObject.CreatePrimitive(PrimitiveType.Cube);
cube.transform.position = new Vector3(0, 0, .5f);
cube.tag = k_CubeTag;
Physics.SyncTransforms();
var obj = new GameObject("agent");
var perception = obj.AddComponent<RayPerceptionSensorComponent3D>();
perception.raysPerDirection = 0;
perception.rayLength = 0.0f;
perception.sphereCastRadius = .5f;
perception.detectableTags = new List<string>();
perception.detectableTags.Add(k_CubeTag);
{
// Set the layer mask to either the default, or one that ignores the close cube's layer
var sensor = perception.CreateSensor();
var expectedObs = (2 * perception.raysPerDirection + 1) * (perception.detectableTags.Count + 2);
Assert.AreEqual(sensor.GetObservationShape()[0], expectedObs);
var outputBuffer = new float[expectedObs];
WriteAdapter writer = new WriteAdapter();
writer.SetTarget(outputBuffer, sensor.GetObservationShape(), 0);
var numWritten = sensor.Write(writer);
Assert.AreEqual(numWritten, expectedObs);
// hit fraction is arbitrary but should be finite in [0,1]
Assert.GreaterOrEqual(outputBuffer[2], 0.0f);
Assert.LessOrEqual(outputBuffer[2], 1.0f);
}
}
}
}
正在加载...
取消
保存