浏览代码

Add torch_utils class, auto-detect CUDA availability (#4403)

* Add torch_utils

* Use torch from torch_utils

* Add torch to banned modules in CI

* Better import error handling

* Fix flake8 errors

* Address comments

* Move networks to GPU if enabled

* Switch to torch_utils

* More flake8 problems

* Move reward providers to GPU/CPU

* Remove anothere set default tensor

* Fix banned import in test
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
6f534366
共有 36 个文件被更改,包括 95 次插入55 次删除
  1. 2
      .pylintrc
  2. 2
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
  3. 2
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  4. 6
      ml-agents/mlagents/trainers/policy/torch_policy.py
  5. 2
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  6. 5
      ml-agents/mlagents/trainers/ppo/trainer.py
  7. 9
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  8. 5
      ml-agents/mlagents/trainers/sac/trainer.py
  9. 2
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  10. 2
      ml-agents/mlagents/trainers/tests/torch/test_decoders.py
  11. 2
      ml-agents/mlagents/trainers/tests/torch/test_distributions.py
  12. 2
      ml-agents/mlagents/trainers/tests/torch/test_encoders.py
  13. 2
      ml-agents/mlagents/trainers/tests/torch/test_layers.py
  14. 2
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  15. 2
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  16. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  17. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  18. 2
      ml-agents/mlagents/trainers/tests/torch/test_sac.py
  19. 2
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  20. 2
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  21. 2
      ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
  22. 4
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  23. 3
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  24. 3
      ml-agents/mlagents/trainers/torch/decoders.py
  25. 3
      ml-agents/mlagents/trainers/torch/distributions.py
  26. 4
      ml-agents/mlagents/trainers/torch/encoders.py
  27. 2
      ml-agents/mlagents/trainers/torch/layers.py
  28. 2
      ml-agents/mlagents/trainers/torch/model_serialization.py
  29. 3
      ml-agents/mlagents/trainers/torch/networks.py
  30. 3
      ml-agents/mlagents/trainers/torch/utils.py
  31. 17
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  32. 10
      ml-agents/mlagents/trainers/trainer_controller.py
  33. 1
      setup.cfg
  34. 4
      ml-agents/mlagents/torch_utils/__init__.py
  35. 32
      ml-agents/mlagents/torch_utils/torch.py

2
.pylintrc


# Add files or directories to the ignore list. They should be base names, not
# paths.
ignore=CVS
generated-members=torch.*
[MESSAGES CONTROL]
#enable=

2
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


import os
import shutil
import torch
from mlagents.torch_utils import torch
from typing import Dict, Union, Optional, cast
from mlagents_envs.exception import UnityPolicyException
from mlagents_envs.logging_util import get_logger

2
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from typing import Dict, Optional, Tuple, List
import torch
from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.buffer import AgentBuffer

6
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List, Tuple, Optional
import numpy as np
import torch
from mlagents.torch_utils import torch, default_device
import copy
from mlagents.trainers.action_info import ActionInfo

) # could be much simpler if TorchPolicy is nn.Module
self.grads = None
torch.set_default_tensor_type(torch.FloatTensor)
reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]

# m_size needed for training is determined by network, not trainer settings
self.m_size = self.actor_critic.memory_size
self.actor_critic.to("cpu")
self.actor_critic.to(default_device())
@property
def export_memory_size(self) -> int:

2
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


from typing import Dict, cast
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import AgentBuffer

5
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, PPOSettings, FrameworkType
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents import torch_utils
try:
if torch_utils.is_available():
except ModuleNotFoundError:
else:
TorchPolicy = None # type: ignore
TorchPPOOptimizer = None # type: ignore

9
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple, Optional
import torch
from torch import nn
from mlagents.torch_utils import torch, nn, default_device
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType

self.entropy_optimizer = torch.optim.Adam(
[self._log_ent_coef], lr=hyperparameters.learning_rate
)
self._move_to_device(default_device())
def _move_to_device(self, device: torch.device) -> None:
self._log_ent_coef.to(device)
self.target_network.to(device)
self.value_network.to(device)
def sac_q_loss(
self,

5
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents import torch_utils
try:
if torch_utils.is_available():
except ModuleNotFoundError:
else:
TorchPolicy = None # type: ignore
TorchSACOptimizer = None # type: ignore

2
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


import os
import numpy as np
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver

2
ml-agents/mlagents/trainers/tests/torch/test_decoders.py


import pytest
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.torch.decoders import ValueHeads

2
ml-agents/mlagents/trainers/tests/torch/test_distributions.py


import pytest
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.torch.distributions import (
GaussianDistribution,

2
ml-agents/mlagents/trainers/tests/torch/test_encoders.py


import torch
from mlagents.torch_utils import torch
from unittest import mock
import pytest

2
ml-agents/mlagents/trainers/tests/torch/test_layers.py


import torch
from mlagents.torch_utils import torch
from mlagents.trainers.torch.layers import (
Swish,

2
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,

2
ml-agents/mlagents/trainers/tests/torch/test_policy.py


import pytest
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.settings import TrainerSettings, NetworkSettings

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


import numpy as np
import pytest
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.torch.components.reward_providers import (
CuriosityRewardProvider,
create_reward_provider,

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


import numpy as np
import pytest
from unittest.mock import patch
import torch
from mlagents.torch_utils import torch
import os
from mlagents.trainers.torch.components.reward_providers import (
GAILRewardProvider,

2
ml-agents/mlagents/trainers/tests/torch/test_sac.py


import pytest
import copy
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy

2
ml-agents/mlagents/trainers/tests/torch/test_utils.py


import pytest
import torch
from mlagents.torch_utils import torch
import numpy as np
from mlagents.trainers.settings import EncoderType, ScheduleType

2
ml-agents/mlagents/trainers/torch/components/bc/module.py


from typing import Dict
import numpy as np
import torch
from mlagents.torch_utils import torch
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.demo_loader import demo_to_buffer

2
ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py


import numpy as np
import torch
from mlagents.torch_utils import torch
from abc import ABC, abstractmethod
from typing import Dict

4
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


import numpy as np
from typing import Dict
import torch
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (

super().__init__(specs, settings)
self._ignore_done = True
self._network = CuriosityNetwork(specs, settings)
self._network.to(default_device())
self.optimizer = torch.optim.Adam(
self._network.parameters(), lr=settings.learning_rate
)

3
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


from typing import Optional, Dict
import numpy as np
import torch
from mlagents.torch_utils import torch, default_device
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (

super().__init__(specs, settings)
self._ignore_done = True
self._discriminator_network = DiscriminatorNetwork(specs, settings)
self._discriminator_network.to(default_device())
_, self._demo_buffer = demo_to_buffer(
settings.demo_path, 1, specs
) # This is supposed to be the sequence length but we do not have access here

3
ml-agents/mlagents/trainers/torch/decoders.py


from typing import List, Dict
import torch
from torch import nn
from mlagents.torch_utils import torch, nn
from mlagents.trainers.torch.layers import linear_layer

3
ml-agents/mlagents/trainers/torch/distributions.py


import abc
from typing import List
import torch
from torch import nn
from mlagents.torch_utils import torch, nn
import numpy as np
import math
from mlagents.trainers.torch.layers import linear_layer, Initialization

4
ml-agents/mlagents/trainers/torch/encoders.py


from mlagents.trainers.torch.layers import linear_layer, Initialization, Swish
import torch
from torch import nn
from mlagents.torch_utils import torch, nn
class Normalizer(nn.Module):

2
ml-agents/mlagents/trainers/torch/layers.py


import torch
from mlagents.torch_utils import torch
import abc
from typing import Tuple
from enum import Enum

2
ml-agents/mlagents/trainers/torch/model_serialization.py


import os
import torch
from mlagents.torch_utils import torch
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.settings import SerializationSettings

3
ml-agents/mlagents/trainers/torch/networks.py


from typing import Callable, List, Dict, Tuple, Optional
import abc
import torch
from torch import nn
from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (

3
ml-agents/mlagents/trainers/torch/utils.py


from typing import List, Optional, Tuple
import torch
from mlagents.torch_utils import torch, nn
from torch import nn
from mlagents.trainers.torch.encoders import (
SimpleVisualEncoder,

17
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.model_saver.model_saver import BaseModelSaver
from mlagents.trainers.model_saver.tf_model_saver import TFModelSaver
from mlagents.trainers.exception import UnityTrainerException
from mlagents import torch_utils
try:
if torch_utils.is_available():
except ModuleNotFoundError:
else:
TorchSaver = None # type: ignore
RewardSignalResults = Dict[str, RewardSignalResult]

StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
)
self.framework = self.trainer_settings.framework
if self.framework == FrameworkType.PYTORCH and not torch_utils.is_available():
raise UnityTrainerException(
"To use the experimental PyTorch backend, install the PyTorch Python package first."
)
logger.debug(f"Using framework {self.framework.value}")
self._next_save_step = 0

behavior_spec: BehaviorSpec,
create_graph: bool = False,
) -> Policy:
if self.framework == FrameworkType.PYTORCH and TorchPolicy is None:
raise UnityTrainerException(
"To use the experimental PyTorch backend, install the PyTorch Python package first."
)
elif self.framework == FrameworkType.PYTORCH:
if self.framework == FrameworkType.PYTORCH:
return self.create_torch_policy(parsed_behavior_id, behavior_spec)
else:
return self.create_tf_policy(

10
ml-agents/mlagents/trainers/trainer_controller.py


from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManager
from mlagents.tf_utils.globals import get_rank
try:
import torch
except ModuleNotFoundError:
torch = None # type: ignore
from mlagents import torch_utils
class TrainerController:

self.kill_trainers = False
np.random.seed(training_seed)
tf.set_random_seed(training_seed)
if torch is not None:
torch.manual_seed(training_seed)
if torch_utils.is_available():
torch_utils.torch.manual_seed(training_seed)
self.rank = get_rank()
@timed

1
setup.cfg


banned-modules = tensorflow = use mlagents.tf_utils instead (it handles tf2 compat).
logging = use mlagents_envs.logging_util instead
torch = use mlagents.torch_utils instead (handles GPU detection).

4
ml-agents/mlagents/torch_utils/__init__.py


from mlagents.torch_utils.torch import torch as torch # noqa
from mlagents.torch_utils.torch import nn # noqa
from mlagents.torch_utils.torch import is_available # noqa
from mlagents.torch_utils.torch import default_device # noqa

32
ml-agents/mlagents/torch_utils/torch.py


# Detect availability of torch package here.
# NOTE: this try/except is temporary until torch is required for ML-Agents.
try:
# This should be the only place that we import torch directly.
# Everywhere else is caught by the banned-modules setting for flake8
import torch # noqa I201
# Known PyLint compatibility with PyTorch https://github.com/pytorch/pytorch/issues/701
# pylint: disable=E1101
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device = torch.device("cuda")
else:
torch.set_default_tensor_type(torch.FloatTensor)
device = torch.device("cpu")
nn = torch.nn
# pylint: disable=E1101
except ImportError:
torch = None
nn = None
device = None
def default_device():
return device
def is_available():
"""
Returns whether Torch is available in this Python environment
"""
return torch is not None
正在加载...
取消
保存