浏览代码

Set torch device from commandline (#4888)

/MLA-1734-demo-provider
GitHub 3 年前
当前提交
457ed0b8
共有 8 个文件被更改,包括 103 次插入14 次删除
  1. 3
      com.unity.ml-agents/CHANGELOG.md
  2. 10
      docs/Training-ML-Agents.md
  3. 1
      ml-agents/mlagents/torch_utils/__init__.py
  4. 37
      ml-agents/mlagents/torch_utils/torch.py
  5. 15
      ml-agents/mlagents/trainers/cli_utils.py
  6. 1
      ml-agents/mlagents/trainers/learn.py
  7. 9
      ml-agents/mlagents/trainers/settings.py
  8. 41
      ml-agents/mlagents/trainers/tests/test_torch_utils.py

3
com.unity.ml-agents/CHANGELOG.md


`AddList()` is recommended, as it does not generate any additional memory allocations. (#4887)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added a `--torch-device` commandline option to `mlagents-learn`, which sets the default
[`torch.device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.device) used for training. (#4888)
- The `--cpu` commandline option had no effect and was removed. Use `--torch-device=cpu` to force CPU training. (#4888)
### Bug Fixes
#### com.unity.ml-agents (C#)

10
docs/Training-ML-Agents.md


mlagents-learn --help
```
These additional CLI arguments are grouped into environment, engine and checkpoint. The available settings and example values are shown below.
These additional CLI arguments are grouped into environment, engine, checkpoint and torch.
The available settings and example values are shown below.
#### Environment settings

force: true
train_model: false
inference: false
```
#### Torch settings:
```yaml
torch_settings:
device: cpu
```
### Behavior Configurations

1
ml-agents/mlagents/torch_utils/__init__.py


from mlagents.torch_utils.torch import torch as torch # noqa
from mlagents.torch_utils.torch import nn # noqa
from mlagents.torch_utils.torch import set_torch_config # noqa
from mlagents.torch_utils.torch import default_device # noqa

37
ml-agents/mlagents/torch_utils/torch.py


from distutils.version import LooseVersion
import pkg_resources
from mlagents.torch_utils import cpu_utils
from mlagents.trainers.settings import TorchSettings
from mlagents_envs.logging_util import get_logger
logger = get_logger(__name__)
def assert_torch_installed():

torch.set_num_threads(cpu_utils.get_num_threads_to_use())
os.environ["KMP_BLOCKTIME"] = "0"
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.FloatTensor)
device = torch.device("cuda")
else:
torch.set_default_tensor_type(torch.FloatTensor)
device = torch.device("cpu")
_device = torch.device("cpu")
def set_torch_config(torch_settings: TorchSettings) -> None:
global _device
if torch_settings.device is None:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
else:
device_str = torch_settings.device
_device = torch.device(device_str)
if _device.type == "cuda":
torch.set_default_tensor_type(torch.cuda.FloatTensor)
else:
torch.set_default_tensor_type(torch.FloatTensor)
logger.info(f"default Torch device: {_device}")
# Initialize to default settings
set_torch_config(TorchSettings(device=None))
return device
return _device

15
ml-agents/mlagents/trainers/cli_utils.py


action=DetectDefault,
)
argparser.add_argument(
"--cpu",
default=False,
action=DetectDefaultStoreTrue,
help="Forces training using CPU only",
)
argparser.add_argument(
"--torch",
default=False,
action=RaiseRemovedWarning,

action=DetectDefaultStoreTrue,
help="Whether to run the Unity executable in no-graphics mode (i.e. without initializing "
"the graphics driver. Use this only if your agents don't use visual observations.",
)
torch_conf = argparser.add_argument_group(title="Torch Configuration")
torch_conf.add_argument(
"--torch-device",
default=None,
dest="device",
action=DetectDefault,
help='Settings for the default torch.device used in training, for example, "cpu", "cuda", or "cuda:0"',
)
return argparser

1
ml-agents/mlagents/trainers/learn.py


:param run_options: Command line arguments for training.
"""
with hierarchical_timer("run_training.setup"):
torch_utils.set_torch_config(options.torch_settings)
checkpoint_settings = options.checkpoint_settings
env_settings = options.env_settings
engine_settings = options.engine_settings

9
ml-agents/mlagents/trainers/settings.py


@attr.s(auto_attribs=True)
class TorchSettings:
device: Optional[str] = parser.get_default("torch_device")
@attr.s(auto_attribs=True)
class RunOptions(ExportableSettings):
default_settings: Optional[TrainerSettings] = None
behaviors: DefaultDict[str, TrainerSettings] = attr.ib(

engine_settings: EngineSettings = attr.ib(factory=EngineSettings)
environment_parameters: Optional[Dict[str, EnvironmentParameterSettings]] = None
checkpoint_settings: CheckpointSettings = attr.ib(factory=CheckpointSettings)
torch_settings: TorchSettings = attr.ib(factory=TorchSettings)
# These are options that are relevant to the run itself, and not the engine or environment.
# They will be left here.

"checkpoint_settings": {},
"env_settings": {},
"engine_settings": {},
"torch_settings": {},
}
if config_path is not None:
configured_dict.update(load_config(config_path))

configured_dict["env_settings"][key] = val
elif key in attr.fields_dict(EngineSettings):
configured_dict["engine_settings"][key] = val
elif key in attr.fields_dict(TorchSettings):
configured_dict["torch_settings"][key] = val
else: # Base options
configured_dict[key] = val

41
ml-agents/mlagents/trainers/tests/test_torch_utils.py


import pytest
from unittest import mock
import torch # noqa I201
from mlagents.torch_utils import set_torch_config, default_device
from mlagents.trainers.settings import TorchSettings
@pytest.mark.parametrize(
"device_str, expected_type, expected_index, expected_tensor_type",
[
("cpu", "cpu", None, torch.FloatTensor),
("cuda", "cuda", None, torch.cuda.FloatTensor),
("cuda:42", "cuda", 42, torch.cuda.FloatTensor),
("opengl", "opengl", None, torch.FloatTensor),
],
)
@mock.patch.object(torch, "set_default_tensor_type")
def test_set_torch_device(
mock_set_default_tensor_type,
device_str,
expected_type,
expected_index,
expected_tensor_type,
):
try:
torch_settings = TorchSettings(device=device_str)
set_torch_config(torch_settings)
assert default_device().type == expected_type
if expected_index is None:
assert default_device().index is None
else:
assert default_device().index == expected_index
mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type)
except Exception:
raise
finally:
# restore the defaults
torch_settings = TorchSettings(device=None)
set_torch_config(torch_settings)
正在加载...
取消
保存