浏览代码

Fixing inference for hallway

/develop-newnormalization
GitHub 5 年前
当前提交
00107b93
共有 4 个文件被更改,包括 83 次插入0 次删除
  1. 33
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs
  2. 47
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs
  3. 1
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs
  4. 2
      UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs

33
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/ApplierImpl.cs


}
}
/// <summary>
/// The Applier for the Memory output tensor. Tensor is assumed to contain the new
/// memory data of the agents in the batch.
/// </summary>
public class MemoryOutputApplier : TensorApplier.IApplier
{
Dictionary<int, List<float>> m_Memories;
public MemoryOutputApplier(
Dictionary<int, List<float>> memories)
{
m_Memories = memories;
}
public void Apply(TensorProxy tensorProxy, IEnumerable<Agent> agents)
{
var agentIndex = 0;
var memorySize = (int)tensorProxy.shape[tensorProxy.shape.Length - 1];
foreach (var agent in agents)
{
List<float> memory = null;
if (!m_Memories.TryGetValue(agent.Info.id, out memory)
|| memory.Count < memorySize)
{
memory = new List<float>();
memory.AddRange(Enumerable.Repeat(0f, memorySize));
}
m_Memories[agent.Info.id] = memory;
agentIndex++;
}
}
}
public class BarracudaMemoryOutputApplier : TensorApplier.IApplier
{
readonly int m_MemoriesCount;

47
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/GeneratorImpl.cs


}
}
/// <summary>
/// Generates the Tensor corresponding to the Recurrent input : Will be a two
/// dimensional float array of dimension [batchSize x memorySize].
/// It will use the Memory data contained in the agentInfo to fill the data
/// of the tensor.
/// </summary>
public class RecurrentInputGenerator : TensorGenerator.IGenerator
{
private readonly ITensorAllocator m_Allocator;
Dictionary<int, List<float>> m_Memories;
public RecurrentInputGenerator(
ITensorAllocator allocator,
Dictionary<int, List<float>> memories)
{
m_Allocator = allocator;
m_Memories = memories;
}
public void Generate(
TensorProxy tensorProxy, int batchSize, IEnumerable<Agent> agents)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
var memorySize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var agentIndex = 0;
foreach (var agent in agents)
{
var info = agent.Info;
List<float> memory;
if (!m_Memories.TryGetValue(agent.Info.id, out memory))
{
agentIndex++;
continue;
}
for (var j = 0; j < Math.Min(memorySize, memory.Count); j++)
{
if (j >= memory.Count)
{
break;
}
tensorProxy.data[agentIndex, j] = memory[j];
}
agentIndex++;
}
}
}
public class BarracudaRecurrentInputGenerator : TensorGenerator.IGenerator
{

1
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorApplier.cs


m_Dict[TensorNames.ActionOutput] =
new DiscreteActionOutputApplier(bp.vectorActionSize, seed, allocator);
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);
if (barracudaModel != null)
{

2
UnitySDK/Assets/ML-Agents/Scripts/InferenceBrain/TensorGenerator.cs


new BatchSizeGenerator(allocator);
m_Dict[TensorNames.SequenceLengthPlaceholder] =
new SequenceLengthGenerator(allocator);
m_Dict[TensorNames.RecurrentInPlaceholder] =
new RecurrentInputGenerator(allocator, memories);
if (barracudaModel != null)
{

正在加载...
取消
保存