浏览代码

[bug-fix] Fix ONNX export/Barracuda import for continuous actions (#4608)

* Use torch.ones rather than expand

* use tf.cat instead of tf.expand for Barracuda

* fix error msg handling

* handle loading exception

Co-authored-by: Chris Elion <chris.elion@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
a4ba90ec
共有 2 个文件被更改,包括 21 次插入7 次删除
  1. 22
      Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs
  2. 6
      ml-agents/mlagents/trainers/torch/distributions.py

22
Project/Assets/ML-Agents/Examples/SharedAssets/Scripts/ModelOverrider.cs


var bp = m_Agent.GetComponent<BehaviorParameters>();
var behaviorName = bp.BehaviorName;
var nnModel = GetModelForBehaviorName(behaviorName);
NNModel nnModel = null;
try
{
nnModel = GetModelForBehaviorName(behaviorName);
}
catch (Exception e)
{
overrideError = $"Exception calling GetModelForBehaviorName: {e}";
}
overrideError =
$"Didn't find a model for behaviorName {behaviorName}. Make " +
$"sure the behaviorName is set correctly in the commandline " +
$"and that the model file exists";
if (string.IsNullOrEmpty(overrideError))
{
overrideError =
$"Didn't find a model for behaviorName {behaviorName}. Make " +
"sure the behaviorName is set correctly in the commandline " +
"and that the model file exists";
}
}
else
{

6
ml-agents/mlagents/trainers/torch/distributions.py


if self.conditional_sigma:
log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
else:
# Expand so that entropy matches batch size
log_sigma = self.log_sigma.expand(inputs.shape[0], -1)
# Expand so that entropy matches batch size. Note that we're using
# torch.cat here instead of torch.expand() becuase it is not supported in the
# verified version of Barracuda (1.0.2).
log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0)
if self.tanh_squash:
return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
else:

正在加载...
取消
保存