浏览代码

Fix stacked grid sensor (#5335)

/colab-links
GitHub 4 年前
当前提交
5b234d2e
共有 2 个文件被更改,包括 33 次插入18 次删除
  1. 38
      com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs
  2. 13
      com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs

38
com.unity.ml-agents/Runtime/Sensors/GridSensorComponent.cs


using System.Collections.Generic;
using System.Linq;
using UnityEngine;
namespace Unity.MLAgents.Sensors

{
// dummy sensor only used for debug gizmo
GridSensorBase m_DebugSensor;
List<ISensor> m_Sensors;
List<GridSensorBase> m_Sensors;
internal BoxOverlapChecker m_BoxOverlapChecker;
[HideInInspector, SerializeField]

/// <inheritdoc/>
public override ISensor[] CreateSensors()
{
m_Sensors = new List<ISensor>();
m_BoxOverlapChecker = new BoxOverlapChecker(
m_CellScale,
m_GridSize,

m_DebugSensor = new GridSensorBase("DebugGridSensor", m_CellScale, m_GridSize, m_DetectableTags, SensorCompressionType.None);
m_BoxOverlapChecker.RegisterDebugSensor(m_DebugSensor);
var gridSensors = GetGridSensors();
if (gridSensors == null || gridSensors.Length < 1)
m_Sensors = GetGridSensors().ToList();
if (m_Sensors == null || m_Sensors.Count < 1)
foreach (var sensor in gridSensors)
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
m_Sensors[0].m_BoxOverlapChecker = m_BoxOverlapChecker;
foreach (var sensor in m_Sensors)
if (ObservationStacks != 1)
m_BoxOverlapChecker.RegisterSensor(sensor);
}
if (ObservationStacks != 1)
{
var sensors = new ISensor[m_Sensors.Count];
for (var i = 0; i < m_Sensors.Count; i++)
m_Sensors.Add(new StackingSensor(sensor, ObservationStacks));
sensors[i] = new StackingSensor(m_Sensors[i], ObservationStacks);
else
{
m_Sensors.Add(sensor);
}
m_BoxOverlapChecker.RegisterSensor(sensor);
return sensors;
}
else
{
return m_Sensors.ToArray();
// Only one sensor needs to reference the boxOverlapChecker, so that it gets updated exactly once
((GridSensorBase)m_Sensors[0]).m_BoxOverlapChecker = m_BoxOverlapChecker;
return m_Sensors.ToArray();
}
/// <summary>

m_BoxOverlapChecker.ColliderMask = m_ColliderMask;
foreach (var sensor in m_Sensors)
{
((GridSensorBase)sensor).CompressionType = m_CompressionType;
sensor.CompressionType = m_CompressionType;
}
}
}

13
com.unity.ml-agents/Tests/Runtime/Sensor/GridSensorTests.cs


gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true);
gridSensorComponent.CreateSensors();
var componentSensor = (List<ISensor>)typeof(GridSensorComponent).GetField("m_Sensors",
var componentSensor = (List<GridSensorBase>)typeof(GridSensorComponent).GetField("m_Sensors",
BindingFlags.Instance | BindingFlags.NonPublic).GetValue(gridSensorComponent);
Assert.AreEqual(componentSensor.Count, 1);
}

{
gridSensorComponent.CreateSensors();
});
}
[Test]
public void TestStackedSensors()
{
testGo.tag = k_Tag2;
string[] tags = { k_Tag1, k_Tag2 };
gridSensorComponent.SetComponentParameters(tags, useGridSensorBase: true);
gridSensorComponent.ObservationStacks = 3;
var sensors = gridSensorComponent.CreateSensors();
Assert.IsInstanceOf(typeof(StackingSensor), sensors[0]);
}
}
}
正在加载...
取消
保存