// Buffer for C# training using System; using System.Linq; using Unity.Barracuda; using System.Collections.Generic; using Unity.MLAgents.Actuators; using Unity.MLAgents.Inference; namespace Unity.MLAgents { internal struct Transition { public IReadOnlyList state; public ActionBuffers action; public float reward; public bool done; public IReadOnlyList nextState; } internal class ReplayBuffer { List m_Buffer; int m_CurrentIndex; int m_MaxSize; public ReplayBuffer(int maxSize) { m_Buffer = new List(); m_Buffer.Capacity = maxSize; m_MaxSize = maxSize; } public int Count { get => m_Buffer.Count; } public void Push(AgentInfo info, IReadOnlyList state, IReadOnlyList nextState) { if (m_Buffer.Count < m_MaxSize) { m_Buffer.Add(new Transition() {state=state, action=info.storedActions, reward=info.reward, done=info.done, nextState=nextState}); } else { m_Buffer[m_CurrentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, done=info.done, nextState=nextState}; } m_CurrentIndex += 1; m_CurrentIndex = m_CurrentIndex % m_MaxSize; } public List SampleBatch(int batchSize) { var indexList = SampleIndex(batchSize); var samples = new List(batchSize); for (var i = 0; i < batchSize; i++) { samples.Add(m_Buffer[indexList[i]]); } return samples; } public List SampleDummyBatch(int batchSize) { var indexList = SampleIndex(batchSize); var samples = new List(batchSize); for (var i = 0; i < batchSize; i++) { samples.Add(m_Buffer[m_CurrentIndex-1]); } return samples; } private List SampleIndex(int batchSize) { Random random = new Random(); HashSet index = new HashSet(); while (index.Count < batchSize) { index.Add(random.Next(m_Buffer.Count)); } return index.ToList(); } } }