浏览代码

Improve Gym wrapper compatibility and add Dopamine documentation (#1541)

* Add option to set gym visual observation to uint8

* Add option to flatten branched discrete actions

* Add game_over variable to gym wrapper

* Add guide on how to use Dopamine with the gym wrapper and comparisons with Baselines and PPO
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
0f1cb4d8
共有 7 个文件被更改,包括 671 次插入28 次删除
  1. 2
      docs/Readme.md
  2. 218
      gym-unity/README.md
  3. 84
      gym-unity/gym_unity/envs/unity_env.py
  4. 26
      gym-unity/tests/test_gym.py
  5. 236
      gym-unity/images/dopamine_gridworld_plot.png
  6. 133
      gym-unity/images/dopamine_visualbanana_plot.png

2
docs/Readme.md


* [API Reference](API-Reference.md)
* [How to use the Python API](Python-API.md)
* [Wrapping Learning Environment as a Gym](../gym-unity/README.md)
* [Wrapping Learning Environment as a Gym (+Baselines/Dopamine Integration)](../gym-unity/README.md)

218
gym-unity/README.md


```python
from gym_unity.envs import UnityEnv
env = UnityEnv(environment_filename, worker_id, default_visual, multiagent)
env = UnityEnv(environment_filename, worker_id, use_visual, uint8_visual, multiagent)
* `environment_filename` refers to the path to the Unity environment.
* `worker_id` refers to the port to use for communication with the environment.
Defaults to `0`.
* `use_visual` refers to whether to use visual observations (True) or vector
observations (False) as the default observation provided by the `reset` and
`step` functions. Defaults to `False`.
* `multiagent` refers to whether you intent to launch an environment which
contains more than one agent. Defaults to `False`.
* `environment_filename` refers to the path to the Unity environment.
* `worker_id` refers to the port to use for communication with the environment.
Defaults to `0`.
* `use_visual` refers to whether to use visual observations (True) or vector
observations (False) as the default observation provided by the `reset` and
`step` functions. Defaults to `False`.
* `uint8_visual` refers to whether to output visual observations as `uint8` values
(0-255). Many common Gym environments (e.g. Atari) do this. By default they
will be floats (0.0-1.0). Defaults to `False`.
* `multiagent` refers to whether you intent to launch an environment which
contains more than one agent. Defaults to `False`.
* `flatten_branched` will flatten a branched discrete action space into a Gym Discrete.
Otherwise, it will be converted into a MultiDiscrete. Defaults to `False`.
The returned environment `env` will function as a gym.

Using the provided Gym wrapper, it is possible to train ML-Agents environments
using these algorithms. This requires the creation of custom training scripts to
launch each algorithm. In most cases these scripts can be created by making
slightly modifications to the ones provided for Atari and Mujoco environments.
slight modifications to the ones provided for Atari and Mujoco environments.
### Example - DQN Baseline

import gym
from baselines import deepq
from gym_unity.envs import UnityEnv
from baselines import logger
from gym_unity.envs.unity_env import UnityEnv
env = UnityEnv("./envs/GridWorld", 0, use_visual=True)
env = UnityEnv("./envs/GridWorld", 0, use_visual=True, uint8_visual=True)
logger.configure('./logs') # Çhange to log in a different directory
"mlp",
lr=1e-3,
total_timesteps=100000,
"cnn", # conv_only is also a good choice for GridWorld
lr=2.5e-4,
total_timesteps=1000000,
exploration_fraction=0.1,
exploration_final_eps=0.02,
print_freq=10
exploration_fraction=0.05,
exploration_final_eps=0.1,
print_freq=20,
train_freq=5,
learning_starts=20000,
target_network_update_freq=50,
gamma=0.99,
prioritized_replay=False,
checkpoint_freq=1000,
checkpoint_path='./logs', # Change to save model in a different directory
dueling=True
To start the training process, run the following from the root of the baselines
repository:
To start the training process, run the following from the directory containing
`train_unity.py`:
```sh
python -m train_unity

"""
def make_env(rank, use_visual=True): # pylint: disable=C0111
def _thunk():
env = UnityEnv(env_directory, rank, use_visual=use_visual)
env = UnityEnv(env_directory, rank, use_visual=use_visual, uint8_visual=True)
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)))
return env
return _thunk

if __name__ == '__main__':
main()
```
## Run Google Dopamine Algorithms
Google provides a framework [Dopamine](https://github.com/google/dopamine), and
implementations of algorithms, e.g. DQN, Rainbow, and the C51 variant of Rainbow.
Using the Gym wrapper, we can run Unity environments using Dopamine.
First, after installing the Gym wrapper, clone the Dopamine repository.
```
git clone https://github.com/google/dopamine
```
Then, follow the appropriate install instructions as specified on
[Dopamine's homepage](https://github.com/google/dopamine). Note that the Dopamine
guide specifies using a virtualenv. If you choose to do so, make sure your unity_env
package is also installed within the same virtualenv as Dopamine.
### Adapting Dopamine's Scripts
First, open `dopamine/atari/run_experiment.py`. Alternatively, copy the entire `atari`
folder, and name it something else (e.g. `unity`). If you choose the copy approach,
be sure to change the package names in the import statements in `train.py` to your new
directory.
Within `run_experiment.py`, we will need to make changes to which environment is
instantiated, just as in the Baselines example. At the top of the file, insert
```python
from gym_unity.envs import UnityEnv
```
to import the Gym Wrapper. Navigate to the `create_atari_environment` method
in the same file, and switch to instantiating a Unity environment by replacing
the method with the following code.
```python
game_version = 'v0' if sticky_actions else 'v4'
full_game_name = '{}NoFrameskip-{}'.format(game_name, game_version)
env = UnityEnv('./envs/GridWorld', 0, use_visual=True, uint8_visual=True)
return env
```
`./envs/GridWorld` is the path to your built Unity executable. For more information on
building Unity environments, see [here](../docs/Learning-Environment-Executable.md), and note
the Limitations section below.
Note that we are not using the preprocessor from Dopamine,
as it uses many Atari-specific calls. Furthermore, frame-skipping can be done from within Unity,
rather than on the Python side.
### Limitations
Since Dopamine is designed around variants of DQN, it is only compatible
with discrete action spaces, and specifically the Discrete Gym space. For environments
that use branched discrete action spaces (e.g.
[VisualBanana](../docs/Learning-Environment-Examples.md)), you can enable the
`flatten_branched` parameter in `UnityEnv`, which treats each combination of branched
actions as separate actions.
Furthermore, when building your environments, ensure that your
[Learning Brain](../docs/Learning-Environment-Design-Brains.md) is using visual
observations with greyscale enabled, and that the dimensions of the visual observations
is 84 by 84 (matches the parameter found in `dqn_agent.py` and `rainbow_agent.py`).
Dopamine's agents currently do not automatically adapt to the observation
dimensions or number of channels.
### Hyperparameters
The hyperparameters provided by Dopamine are tailored to the Atari games, and you will
likely need to adjust them for ML-Agents environments. Here is a sample
`dopamine/agents/rainbow/configs/rainbow.gin` file that is known to work with
GridWorld.
```python
import dopamine.agents.rainbow.rainbow_agent
import dopamine.unity.run_experiment
import dopamine.replay_memory.prioritized_replay_buffer
import gin.tf.external_configurables
RainbowAgent.num_atoms = 51
RainbowAgent.stack_size = 1
RainbowAgent.vmax = 10.
RainbowAgent.gamma = 0.99
RainbowAgent.update_horizon = 3
RainbowAgent.min_replay_history = 20000 # agent steps
RainbowAgent.update_period = 5
RainbowAgent.target_update_period = 50 # agent steps
RainbowAgent.epsilon_train = 0.1
RainbowAgent.epsilon_eval = 0.01
RainbowAgent.epsilon_decay_period = 50000 # agent steps
RainbowAgent.replay_scheme = 'prioritized'
RainbowAgent.tf_device = '/cpu:0' # use '/cpu:*' for non-GPU version
RainbowAgent.optimizer = @tf.train.AdamOptimizer()
tf.train.AdamOptimizer.learning_rate = 0.00025
tf.train.AdamOptimizer.epsilon = 0.0003125
Runner.game_name = "Unity" # any name can be used here
Runner.sticky_actions = False
Runner.num_iterations = 200
Runner.training_steps = 10000 # agent steps
Runner.evaluation_steps = 500 # agent steps
Runner.max_steps_per_episode = 27000 # agent steps
WrappedPrioritizedReplayBuffer.replay_capacity = 1000000
WrappedPrioritizedReplayBuffer.batch_size = 32
```
This example assumed you copied `atari` to a separate folder named `unity`.
Replace `unity` in `import dopamine.unity.run_experiment` with the folder you
copied your `run_experiment.py` and `trainer.py` files to.
If you directly modified the existing files, then use `atari` here.
### Starting a Run
You can now run Dopamine as you would normally:
```
python -um dopamine.unity.train \
--agent_name=rainbow \
--base_dir=/tmp/dopamine \
--gin_files='dopamine/agents/rainbow/configs/rainbow.gin'
```
Again, we assume that you've copied `atari` into a separate folder.
Remember to replace `unity` with the directory you copied your files into. If you
edited the Atari files directly, this should be `atari`.
### Example: GridWorld
As a baseline, here are rewards over time for the three algorithms provided with
Dopamine as run on the GridWorld example environment. All Dopamine (DQN, Rainbow,
C51) runs were done with the same epsilon, epsilon decay, replay history, training steps,
and buffer settings as specified above. Note that the first 20000 steps are used to pre-fill
the training buffer, and no learning happens.
We provide results from our PPO implementation and the DQN from Baselines as reference.
Note that all runs used the same greyscale GridWorld as Dopamine. For PPO, `num_layers`
was set to 2, and all other hyperparameters are the default for GridWorld in `trainer_config.yaml`.
For Baselines DQN, the provided hyperparameters in the previous section are used. Note
that Baselines implements certain features (e.g. dueling-Q) that are not enabled
in Dopamine DQN.
![Dopamine on GridWorld](images/dopamine_gridworld_plot.png)
### Example: VisualBanana
As an example of using the `flatten_branched` option, we also used the Rainbow
algorithm to train on the VisualBanana environment, and provide the results below.
The same hyperparameters were used as in the GridWorld case, except that
`replay_history` and `epsilon_decay` were increased to 100000.
![Dopamine on VisualBanana](images/dopamine_visualbanana_plot.png)

84
gym-unity/gym_unity/envs/unity_env.py


import logging
import itertools
import gym
import numpy as np
from mlagents.envs import UnityEnvironment

https://github.com/openai/multiagent-particle-envs
"""
def __init__(self, environment_filename: str, worker_id=0, use_visual=False, multiagent=False):
def __init__(self, environment_filename: str, worker_id=0, use_visual=False, uint8_visual=False, multiagent=False, flatten_branched=False):
:param uint8_visual: Return visual observations as uint8 (0-255) matrices instead of float (0.0-1.0).
:param flatten_branched: If True, turn branched discrete action spaces into a Discrete space rather than MultiDiscrete.
"""
self._env = UnityEnvironment(environment_filename, worker_id)
self.name = self._env.academy_name

self._multiagent = multiagent
self._flattener = None
self.game_over = False # Hidden flag used by Atari environments to determine if the game is over
# Check brain configuration
if len(self._env.brains) != 1:

" visual observations as part of this environment.")
self.use_visual = brain.number_visual_observations >= 1 and use_visual
if not use_visual and uint8_visual:
logger.warning("`uint8_visual was set to true, but visual observations are not in use. "
"This setting will not have any effect.")
else:
self.uint8_visual = uint8_visual
if brain.number_visual_observations > 1:
logger.warning("The environment contains more than one visual observation. "
"Please note that only the first will be provided in the observation.")

if len(brain.vector_action_space_size) == 1:
self._action_space = spaces.Discrete(brain.vector_action_space_size[0])
else:
self._action_space = spaces.MultiDiscrete(brain.vector_action_space_size)
if flatten_branched:
self._flattener = ActionFlattener(brain.vector_action_space_size)
self._action_space = self._flattener.action_space
else:
self._action_space = spaces.MultiDiscrete(brain.vector_action_space_size)
if flatten_branched:
logger.warning("The environment has a non-discrete action space. It will "
"not be flattened.")
high = np.array([1] * brain.vector_action_space_size[0])
self._action_space = spaces.Box(-high, high, dtype=np.float32)
high = np.array([np.inf] * brain.vector_observation_space_size)

info = self._env.reset()[self.brain_name]
n_agents = len(info.agents)
self._check_agents(n_agents)
self.game_over = False
if not self._multiagent:
obs, reward, done, info = self._single_step(info)

raise UnityGymException(
"The environment was expecting a list of {} actions.".format(self._n_agents))
else:
if self._flattener is not None:
# Action space is discrete and flattened - we expect a list of scalars
action = [self._flattener.lookup_action(_act) for _act in action]
else:
if self._flattener is not None:
# Translate action into list
action = self._flattener.lookup_action(action)
info = self._env.step(action)[self.brain_name]
n_agents = len(info.agents)

if not self._multiagent:
obs, reward, done, info = self._single_step(info)
self.game_over = done
self.game_over = all(done)
self.visual_obs = info.visual_observations[0][0, :, :, :]
self.visual_obs = self._preprocess_single(info.visual_observations[0][0, :, :, :])
default_observation = self.visual_obs
else:
default_observation = info.vector_observations[0, :]

"brain_info": info}
def _preprocess_single(self, single_visual_obs):
if self.uint8_visual:
return (255.0*single_visual_obs).astype(np.uint8)
else:
return single_visual_obs
self.visual_obs = info.visual_observations
self.visual_obs = self._preprocess_multi(info.visual_observations)
default_observation = self.visual_obs
else:
default_observation = info.vector_observations

def _preprocess_multi(self, multiple_visual_obs):
if self.uint8_visual:
return [(255.0*_visual_obs).astype(np.uint8) for _visual_obs in multiple_visual_obs]
else:
return multiple_visual_obs
def render(self, mode='rgb_array'):
return self.visual_obs

@property
def number_agents(self):
return self._n_agents
class ActionFlattener():
"""
Flattens branched discrete action spaces into single-branch discrete action spaces.
"""
def __init__(self,branched_action_space):
"""
Initialize the flattener.
:param branched_action_space: A List containing the sizes of each branch of the action
space, e.g. [2,3,3] for three branches with size 2, 3, and 3 respectively.
"""
self._action_shape = branched_action_space
self.action_lookup = self._create_lookup(self._action_shape)
self.action_space = spaces.Discrete(len(self.action_lookup))
@classmethod
def _create_lookup(self, branched_action_space):
"""
Creates a Dict that maps discrete actions (scalars) to branched actions (lists).
Each key in the Dict maps to one unique set of branched actions, and each value
contains the List of branched actions.
"""
possible_vals = [range(_num) for _num in branched_action_space]
all_actions = [list(_action) for _action in itertools.product(*possible_vals)]
# Dict should be faster than List for large action spaces
action_lookup = {_scalar: _action for (_scalar, _action) in enumerate(all_actions)}
return action_lookup
def lookup_action(self, action):
"""
Convert a scalar discrete action into a unique set of branched actions.
:param: action: A scalar value representing one of the discrete actions.
:return: The List containing the branched actions.
"""
return self.action_lookup[action]

26
gym-unity/tests/test_gym.py


import pytest
import numpy as np
from gym import spaces
from tests.mock_communicator import MockCommunicator
from mock_communicator import MockCommunicator
@mock.patch('mlagents.envs.UnityEnvironment.executable_launcher')
@mock.patch('mlagents.envs.UnityEnvironment.get_communicator')

assert isinstance(rew, list)
assert isinstance(done, list)
assert isinstance(info, dict)
@mock.patch('gym_unity.envs.unity_env.UnityEnvironment')
def test_branched_flatten(mock_env):
mock_env.return_value.academy_name = 'MockAcademy'
mock_brain = mock.Mock();
mock_brain.return_value.number_visual_observations = 0
mock_brain.return_value.num_stacked_vector_observations = 1
mock_brain.return_value.vector_action_space_type = 'discrete'
mock_brain.return_value.vector_observation_space_size = 1
# Unflattened action space
mock_brain.return_value.vector_action_space_size = [2,2,3]
mock_env.return_value.brains = {'MockBrain':mock_brain()}
mock_env.return_value.external_brain_names = ['MockBrain']
env = UnityEnv(' ', use_visual=False, multiagent=False, flatten_branched=True)
assert isinstance(env.action_space, spaces.Discrete)
assert env.action_space.n==12
assert env._flattener.lookup_action(0)==[0,0,0]
assert env._flattener.lookup_action(11)==[1,1,2]
# Check that False produces a MultiDiscrete
env = UnityEnv(' ', use_visual=False, multiagent=False, flatten_branched=False)
assert isinstance(env.action_space, spaces.MultiDiscrete)

236
gym-unity/images/dopamine_gridworld_plot.png

之前 之后
宽度: 705  |  高度: 448  |  大小: 67 KiB

133
gym-unity/images/dopamine_visualbanana_plot.png

之前 之后
宽度: 704  |  高度: 462  |  大小: 36 KiB
正在加载...
取消
保存