浏览代码

Random Network Distillation for Torch (#4473)

* initial commit

* works with Pyramids

* added unit tests and a separate config file

* Adding first batch of documentation

* adding in the docs that rnd is only for PyTorch

* adding newline at the end of the config files

* adding some docs

* Code comments

* no normalization of the reward

* Fixing the tests

* [skip ci]

* [skip ci] Make sure RND will only work for Torch by editing the config file

* [skip ci] Additional information in the Documentation

* Remove the _has_updated_once flag
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
60b76790
共有 10 个文件被更改,包括 236 次插入1 次删除
  1. 3
      com.unity.ml-agents/CHANGELOG.md
  2. 24
      docs/ML-Agents-Overview.md
  3. 13
      docs/Training-Configuration-File.md
  4. 3
      ml-agents/mlagents/trainers/model_saver/tf_model_saver.py
  5. 8
      ml-agents/mlagents/trainers/settings.py
  6. 3
      ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
  7. 4
      ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
  8. 32
      config/ppo/PyramidsRND.yaml
  9. 69
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
  10. 78
      ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py

3
com.unity.ml-agents/CHANGELOG.md


### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- Added the Random Network Distillation (RND) intrinsic reward signal to the Pytorch
trainers. To use RND, add a `rnd` section to the `reward_signals` section of your
yaml configuration file. [More information here](https://github.com/Unity-Technologies/ml-agents/blob/master/docs/Training-Configuration-File.md#rnd-intrinsic-reward)
### Minor Changes
#### com.unity.ml-agents (C#)

24
docs/ML-Agents-Overview.md


- [A Quick Note on Reward Signals](#a-quick-note-on-reward-signals)
- [Deep Reinforcement Learning](#deep-reinforcement-learning)
- [Curiosity for Sparse-reward Environments](#curiosity-for-sparse-reward-environments)
- [RND for Sparse-reward Environments](#rnd-for-sparse-reward-environments)
- [Imitation Learning](#imitation-learning)
- [GAIL (Generative Adversarial Imitation Learning)](#gail-generative-adversarial-imitation-learning)
- [Behavioral Cloning (BC)](#behavioral-cloning-bc)

and intrinsic reward signals.
The ML-Agents Toolkit allows reward signals to be defined in a modular way, and
we provide three reward signals that can the mixed and matched to help shape
we provide four reward signals that can the mixed and matched to help shape
your agent's behavior:
- `extrinsic`: represents the rewards defined in your environment, and is

- `curiosity`: represents an intrinsic reward signal that encourages exploration
in sparse-reward environments that is defined by the Curiosity module (see
below).
- `rnd`: represents an intrinsic reward signal that encourages exploration
in sparse-reward environments that is defined by the Curiosity module (see
below). (Not available for TensorFlow trainers)
### Deep Reinforcement Learning

For more information, see our dedicated
[blog post on the Curiosity module](https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/).
#### RND for Sparse-reward Environments
Similarly to Curiosity, Random Network Distillation (RND) is useful in sparse or rare
reward environments as it helps the Agent explore. The RND Module is implemented following
the paper [Exploration by Random Network Distillation](https://arxiv.org/abs/1810.12894).
RND uses two networks:
- The first is a network with fixed random weights that takes observations as inputs and
generates an encoding
- The second is a network with similar architecture that is trained to predict the
outputs of the first network and uses the observations the Agent collects as training data.
The loss (the squared difference between the predicted and actual encoded observations)
of the trained model is used as intrinsic reward. The more an Agent visits a state, the
more accurate the predictions and the lower the rewards which encourages the Agent to
explore new states with higher prediction errors.
__Note:__ RND is not available for TensorFlow trainers (only PyTorch trainers)
### Imitation Learning

13
docs/Training-Configuration-File.md


- [Extrinsic Rewards](#extrinsic-rewards)
- [Curiosity Intrinsic Reward](#curiosity-intrinsic-reward)
- [GAIL Intrinsic Reward](#gail-intrinsic-reward)
- [RND Intrinsic Reward](#rnd-intrinsic-reward)
- [Reward Signal Settings for SAC](#reward-signal-settings-for-sac)
- [Behavioral Cloning](#behavioral-cloning)
- [Memory-enhanced Agents using Recurrent Neural Networks](#memory-enhanced-agents-using-recurrent-neural-networks)

| `gail -> learning_rate` | (Optional, default = `3e-4`) Learning rate used to update the discriminator. This should typically be decreased if training is unstable, and the GAIL loss is unstable. <br><br>Typical range: `1e-5` - `1e-3` |
| `gail -> use_actions` | (default = `false`) Determines whether the discriminator should discriminate based on both observations and actions, or just observations. Set to True if you want the agent to mimic the actions from the demonstrations, and False if you'd rather have the agent visit the same states as in the demonstrations but with possibly different actions. Setting to False is more likely to be stable, especially with imperfect demonstrations, but may learn slower. |
| `gail -> use_vail` | (default = `false`) Enables a variational bottleneck within the GAIL discriminator. This forces the discriminator to learn a more general representation and reduces its tendency to be "too good" at discriminating, making learning more stable. However, it does increase training time. Enable this if you notice your imitation learning is unstable, or unable to learn the task at hand. |
### RND Intrinsic Reward
Random Network Distillation (RND) is only available for the PyTorch trainers.
To enable RND, provide these settings:
| **Setting** | **Description** |
| :--------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `rnd -> strength` | (default = `1.0`) Magnitude of the curiosity reward generated by the intrinsic rnd module. This should be scaled in order to ensure it is large enough to not be overwhelmed by extrinsic reward signals in the environment. Likewise it should not be too large to overwhelm the extrinsic reward signal. <br><br>Typical range: `0.001` - `0.01` |
| `rnd -> gamma` | (default = `0.99`) Discount factor for future rewards. <br><br>Typical range: `0.8` - `0.995` |
| `rnd -> encoding_size` | (default = `64`) Size of the encoding used by the intrinsic RND model. <br><br>Typical range: `64` - `256` |
| `curiosity -> learning_rate` | (default = `3e-4`) Learning rate used to update the RND module. This should be large enough for the RND module to quickly learn the state representation, but small enough to allow for stable learning. <br><br>Typical range: `1e-5` - `1e-3`
## Behavioral Cloning

3
ml-agents/mlagents/trainers/model_saver/tf_model_saver.py


# only on worker-0 if there are multiple workers
if self.policy and self.policy.rank is not None and self.policy.rank != 0:
return
if self.graph is None:
logger.info("No model to export")
return
export_policy_model(
self.model_path, output_filepath, behavior_name, self.graph, self.sess
)

8
ml-agents/mlagents/trainers/settings.py


EXTRINSIC: str = "extrinsic"
GAIL: str = "gail"
CURIOSITY: str = "curiosity"
RND: str = "rnd"
def to_settings(self) -> type:
_mapping = {

RewardSignalType.RND: RNDSettings,
}
return _mapping[self]

class CuriositySettings(RewardSignalSettings):
encoding_size: int = 64
learning_rate: float = 3e-4
@attr.s(auto_attribs=True)
class RNDSettings(RewardSignalSettings):
encoding_size: int = 64
learning_rate: float = 1e-4
# SAMPLERS #############################################################################

3
ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py


from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401
GAILRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import ( # noqa F401
RNDRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401
create_reward_provider,
)

4
ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py


from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import (
GAILRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.rnd_reward_provider import (
RNDRewardProvider,
)
from mlagents_envs.base_env import BehaviorSpec

RewardSignalType.GAIL: GAILRewardProvider,
RewardSignalType.RND: RNDRewardProvider,
}

32
config/ppo/PyramidsRND.yaml


behaviors:
Pyramids:
trainer_type: ppo
hyperparameters:
batch_size: 128
buffer_size: 2048
learning_rate: 0.0003
beta: 0.01
epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
network_settings:
normalize: false
hidden_units: 512
num_layers: 2
vis_encode_type: simple
reward_signals:
extrinsic:
gamma: 0.99
strength: 1.0
rnd:
gamma: 0.99
strength: 0.01
encoding_size: 64
learning_rate: 0.0001
keep_checkpoints: 5
max_steps: 3000000
time_horizon: 128
summary_freq: 30000
framework: pytorch
threaded: true

69
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py


import numpy as np
import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.components.reward_providers import (
RNDRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents.trainers.settings import RNDSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)
SEED = [42]
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
curiosity_settings = RNDSettings(32, 0.01)
curiosity_settings.strength = 0.1
curiosity_rp = RNDRewardProvider(behavior_spec, curiosity_settings)
assert curiosity_rp.strength == 0.1
assert curiosity_rp.name == "RND"
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
curiosity_settings = RNDSettings(32, 0.01)
curiosity_rp = create_reward_provider(
RewardSignalType.RND, behavior_spec, curiosity_settings
)
assert curiosity_rp.name == "RND"
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
rnd_settings = RNDSettings(32, 0.01)
rnd_rp = RNDRewardProvider(behavior_spec, rnd_settings)
buffer = create_agent_buffer(behavior_spec, 5)
rnd_rp.update(buffer)
reward_old = rnd_rp.evaluate(buffer)[0]
for _ in range(100):
rnd_rp.update(buffer)
reward_new = rnd_rp.evaluate(buffer)[0]
assert reward_new < reward_old

78
ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py


import numpy as np
from typing import Dict
from mlagents.torch_utils import torch
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents.trainers.settings import RNDSettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.settings import NetworkSettings, EncoderType
class RNDRewardProvider(BaseRewardProvider):
"""
Implementation of Random Network Distillation : https://arxiv.org/pdf/1810.12894.pdf
"""
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__(specs, settings)
self._ignore_done = True
self._random_network = RNDNetwork(specs, settings)
self._training_network = RNDNetwork(specs, settings)
self.optimizer = torch.optim.Adam(
self._training_network.parameters(), lr=settings.learning_rate
)
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():
target = self._random_network(mini_batch)
prediction = self._training_network(mini_batch)
rewards = torch.sum((prediction - target) ** 2, dim=1)
return rewards.detach().cpu().numpy()
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
with torch.no_grad():
target = self._random_network(mini_batch)
prediction = self._training_network(mini_batch)
loss = torch.mean(torch.sum((prediction - target) ** 2, dim=1))
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"Losses/RND Loss": loss.detach().cpu().numpy()}
class RNDNetwork(torch.nn.Module):
EPSILON = 1e-10
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__()
self._policy_specs = specs
state_encoder_settings = NetworkSettings(
normalize=True,
hidden_units=settings.encoding_size,
num_layers=3,
vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings)
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor:
n_vis = len(self._encoder.visual_processors)
hidden, _ = self._encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
self._encoder.update_normalization(torch.tensor(mini_batch["vector_obs"]))
return hidden
正在加载...
取消
保存