Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

75 行
2.0 KiB

// Buffer for C# training
using System;
using System.Linq;
using Unity.Barracuda;
using System.Collections.Generic;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
namespace Unity.MLAgents
{
internal struct Transition
{
public IReadOnlyList<TensorProxy> state;
public ActionBuffers action;
public float reward;
public IReadOnlyList<TensorProxy> nextState;
}
internal class ReplayBuffer
{
List<Transition> m_Buffer;
int currentIndex;
int m_MaxSize;
public ReplayBuffer(int maxSize)
{
m_Buffer = new List<Transition>();
m_Buffer.Capacity = maxSize;
m_MaxSize = maxSize;
}
public int Count
{
get => m_Buffer.Count;
}
public void Push(AgentInfo info, IReadOnlyList<TensorProxy> state, IReadOnlyList<TensorProxy> nextState)
{
if (m_Buffer.Count < m_MaxSize)
{
m_Buffer.Append(new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState});
}
else
{
m_Buffer[currentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState};
}
currentIndex += 1;
currentIndex = currentIndex % m_MaxSize;
}
public Transition[] SampleBatch(int batchSize)
{
var indexList = SampleIndex(batchSize);
var samples = new Transition[batchSize];
for (var i = 0; i < batchSize; i++)
{
samples[i] = m_Buffer[indexList[i]];
}
return samples;
}
private List<int> SampleIndex(int batchSize)
{
Random random = new Random();
HashSet<int> index = new HashSet<int>();
while (index.Count < batchSize)
{
index.Add(random.Next(m_Buffer.Count));
}
return index.ToList();
}
}
}