Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

84 行
2.4 KiB

// Policy for C# training
using Unity.Barracuda;
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Sensors;
namespace Unity.MLAgents.Policies
{
internal class TrainingPolicy : IPolicy
{
protected TrainingModelRunner m_ModelRunner;
ActionBuffers m_LastActionBuffer;
int m_AgentId;
ActionSpec m_ActionSpec;
string m_BehaviorName;
AgentInfo m_LastInfo;
IReadOnlyList<TensorProxy> m_LastObservations;
ReplayBuffer m_buffer;
IReadOnlyList<TensorProxy> m_CurrentObservations;
/// <inheritdoc />
public TrainingPolicy(
ActionSpec actionSpec,
string behaviorName,
NNModel model
)
{
var trainer = Academy.Instance.GetOrCreateTrainer(behaviorName, actionSpec, model);
m_ModelRunner = trainer.TrainerModelRunner;
m_buffer = trainer.Buffer;
m_CurrentObservations = m_ModelRunner.GetInputTensors();
m_BehaviorName = behaviorName;
m_ActionSpec = actionSpec;
}
/// <inheritdoc />
public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
m_AgentId = info.episodeId;
m_ModelRunner.PutObservations(info, sensors);
m_ModelRunner.GetObservationTensors(m_CurrentObservations, info, sensors);
if (m_LastObservations != null)
{
m_buffer.Push(m_LastInfo, m_LastObservations, m_CurrentObservations);
}
else if (m_buffer.Count == 0)
{
// hack
m_buffer.Push(info, m_CurrentObservations, m_CurrentObservations);
}
m_LastInfo = info;
m_LastObservations = m_CurrentObservations;
if (info.done == true)
{
m_buffer.Push(info, m_CurrentObservations, m_CurrentObservations); // dummy next_state
m_LastObservations = null;
}
}
/// <inheritdoc />
public ref readonly ActionBuffers DecideAction()
{
m_ModelRunner.DecideBatch();
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
return ref m_LastActionBuffer;
}
public void Dispose()
{
}
}
}