您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
53 行
1.2 KiB
53 行
1.2 KiB
using UnityEngine;
|
|
using UnityEngine.Serialization;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
|
|
public class TestTextureSensorComponent : SensorComponent
|
|
{
|
|
TestTextureSensor m_Sensor;
|
|
|
|
public Texture2D TestTexture;
|
|
|
|
string m_SensorName = "TextureSensor";
|
|
|
|
public string SensorName
|
|
{
|
|
get { return m_SensorName; }
|
|
set { m_SensorName = value; }
|
|
}
|
|
|
|
|
|
public int ObservationStacks = 4;
|
|
|
|
public SensorCompressionType CompressionType = SensorCompressionType.PNG;
|
|
|
|
|
|
/// <inheritdoc/>
|
|
public override ISensor CreateSensor()
|
|
{
|
|
m_Sensor = new TestTextureSensor(TestTexture, SensorName, CompressionType);
|
|
if (ObservationStacks != 1)
|
|
{
|
|
return new StackingSensor(m_Sensor, ObservationStacks);
|
|
}
|
|
return m_Sensor;
|
|
}
|
|
|
|
/// <inheritdoc/>
|
|
public override int[] GetObservationShape()
|
|
{
|
|
var width = TestTexture.width;
|
|
var height = TestTexture.height;
|
|
var observationShape = new[] { height, width, 3 };
|
|
|
|
var stacks = ObservationStacks > 1 ? ObservationStacks : 1;
|
|
if (stacks > 1)
|
|
{
|
|
observationShape[2] *= stacks;
|
|
}
|
|
|
|
return observationShape;
|
|
}
|
|
}
|
|
|