您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
84 行
2.4 KiB
84 行
2.4 KiB
// Policy for C# training
|
|
|
|
using Unity.Barracuda;
|
|
using System.Collections.Generic;
|
|
using Unity.MLAgents.Actuators;
|
|
using Unity.MLAgents.Inference;
|
|
using Unity.MLAgents.Sensors;
|
|
|
|
namespace Unity.MLAgents.Policies
|
|
{
|
|
internal class TrainingPolicy : IPolicy
|
|
{
|
|
protected TrainingModelRunner m_ModelRunner;
|
|
ActionBuffers m_LastActionBuffer;
|
|
|
|
int m_AgentId;
|
|
|
|
ActionSpec m_ActionSpec;
|
|
|
|
string m_BehaviorName;
|
|
|
|
AgentInfo m_LastInfo;
|
|
|
|
IReadOnlyList<TensorProxy> m_LastObservations;
|
|
|
|
ReplayBuffer m_buffer;
|
|
|
|
IReadOnlyList<TensorProxy> m_CurrentObservations;
|
|
|
|
/// <inheritdoc />
|
|
public TrainingPolicy(
|
|
ActionSpec actionSpec,
|
|
string behaviorName,
|
|
NNModel model
|
|
)
|
|
{
|
|
var trainer = Academy.Instance.GetOrCreateTrainer(behaviorName, actionSpec, model);
|
|
m_ModelRunner = trainer.TrainerModelRunner;
|
|
m_buffer = trainer.Buffer;
|
|
m_CurrentObservations = m_ModelRunner.GetInputTensors();
|
|
m_BehaviorName = behaviorName;
|
|
m_ActionSpec = actionSpec;
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
|
|
{
|
|
m_AgentId = info.episodeId;
|
|
m_ModelRunner.PutObservations(info, sensors);
|
|
m_ModelRunner.GetObservationTensors(m_CurrentObservations, info, sensors);
|
|
|
|
if (m_LastObservations != null)
|
|
{
|
|
m_buffer.Push(m_LastInfo, m_LastObservations, m_CurrentObservations);
|
|
}
|
|
else if (m_buffer.Count == 0)
|
|
{
|
|
// hack
|
|
m_buffer.Push(info, m_CurrentObservations, m_CurrentObservations);
|
|
}
|
|
|
|
m_LastInfo = info;
|
|
m_LastObservations = m_CurrentObservations;
|
|
|
|
if (info.done == true)
|
|
{
|
|
m_buffer.Push(info, m_CurrentObservations, m_CurrentObservations); // dummy next_state
|
|
m_LastObservations = null;
|
|
}
|
|
}
|
|
|
|
/// <inheritdoc />
|
|
public ref readonly ActionBuffers DecideAction()
|
|
{
|
|
m_ModelRunner.DecideBatch();
|
|
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
|
|
return ref m_LastActionBuffer;
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
}
|
|
}
|
|
}
|