浏览代码

Fix tests for Barracuda (#2333)

* Removed obsolete 'TestDstWrongShape' test as it does not reflect how Barracuda tensors work

* Added proper test cleanup, to avoid warning messages from finalizer thread.
/develop-generalizationTraining-TrainerController
Ervin T 5 年前
当前提交
fb9dc411
共有 3 个文件被更改,包括 27 次插入33 次删除
  1. 9
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs
  2. 32
      UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs
  3. 19
      UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs

9
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorApplier.cs


public void Contruction()
{
var bp = new BrainParameters();
var tensorGenerator = new TensorApplier(bp, 0, new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorApplier(bp, 0, alloc);
alloc.Dispose();
}
[Test]

4f, 5f, 6f, 7f, 8f})
};
var agentInfos = GetFakeAgentInfos();
var applier = new DiscreteActionOutputApplier(new int[]{2, 3}, 0, new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var applier = new DiscreteActionOutputApplier(new int[]{2, 3}, 0, alloc);
applier.Apply(inputTensor, agentInfos);
var agents = agentInfos.Keys.ToList();
var agent = agents[0] as TestAgent;

action = agent.GetAction();
Assert.AreEqual(action.vectorActions[0], 1);
Assert.AreEqual(action.vectorActions[1], 2);
alloc.Dispose();
}
[Test]

32
UnitySDK/Assets/ML-Agents/Editor/Tests/EditModeTestInternalBrainTensorGenerator.cs


public void Contruction()
{
var bp = new BrainParameters();
var tensorGenerator = new TensorGenerator(bp, 0, new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var tensorGenerator = new TensorGenerator(bp, 0, alloc);
alloc.Dispose();
}
[Test]

var alloc = new TensorCachingAllocator();
var generator = new BatchSizeGenerator(new TensorCachingAllocator());
var generator = new BatchSizeGenerator(alloc);
alloc.Dispose();
}
[Test]

var alloc = new TensorCachingAllocator();
var generator = new SequenceLengthGenerator(new TensorCachingAllocator());
var generator = new SequenceLengthGenerator(alloc);
alloc.Dispose();
}
[Test]

};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new VectorObservationGenerator(new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var generator = new VectorObservationGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data);
Assert.AreEqual(inputTensor.Data[0, 0], 1);

alloc.Dispose();
}
[Test]

};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new RecurrentInputGenerator(new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var generator = new RecurrentInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data);
Assert.AreEqual(inputTensor.Data[0, 0], 0);

alloc.Dispose();
}
[Test]

};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new PreviousActionInputGenerator(new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var generator = new PreviousActionInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data);

Assert.AreEqual(inputTensor.Data[1, 1], 4);
alloc.Dispose();
}
[Test]

};
var batchSize = 4;
var agentInfos = GetFakeAgentInfos();
var generator = new ActionMaskInputGenerator(new TensorCachingAllocator());
var alloc = new TensorCachingAllocator();
var generator = new ActionMaskInputGenerator(alloc);
generator.Generate(inputTensor, batchSize, agentInfos);
Assert.IsNotNull(inputTensor.Data);
Assert.AreEqual(inputTensor.Data[0, 0], 1);

alloc.Dispose();
}
}
}

19
UnitySDK/Assets/ML-Agents/Editor/Tests/MultinomialTest.cs


}
[Test]
public void TestDstWrongShape()
{
Multinomial m = new Multinomial(2018);
TensorProxy src = new TensorProxy
{
ValueType = TensorProxy.TensorType.FloatingPoint,
Data = new Tensor(0,1)
};
TensorProxy dst = new TensorProxy
{
ValueType = TensorProxy.TensorType.FloatingPoint,
Data = new Tensor(0,2)
};
Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
}
[Test]
public void TestUnequalBatchSize()
{
Multinomial m = new Multinomial(2018);

正在加载...
取消
保存