|
|
|
|
|
|
// ModelRunner for C# training.
|
|
|
|
|
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using Unity.Barracuda; |
|
|
|
using UnityEngine.Profiling; |
|
|
|
|
|
|
TensorGenerator m_TensorGenerator; |
|
|
|
TensorApplier m_TensorApplier; |
|
|
|
|
|
|
|
NNModel m_Model; |
|
|
|
Model m_Model; |
|
|
|
NNModel m_TargetModel; |
|
|
|
string m_ModelName; |
|
|
|
InferenceDevice m_InferenceDevice; |
|
|
|
|
|
|
m_InferenceOutputs = new List<TensorProxy>(); |
|
|
|
} |
|
|
|
|
|
|
|
public InferenceDevice InferenceDevice |
|
|
|
{ |
|
|
|
get { return m_InferenceDevice; } |
|
|
|
} |
|
|
|
|
|
|
|
public NNModel Model |
|
|
|
{ |
|
|
|
get { return m_Model; } |
|
|
|
} |
|
|
|
|
|
|
|
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs) |
|
|
|
{ |
|
|
|
m_InputsByName.Clear(); |
|
|
|
|
|
|
|
|
|
|
public void PutObservations(AgentInfo info, List<ISensor> sensors) |
|
|
|
{ |
|
|
|
#if DEBUG
|
|
|
|
m_SensorShapeValidator.ValidateSensors(sensors); |
|
|
|
#endif
|
|
|
|
m_Infos.Add(new AgentInfoSensorsPair |
|
|
|
{ |
|
|
|
agentInfo = info, |
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public void GetObservationTensors(IReadOnlyList<TensorProxy> tensors, AgentInfo info, List<ISensor> sensors) |
|
|
|
{ |
|
|
|
if (!m_ObservationsInitialized) |
|
|
|
{ |
|
|
|
// Just grab the first agent in the collection (any will suffice, really).
|
|
|
|
// We check for an empty Collection above, so this will always return successfully.
|
|
|
|
m_TensorGenerator.InitializeObservations(sensors, m_TensorAllocator); |
|
|
|
m_ObservationsInitialized = true; |
|
|
|
} |
|
|
|
var infoSensorPair = new AgentInfoSensorsPair |
|
|
|
{ |
|
|
|
agentInfo = info, |
|
|
|
sensors = sensors |
|
|
|
}; |
|
|
|
m_TensorGenerator.GenerateTensors(tensors, 1, new List<AgentInfoSensorsPair> { infoSensorPair }); |
|
|
|
} |
|
|
|
|
|
|
|
public IReadOnlyList<TensorProxy> GetInputTensors() |
|
|
|
{ |
|
|
|
return m_Model.GetInputTensors(); |
|
|
|
} |
|
|
|
|
|
|
|
public void DecideBatch() |
|
|
|
{ |
|
|
|
var currentBatchSize = m_Infos.Count; |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
Profiler.BeginSample("ModelRunner.DecideAction"); |
|
|
|
Profiler.BeginSample(m_ModelName); |
|
|
|
|
|
|
|
Profiler.BeginSample($"GenerateTensors"); |
|
|
|
// Prepare the input tensors to be feed into the engine
|
|
|
|
|
|
|
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); |
|
|
|
Profiler.EndSample(); |
|
|
|
|
|
|
|
Profiler.EndSample(); // end name
|
|
|
|
Profiler.EndSample(); // end ModelRunner.DecideAction
|
|
|
|
|
|
|
|
m_Infos.Clear(); |
|
|
|