浏览代码
Random Network Distillation for Torch (#4473)
Random Network Distillation for Torch (#4473)
* initial commit * works with Pyramids * added unit tests and a separate config file * Adding first batch of documentation * adding in the docs that rnd is only for PyTorch * adding newline at the end of the config files * adding some docs * Code comments * no normalization of the reward * Fixing the tests * [skip ci] * [skip ci] Make sure RND will only work for Torch by editing the config file * [skip ci] Additional information in the Documentation * Remove the _has_updated_once flag/MLA-1734-demo-provider
GitHub
4 年前
当前提交
60b76790
共有 10 个文件被更改,包括 236 次插入 和 1 次删除
-
3com.unity.ml-agents/CHANGELOG.md
-
24docs/ML-Agents-Overview.md
-
13docs/Training-Configuration-File.md
-
3ml-agents/mlagents/trainers/model_saver/tf_model_saver.py
-
8ml-agents/mlagents/trainers/settings.py
-
3ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
4ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
-
32config/ppo/PyramidsRND.yaml
-
69ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
-
78ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
|
|||
behaviors: |
|||
Pyramids: |
|||
trainer_type: ppo |
|||
hyperparameters: |
|||
batch_size: 128 |
|||
buffer_size: 2048 |
|||
learning_rate: 0.0003 |
|||
beta: 0.01 |
|||
epsilon: 0.2 |
|||
lambd: 0.95 |
|||
num_epoch: 3 |
|||
learning_rate_schedule: linear |
|||
network_settings: |
|||
normalize: false |
|||
hidden_units: 512 |
|||
num_layers: 2 |
|||
vis_encode_type: simple |
|||
reward_signals: |
|||
extrinsic: |
|||
gamma: 0.99 |
|||
strength: 1.0 |
|||
rnd: |
|||
gamma: 0.99 |
|||
strength: 0.01 |
|||
encoding_size: 64 |
|||
learning_rate: 0.0001 |
|||
keep_checkpoints: 5 |
|||
max_steps: 3000000 |
|||
time_horizon: 128 |
|||
summary_freq: 30000 |
|||
framework: pytorch |
|||
threaded: true |
|
|||
import numpy as np |
|||
import pytest |
|||
from mlagents.torch_utils import torch |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
RNDRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import RNDSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = RNDSettings(32, 0.01) |
|||
curiosity_settings.strength = 0.1 |
|||
curiosity_rp = RNDRewardProvider(behavior_spec, curiosity_settings) |
|||
assert curiosity_rp.strength == 0.1 |
|||
assert curiosity_rp.name == "RND" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = RNDSettings(32, 0.01) |
|||
curiosity_rp = create_reward_provider( |
|||
RewardSignalType.RND, behavior_spec, curiosity_settings |
|||
) |
|||
assert curiosity_rp.name == "RND" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
rnd_settings = RNDSettings(32, 0.01) |
|||
rnd_rp = RNDRewardProvider(behavior_spec, rnd_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
rnd_rp.update(buffer) |
|||
reward_old = rnd_rp.evaluate(buffer)[0] |
|||
for _ in range(100): |
|||
rnd_rp.update(buffer) |
|||
reward_new = rnd_rp.evaluate(buffer)[0] |
|||
assert reward_new < reward_old |
|
|||
import numpy as np |
|||
from typing import Dict |
|||
from mlagents.torch_utils import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.settings import RNDSettings |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.networks import NetworkBody |
|||
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|||
|
|||
|
|||
class RNDRewardProvider(BaseRewardProvider): |
|||
""" |
|||
Implementation of Random Network Distillation : https://arxiv.org/pdf/1810.12894.pdf |
|||
""" |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: |
|||
super().__init__(specs, settings) |
|||
self._ignore_done = True |
|||
self._random_network = RNDNetwork(specs, settings) |
|||
self._training_network = RNDNetwork(specs, settings) |
|||
self.optimizer = torch.optim.Adam( |
|||
self._training_network.parameters(), lr=settings.learning_rate |
|||
) |
|||
|
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
with torch.no_grad(): |
|||
target = self._random_network(mini_batch) |
|||
prediction = self._training_network(mini_batch) |
|||
rewards = torch.sum((prediction - target) ** 2, dim=1) |
|||
return rewards.detach().cpu().numpy() |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
with torch.no_grad(): |
|||
target = self._random_network(mini_batch) |
|||
prediction = self._training_network(mini_batch) |
|||
loss = torch.mean(torch.sum((prediction - target) ** 2, dim=1)) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
self.optimizer.step() |
|||
return {"Losses/RND Loss": loss.detach().cpu().numpy()} |
|||
|
|||
|
|||
class RNDNetwork(torch.nn.Module): |
|||
EPSILON = 1e-10 |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None: |
|||
super().__init__() |
|||
self._policy_specs = specs |
|||
state_encoder_settings = NetworkSettings( |
|||
normalize=True, |
|||
hidden_units=settings.encoding_size, |
|||
num_layers=3, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
memory=None, |
|||
) |
|||
self._encoder = NetworkBody(specs.observation_shapes, state_encoder_settings) |
|||
|
|||
def forward(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
n_vis = len(self._encoder.visual_processors) |
|||
hidden, _ = self._encoder.forward( |
|||
vec_inputs=[ |
|||
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) |
|||
], |
|||
vis_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["visual_obs%d" % i], dtype=torch.float |
|||
) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
self._encoder.update_normalization(torch.tensor(mini_batch["vector_obs"])) |
|||
return hidden |
撰写
预览
正在加载...
取消
保存
Reference in new issue