Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

141 行
4.4 KiB

using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Tests
{
public static class TestGridSensorConfig
{
public static int ObservationSize;
public static bool IsNormalized;
public static bool ParseAllColliders;
public static void SetParameters(int observationSize, bool isNormalized, bool parseAllColliders)
{
ObservationSize = observationSize;
IsNormalized = isNormalized;
ParseAllColliders = parseAllColliders;
}
public static void Reset()
{
ObservationSize = 0;
IsNormalized = false;
ParseAllColliders = false;
}
}
public class SimpleTestGridSensor : GridSensorBase
{
public float[] DummyData;
public SimpleTestGridSensor(
string name,
Vector3 cellScale,
Vector3Int gridSize,
string[] detectableTags,
SensorCompressionType compression
) : base(
name,
cellScale,
gridSize,
detectableTags,
compression)
{ }
protected override int GetCellObservationSize()
{
return TestGridSensorConfig.ObservationSize;
}
protected override bool IsDataNormalized()
{
return TestGridSensorConfig.IsNormalized;
}
protected internal override ProcessCollidersMethod GetProcessCollidersMethod()
{
return TestGridSensorConfig.ParseAllColliders ? ProcessCollidersMethod.ProcessAllColliders : ProcessCollidersMethod.ProcessClosestColliders;
}
protected override void GetObjectData(GameObject detectedObject, int typeIndex, float[] dataBuffer)
{
for (var i = 0; i < DummyData.Length; i++)
{
dataBuffer[i] = DummyData[i];
}
}
}
public class SimpleTestGridSensorComponent : GridSensorComponent
{
bool m_UseOneHotTag;
bool m_UseTestingGridSensor;
bool m_UseGridSensorBase;
protected override GridSensorBase[] GetGridSensors()
{
List<GridSensorBase> sensorList = new List<GridSensorBase>();
if (m_UseOneHotTag)
{
var testSensor = new OneHotGridSensor(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
if (m_UseGridSensorBase)
{
var testSensor = new GridSensorBase(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
if (m_UseTestingGridSensor)
{
var testSensor = new SimpleTestGridSensor(
SensorName,
CellScale,
GridSize,
DetectableTags,
CompressionType
);
sensorList.Add(testSensor);
}
return sensorList.ToArray();
}
public void SetComponentParameters(
string[] detectableTags = null,
float cellScaleX = 1f,
float cellScaleZ = 1f,
int gridSizeX = 10,
int gridSizeY = 1,
int gridSizeZ = 10,
int colliderMaskInt = -1,
SensorCompressionType compression = SensorCompressionType.None,
bool rotateWithAgent = false,
bool useOneHotTag = false,
bool useTestingGridSensor = false,
bool useGridSensorBase = false
)
{
DetectableTags = detectableTags;
CellScale = new Vector3(cellScaleX, 0.01f, cellScaleZ);
GridSize = new Vector3Int(gridSizeX, gridSizeY, gridSizeZ);
ColliderMask = colliderMaskInt < 0 ? LayerMask.GetMask("Default") : colliderMaskInt;
RotateWithAgent = rotateWithAgent;
CompressionType = compression;
m_UseOneHotTag = useOneHotTag;
m_UseGridSensorBase = useGridSensorBase;
m_UseTestingGridSensor = useTestingGridSensor;
}
}
}