浏览代码

Fix model inference issue with Barracuda v1.2.1 (#4766)

Co-authored-by: Ervin T. <ervin@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
9fbf56e2
共有 2 个文件被更改,包括 13 次插入9 次删除
  1. 8
      ml-agents/mlagents/trainers/torch/distributions.py
  2. 14
      ml-agents/mlagents/trainers/torch/networks.py

8
ml-agents/mlagents/trainers/torch/distributions.py


log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
else:
# Expand so that entropy matches batch size. Note that we're using
# torch.cat here instead of torch.expand() becuase it is not supported in the
# verified version of Barracuda (1.0.2).
log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0)
# mu*0 here to get the batch size implicitly since Barracuda 1.2.1
# throws error on runtime broadcasting due to unknown reason. We
# use this to replace torch.expand() becuase it is not supported in
# the verified version of Barracuda (1.0.X).
log_sigma = mu * 0 + self.log_sigma
if self.tanh_squash:
return TanhGaussianDistInstance(mu, torch.exp(log_sigma))
else:

14
ml-agents/mlagents/trainers/torch/networks.py


):
super().__init__()
self.action_spec = action_spec
self.version_number = torch.nn.Parameter(torch.Tensor([2.0]))
self.version_number = torch.nn.Parameter(
torch.Tensor([2.0]), requires_grad=False
)
torch.Tensor([int(self.action_spec.is_continuous())])
torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False
)
self.continuous_act_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False

self.encoding_size = network_settings.memory.memory_size // 2
else:
self.encoding_size = network_settings.hidden_units
self.memory_size_vector = torch.nn.Parameter(
torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False
)
self.action_model = ActionModel(
self.encoding_size,

disc_action_out,
action_out_deprecated,
) = self.action_model.get_action_out(encoding, masks)
export_out = [
self.version_number,
torch.Tensor([self.network_body.memory_size]),
]
export_out = [self.version_number, self.memory_size_vector]
if self.action_spec.continuous_size > 0:
export_out += [cont_action_out, self.continuous_act_size_vector]
if self.action_spec.discrete_size > 0:

正在加载...
取消
保存