#if MLA_INPUT_SYSTEM
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Policies;
using UnityEngine.InputSystem;
using UnityEngine.Profiling;
namespace Unity.MLAgents.Extensions.Input
{
///
/// This implementation of will send events from the ML-Agents training process, or from
/// neural networks to the via the interface. If an
/// 's indicate that the Agent is running in Heuristic Mode,
/// this Actuator will write actions from the to the object.
///
public class InputActionActuator : IActuator, IBuiltInActuator
{
readonly BehaviorParameters m_BehaviorParameters;
readonly InputAction m_Action;
readonly IRLActionInputAdaptor m_InputAdaptor;
InputActuatorEventContext m_InputActuatorEventContext;
InputDevice m_Device;
InputControl m_Control;
///
/// Construct an with the of the
/// component, the relevant , and the relevant
/// to convert between ml-agents <--> .
///
/// The input device this action is bound to.
/// Used to determine if the is running in
/// heuristic mode.
/// The this we read/write data to/from
/// via the .
/// The that will convert data between ML-Agents
/// and the .
/// The object that will provide the event ptr to write to.
public InputActionActuator(InputDevice inputDevice, BehaviorParameters behaviorParameters,
InputAction action,
IRLActionInputAdaptor adaptor,
InputActuatorEventContext inputActuatorEventContext)
{
m_BehaviorParameters = behaviorParameters;
Name = $"InputActionActuator-{action.name}";
m_Action = action;
m_InputAdaptor = adaptor;
m_InputActuatorEventContext = inputActuatorEventContext;
ActionSpec = adaptor.GetActionSpecForInputAction(m_Action);
m_Device = inputDevice;
m_Control = m_Device?.GetChildControl(m_Action.name);
}
///
public void OnActionReceived(ActionBuffers actionBuffers)
{
Profiler.BeginSample("InputActionActuator.OnActionReceived");
if (!m_BehaviorParameters.IsInHeuristicMode())
{
using (m_InputActuatorEventContext.GetEventForFrame(out var eventPtr))
{
m_InputAdaptor.WriteToInputEventForAction(eventPtr, m_Action, m_Control, ActionSpec, actionBuffers);
}
}
Profiler.EndSample();
}
///
public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)
{
// TODO configure mask from editor UI?
}
///
public ActionSpec ActionSpec { get; }
///
public string Name { get; }
///
public void ResetData()
{
// do nothing for now
}
///
public void Heuristic(in ActionBuffers actionBuffersOut)
{
Profiler.BeginSample("InputActionActuator.Heuristic");
m_InputAdaptor.WriteToHeuristic(m_Action, actionBuffersOut);
Profiler.EndSample();
}
///
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.InputActionActuator;
}
}
}
#endif // MLA_INPUT_SYSTEM