浏览代码

addressing comments

/MLA-1734-demo-provider
vincentpierre 4 年前
当前提交
b7892849
共有 3 个文件被更改,包括 35 次插入23 次删除
  1. 26
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 31
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  3. 1
      com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs

26
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


for (var sensorIndex = 0; sensorIndex < sensorComponents.Length; sensorIndex++)
{
var sensor = sensorComponents[sensorIndex];
if (sensor.GetObservationShape().Length != 3)
if (sensor.GetObservationShape().Length == 3)
continue;
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
{
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
}
visObsIndex++;
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
if (sensor.GetObservationShape().Length == 2)
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +
$"for sensor component {visObsIndex} ({sensor.GetType().Name}).");
if (!tensorsNames.Contains(
TensorNames.ObservationPlaceholderPrefix + sensorIndex))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +
$"for sensor component {sensorIndex} ({sensor.GetType().Name}).");
}
visObsIndex++;
visObsIndex++;
}
var expectedVisualObs = model.GetNumVisualInputs();

31
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


{
var sensor = sensors[sensorIndex];
var shape = sensor.GetObservationShape();
// TODO generalize - we currently only have vector or visual, but can't handle "2D" observations
ObservationGenerator obsGen = null;
string obsGenName = null;
switch (rank)
{
case 1:

}
vecObsGen.AddSensorIndex(sensorIndex);
obsGen = vecObsGen;
obsGenName = TensorNames.VectorObservationPlaceholder;
var gen = new ObservationGenerator(allocator);
gen.AddSensorIndex(sensorIndex);
m_Dict[TensorNames.ObservationPlaceholderPrefix + sensorIndex] = gen;
// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.ObservationPlaceholderPrefix + sensorIndex;
var visgen = new ObservationGenerator(allocator);
visgen.AddSensorIndex(sensorIndex);
m_Dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] = visgen;
// If the tensor is of rank 3, we use the "visual observation
// index", which only counts the rank 3 sensors
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.VisualObservationPlaceholderPrefix + visIndex;
break;
throw new UnityAgentsException(
$"Sensor {sensor.GetName()} have an invalid rank {rank}");
}
if (vecObsGen != null)
{
m_Dict[TensorNames.VectorObservationPlaceholder] = vecObsGen;
obsGen.AddSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}

1
com.unity.ml-agents/Runtime/Sensors/SensorComponent.cs


/// Whether the observation is vector or not.
/// </summary>
/// <returns>True if the observation is vector, false otherwise.</returns>
[Obsolete("IsVisual is deprecated, please use GetObservationShape() instead.")]
public virtual bool IsVector()
{
var shape = GetObservationShape();

正在加载...
取消
保存