浏览代码

push transitions into buffer

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
ed33d74a
共有 5 个文件被更改,包括 88 次插入38 次删除
  1. 4
      com.unity.ml-agents/Runtime/Academy.cs
  2. 40
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
  3. 42
      com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs
  4. 23
      com.unity.ml-agents/Runtime/ReplayBuffer.cs
  5. 17
      com.unity.ml-agents/Runtime/Trainer.cs

4
com.unity.ml-agents/Runtime/Academy.cs


}
return modelRunner;
}
internal TrainingModelRunner GetOrCreateTrainingModelRunner(string behaviorName, ActionSpec actionSpec)
internal Trainer GetOrCreateTrainer(string behaviorName, ActionSpec actionSpec, NNModel model)
{
var trainer = m_Trainers.Find(x => x.BehaviorName == behaviorName);
if (trainer == null)

}
return trainer.TrainerModelRunner;
return trainer;
}
/// <summary>

40
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


// ModelRunner for C# training.
using System;
using System.Collections.Generic;
using Unity.Barracuda;
using UnityEngine.Profiling;

TensorGenerator m_TensorGenerator;
TensorApplier m_TensorApplier;
NNModel m_Model;
Model m_Model;
NNModel m_TargetModel;
string m_ModelName;
InferenceDevice m_InferenceDevice;

m_InferenceOutputs = new List<TensorProxy>();
}
public InferenceDevice InferenceDevice
{
get { return m_InferenceDevice; }
}
public NNModel Model
{
get { return m_Model; }
}
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)
{
m_InputsByName.Clear();

public void PutObservations(AgentInfo info, List<ISensor> sensors)
{
#if DEBUG
m_SensorShapeValidator.ValidateSensors(sensors);
#endif
m_Infos.Add(new AgentInfoSensorsPair
{
agentInfo = info,

}
}
public void GetObservationTensors(IReadOnlyList<TensorProxy> tensors, AgentInfo info, List<ISensor> sensors)
{
if (!m_ObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
m_TensorGenerator.InitializeObservations(sensors, m_TensorAllocator);
m_ObservationsInitialized = true;
}
var infoSensorPair = new AgentInfoSensorsPair
{
agentInfo = info,
sensors = sensors
};
m_TensorGenerator.GenerateTensors(tensors, 1, new List<AgentInfoSensorsPair> { infoSensorPair });
}
public IReadOnlyList<TensorProxy> GetInputTensors()
{
return m_Model.GetInputTensors();
}
public void DecideBatch()
{
var currentBatchSize = m_Infos.Count;

}
Profiler.BeginSample("ModelRunner.DecideAction");
Profiler.BeginSample(m_ModelName);
Profiler.BeginSample($"GenerateTensors");
// Prepare the input tensors to be feed into the engine

m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
Profiler.EndSample();
Profiler.EndSample(); // end name
Profiler.EndSample(); // end ModelRunner.DecideAction
m_Infos.Clear();

42
com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs


ActionSpec m_ActionSpec;
private string m_BehaviorName;
string m_BehaviorName;
AgentInfo m_LastInfo;
IReadOnlyList<TensorProxy> m_LastObservations;
ReplayBuffer m_buffer;
IReadOnlyList<TensorProxy> m_CurrentObservations;
/// <inheritdoc />
public TrainingPolicy(

{
m_ModelRunner = Academy.Instance.GetOrCreateTrainingModelRunner(behaviorName, actionSpec);
var trainer = Academy.Instance.GetOrCreateTrainer(behaviorName, actionSpec, model);
m_ModelRunner = trainer.TrainerModelRunner;
m_buffer = trainer.Buffer;
m_CurrentObservations = m_ModelRunner.GetInputTensors();
m_BehaviorName = behaviorName;
m_ActionSpec = actionSpec;
}

{
m_AgentId = info.episodeId;
m_ModelRunner?.PutObservations(info, sensors);
}
m_ModelRunner.PutObservations(info, sensors);
m_ModelRunner.GetObservationTensors(m_CurrentObservations, info, sensors);
/// <inheritdoc />
public ref readonly ActionBuffers DecideAction()
{
if (m_ModelRunner == null)
if (m_LastObservations != null)
{
m_buffer.Push(m_LastInfo, m_LastObservations, m_CurrentObservations);
}
if (info.done == true)
m_LastActionBuffer = ActionBuffers.Empty;
m_LastObservations = null;
m_ModelRunner?.DecideBatch();
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
m_LastInfo = info;
m_LastObservations = m_CurrentObservations;
}
/// <inheritdoc />
public ref readonly ActionBuffers DecideAction()
{
m_ModelRunner.DecideBatch();
m_LastActionBuffer = m_ModelRunner.GetAction(m_AgentId);
return ref m_LastActionBuffer;
}

23
com.unity.ml-agents/Runtime/ReplayBuffer.cs


using Unity.Barracuda;
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
public List<Tensor> state;
public IReadOnlyList<TensorProxy> state;
public List<Tensor> nextState;
public IReadOnlyList<TensorProxy> nextState;
}
internal class ReplayBuffer

m_MaxSize = maxSize;
}
public void Push(AgentInfo info, List<Tensor> state, List<Tensor> nextState)
public int Count
{
get => m_Buffer.Count;
}
public void Push(AgentInfo info, IReadOnlyList<TensorProxy> state, IReadOnlyList<TensorProxy> nextState)
m_Buffer[currentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState};
if (m_Buffer.Count < m_MaxSize)
{
m_Buffer.Append(new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState});
}
else
{
m_Buffer[currentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState};
}
public Transition[] Sample(int batchSize)
public Transition[] SampleBatch(int batchSize)
{
var indexList = SampleIndex(batchSize);
var samples = new Transition[batchSize];

17
com.unity.ml-agents/Runtime/Trainer.cs


using System;
using Unity.MLAgents.Actuators;
using Unity.Barracuda;
namespace Unity.MLAgents

ReplayBuffer m_Buffer;
TrainingModelRunner m_ModelRunner;
string m_behaviorName;
int m_BufferSize;
int batchSize;
int m_BufferSize = 1024;
int batchSize = 64;
float GAMMA;
public Trainer(string behaviorName, ActionSpec actionSpec, int seed=0)

get => m_behaviorName;
}
public ReplayBuffer Buffer
{
get => m_Buffer;
}
public TrainingModelRunner TrainerModelRunner
{
get => m_ModelRunner;

public void Update()
{
var samples = m_Buffer.Sample(batchSize);
if (m_Buffer.Count < batchSize * 2)
{
return;
}
var samples = m_Buffer.SampleBatch(batchSize);
// states = [s.state for s in samples]
// actions = [s.action for s in samples]
// q_values = policy_net(states).gather(1, action_batch)

正在加载...
取消
保存