using UnityEngine; using UnityEngine.Serialization; namespace Unity.MLAgents.Sensors { public class CustomRayPerceptionOutput : RayPerceptionOutput { public CustomRayPerceptionOutput() { CustomObservationSizePerRay = 3; } public override void GetCustomObservationData(RayOutput rayOutput, float[] buffer) { var rb = rayOutput.HitGameObject.GetComponent(); if (rb != null) { buffer[0] = rb.velocity.x; buffer[1] = rb.velocity.z; buffer[2] = rb.mass; } } } /// /// A component for Custom 3D Ray Perception. /// [AddComponentMenu("ML Agents/Custom Ray Perception Sensor 3D", (int)MenuGroup.Sensors)] public class RayPerceptionSensorComponentCustom3D : RayPerceptionSensorComponent3D { public override ISensor[] CreateSensors() { var rayPerceptionInput = GetRayPerceptionInput(); var rayPerceptionOutput = new CustomRayPerceptionOutput(); m_RaySensor = new RayPerceptionSensor(SensorName, rayPerceptionInput, rayPerceptionOutput); if (ObservationStacks != 1) { var stackingSensor = new StackingSensor(m_RaySensor, ObservationStacks); return new ISensor[] { stackingSensor }; } return new ISensor[] { m_RaySensor }; } } }