浏览代码

[bug-fix] Move POCA critic to default device (#5124) (#5131)

* Move critic to default device

* Make sure to clone onto default device

* Add some debug stuff

* Some more debug

* Fix issue

* Fix bool tensor too
/release_15_branch
GitHub 4 年前
当前提交
e6143a83
共有 2 个文件被更改,包括 5 次插入3 次删除
  1. 4
      ml-agents/mlagents/trainers/poca/optimizer_torch.py
  2. 4
      ml-agents/mlagents/trainers/torch/networks.py

4
ml-agents/mlagents/trainers/poca/optimizer_torch.py


)
import numpy as np
import math
from mlagents.torch_utils import torch
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.buffer import (
AgentBuffer,

network_settings=trainer_settings.network_settings,
action_spec=policy.behavior_spec.action_spec,
)
# Move to GPU if needed
self._critic.to(default_device())
params = list(self.policy.actor.parameters()) + list(self.critic.parameters())
self.hyperparameters: POCASettings = cast(

4
ml-agents/mlagents/trainers/torch/networks.py


[_obs.flatten(start_dim=1)[:, 0] for _obs in only_first_obs], dim=1
)
# Get the mask from NaNs
attn_mask = only_first_obs_flat.isnan().type(torch.FloatTensor)
attn_mask = only_first_obs_flat.isnan().float()
return attn_mask
def _copy_and_remove_nans_from_obs(

for obs in single_agent_obs:
new_obs = obs.clone()
new_obs[
attention_mask.type(torch.BoolTensor)[:, i_agent], ::
attention_mask.bool()[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
no_nan_obs.append(new_obs)
obs_with_no_nans.append(no_nan_obs)

正在加载...
取消
保存