您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
125 行
4.9 KiB
125 行
4.9 KiB
from typing import Dict, Iterator, Any, List
|
|
from collections.abc import Mapping
|
|
from mlagents_envs.registry.base_registry_entry import BaseRegistryEntry
|
|
from mlagents_envs.registry.binary_utils import (
|
|
load_local_manifest,
|
|
load_remote_manifest,
|
|
)
|
|
from mlagents_envs.registry.remote_registry_entry import RemoteRegistryEntry
|
|
|
|
|
|
class UnityEnvRegistry(Mapping):
|
|
"""
|
|
### UnityEnvRegistry
|
|
Provides a library of Unity environments that can be launched without the need
|
|
of downloading the Unity Editor.
|
|
The UnityEnvRegistry implements a Map, to access an entry of the Registry, use:
|
|
```python
|
|
registry = UnityEnvRegistry()
|
|
entry = registry[<environment_identifyier>]
|
|
```
|
|
An entry has the following properties :
|
|
* `identifier` : Uniquely identifies this environment
|
|
* `expected_reward` : Corresponds to the reward an agent must obtained for the task
|
|
to be considered completed.
|
|
* `description` : A human readable description of the environment.
|
|
|
|
To launch a Unity environment from a registry entry, use the `make` method:
|
|
```python
|
|
registry = UnityEnvRegistry()
|
|
env = registry[<environment_identifyier>].make()
|
|
```
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._REGISTERED_ENVS: Dict[str, BaseRegistryEntry] = {}
|
|
self._manifests: List[str] = []
|
|
self._sync = True
|
|
|
|
def register(self, new_entry: BaseRegistryEntry) -> None:
|
|
"""
|
|
Registers a new BaseRegistryEntry to the registry. The
|
|
BaseRegistryEntry.identifier value will be used as indexing key.
|
|
If two are more environments are registered under the same key, the most
|
|
recentry added will replace the others.
|
|
"""
|
|
self._REGISTERED_ENVS[new_entry.identifier] = new_entry
|
|
|
|
def register_from_yaml(self, path_to_yaml: str) -> None:
|
|
"""
|
|
Registers the environments listed in a yaml file (either local or remote). Note
|
|
that the entries are registered lazily: the registration will only happen when
|
|
an environment is accessed.
|
|
The yaml file must have the following format :
|
|
```yaml
|
|
environments:
|
|
- <identifier of the first environment>:
|
|
expected_reward: <expected reward of the environment>
|
|
description: | <a multi line description of the environment>
|
|
<continued multi line description>
|
|
linux_url: <The url for the Linux executable zip file>
|
|
darwin_url: <The url for the OSX executable zip file>
|
|
win_url: <The url for the Windows executable zip file>
|
|
|
|
- <identifier of the second environment>:
|
|
expected_reward: <expected reward of the environment>
|
|
description: | <a multi line description of the environment>
|
|
<continued multi line description>
|
|
linux_url: <The url for the Linux executable zip file>
|
|
darwin_url: <The url for the OSX executable zip file>
|
|
win_url: <The url for the Windows executable zip file>
|
|
|
|
- ...
|
|
```
|
|
:param path_to_yaml: A local path or url to the yaml file
|
|
"""
|
|
self._manifests.append(path_to_yaml)
|
|
self._sync = False
|
|
|
|
def _load_all_manifests(self) -> None:
|
|
if not self._sync:
|
|
for path_to_yaml in self._manifests:
|
|
if path_to_yaml[:4] == "http":
|
|
manifest = load_remote_manifest(path_to_yaml)
|
|
else:
|
|
manifest = load_local_manifest(path_to_yaml)
|
|
for env in manifest["environments"]:
|
|
remote_entry_args = list(env.values())[0]
|
|
remote_entry_args["identifier"] = list(env.keys())[0]
|
|
self.register(RemoteRegistryEntry(**remote_entry_args))
|
|
self._manifests = []
|
|
self._sync = True
|
|
|
|
def clear(self) -> None:
|
|
"""
|
|
Deletes all entries in the registry.
|
|
"""
|
|
self._REGISTERED_ENVS.clear()
|
|
self._manifests = []
|
|
self._sync = True
|
|
|
|
def __getitem__(self, identifier: str) -> BaseRegistryEntry:
|
|
"""
|
|
Returns the BaseRegistryEntry with the provided identifier. BaseRegistryEntry
|
|
can then be used to make a Unity Environment.
|
|
:param identifier: The identifier of the BaseRegistryEntry
|
|
:returns: The associated BaseRegistryEntry
|
|
"""
|
|
self._load_all_manifests()
|
|
if identifier not in self._REGISTERED_ENVS:
|
|
raise KeyError(f"The entry {identifier} is not present in the registry.")
|
|
return self._REGISTERED_ENVS[identifier]
|
|
|
|
def __len__(self) -> int:
|
|
self._load_all_manifests()
|
|
return len(self._REGISTERED_ENVS)
|
|
|
|
def __iter__(self) -> Iterator[Any]:
|
|
self._load_all_manifests()
|
|
yield from self._REGISTERED_ENVS
|
|
|
|
|
|
default_registry = UnityEnvRegistry()
|
|
default_registry.register_from_yaml(
|
|
"https://storage.googleapis.com/mlagents-test-environments/1.0.0/manifest.yaml"
|
|
) # noqa E501
|