浏览代码
Added Reward Providers for Torch (#4280)
Added Reward Providers for Torch (#4280)
* Added Reward Providers for Torch * Use NetworkBody to encode state in the reward providers * Integrating the reward prodiders with ppo and torch * work in progress, integration with PPO. Not training properly Pyramids at the moment * Integration in PPO * Removing duplicate file * Gail and Curiosity working * addressing comments * Enfore float32 for tests * enfore np.float32 in buffer/develop/add-fire
当前提交
7ddfd81f
共有 23 个文件被更改,包括 1111 次插入 和 49 次删除
-
2ml-agents/mlagents/trainers/buffer.py
-
26ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
-
3ml-agents/mlagents/trainers/ppo/optimizer_torch.py
-
25ml-agents/mlagents/trainers/ppo/trainer.py
-
5ml-agents/mlagents/trainers/sac/optimizer_torch.py
-
49ml-agents/mlagents/trainers/sac/trainer.py
-
6ml-agents/mlagents/trainers/tests/torch/test_utils.py
-
2ml-agents/mlagents/trainers/torch/distributions.py
-
6ml-agents/mlagents/trainers/torch/encoders.py
-
53ml-agents/mlagents/trainers/torch/utils.py
-
16ml-agents/mlagents/trainers/trainer/rl_trainer.py
-
111ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
-
56ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
-
138ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
-
32ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
-
72ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
-
228ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
-
15ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
-
257ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
-
43ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py
|
|||
import numpy as np |
|||
import pytest |
|||
import torch |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
CuriosityRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import CuriositySettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_settings.strength = 0.1 |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
assert curiosity_rp.strength == 0.1 |
|||
assert curiosity_rp.name == "Curiosity" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_rp = create_reward_provider( |
|||
RewardSignalType.CURIOSITY, behavior_spec, curiosity_settings |
|||
) |
|||
assert curiosity_rp.name == "Curiosity" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.01) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
curiosity_rp.update(buffer) |
|||
reward_old = curiosity_rp.evaluate(buffer)[0] |
|||
for _ in range(10): |
|||
curiosity_rp.update(buffer) |
|||
reward_new = curiosity_rp.evaluate(buffer)[0] |
|||
assert reward_new < reward_old |
|||
reward_old = reward_new |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5)] |
|||
) |
|||
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.1) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
for _ in range(200): |
|||
curiosity_rp.update(buffer) |
|||
prediction = curiosity_rp._network.predict_action(buffer)[0].detach() |
|||
target = buffer["actions"][0] |
|||
error = float(torch.mean((prediction - target) ** 2)) |
|||
assert error < 0.001 |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,), (64, 66, 3)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)), |
|||
], |
|||
) |
|||
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
curiosity_settings = CuriositySettings(32, 0.1) |
|||
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings) |
|||
buffer = create_agent_buffer(behavior_spec, 5) |
|||
for _ in range(100): |
|||
curiosity_rp.update(buffer) |
|||
prediction = curiosity_rp._network.predict_next_state(buffer)[0] |
|||
target = curiosity_rp._network.get_next_state(buffer)[0] |
|||
error = float(torch.mean((prediction - target) ** 2).detach()) |
|||
assert error < 0.001 |
|
|||
import pytest |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
ExtrinsicRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
settings = RewardSignalSettings() |
|||
settings.gamma = 0.2 |
|||
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings) |
|||
assert extrinsic_rp.gamma == 0.2 |
|||
assert extrinsic_rp.name == "Extrinsic" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
settings = RewardSignalSettings() |
|||
extrinsic_rp = create_reward_provider( |
|||
RewardSignalType.EXTRINSIC, behavior_spec, settings |
|||
) |
|||
assert extrinsic_rp.name == "Extrinsic" |
|||
|
|||
|
|||
@pytest.mark.parametrize("reward", [2.0, 3.0, 4.0]) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)), |
|||
], |
|||
) |
|||
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None: |
|||
buffer = create_agent_buffer(behavior_spec, 1000, reward) |
|||
settings = RewardSignalSettings() |
|||
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings) |
|||
generated_rewards = extrinsic_rp.evaluate(buffer) |
|||
assert (generated_rewards == reward).all() |
|
|||
from typing import Any |
|||
import numpy as np |
|||
import pytest |
|||
from unittest.mock import patch |
|||
import torch |
|||
import os |
|||
from mlagents.trainers.torch.components.reward_providers import ( |
|||
GAILRewardProvider, |
|||
create_reward_provider, |
|||
) |
|||
from mlagents_envs.base_env import BehaviorSpec, ActionType |
|||
from mlagents.trainers.settings import GAILSettings, RewardSignalType |
|||
from mlagents.trainers.tests.torch.test_reward_providers.utils import ( |
|||
create_agent_buffer, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( |
|||
DiscriminatorNetwork, |
|||
) |
|||
|
|||
CONTINUOUS_PATH = ( |
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) |
|||
+ "/test.demo" |
|||
) |
|||
DISCRETE_PATH = ( |
|||
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) |
|||
+ "/testdcvis.demo" |
|||
) |
|||
SEED = [42] |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)] |
|||
) |
|||
def test_construction(behavior_spec: BehaviorSpec) -> None: |
|||
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH) |
|||
gail_rp = GAILRewardProvider(behavior_spec, gail_settings) |
|||
assert gail_rp.name == "GAIL" |
|||
|
|||
|
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)] |
|||
) |
|||
def test_factory(behavior_spec: BehaviorSpec) -> None: |
|||
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH) |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
assert gail_rp.name == "GAIL" |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2), |
|||
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)), |
|||
], |
|||
) |
|||
@pytest.mark.parametrize("use_actions", [False, True]) |
|||
@patch( |
|||
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|||
) |
|||
def test_reward_decreases( |
|||
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|||
) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
buffer_expert = create_agent_buffer(behavior_spec, 1000) |
|||
buffer_policy = create_agent_buffer(behavior_spec, 1000) |
|||
demo_to_buffer.return_value = None, buffer_expert |
|||
gail_settings = GAILSettings( |
|||
demo_path="", learning_rate=0.05, use_vail=False, use_actions=use_actions |
|||
) |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
|
|||
init_reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
init_reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
|
|||
for _ in range(10): |
|||
gail_rp.update(buffer_policy) |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert >= 0 # GAIL / VAIL reward always positive |
|||
assert reward_policy >= 0 |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert > reward_policy # Expert reward greater than non-expert reward |
|||
assert ( |
|||
reward_expert > init_reward_expert |
|||
) # Expert reward getting better as network trains |
|||
assert ( |
|||
reward_policy < init_reward_policy |
|||
) # Non-expert reward getting worse as network trains |
|||
|
|||
|
|||
@pytest.mark.parametrize("seed", SEED) |
|||
@pytest.mark.parametrize( |
|||
"behavior_spec", |
|||
[ |
|||
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)), |
|||
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)), |
|||
], |
|||
) |
|||
@pytest.mark.parametrize("use_actions", [False, True]) |
|||
@patch( |
|||
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer" |
|||
) |
|||
def test_reward_decreases_vail( |
|||
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int |
|||
) -> None: |
|||
np.random.seed(seed) |
|||
torch.manual_seed(seed) |
|||
buffer_expert = create_agent_buffer(behavior_spec, 1000) |
|||
buffer_policy = create_agent_buffer(behavior_spec, 1000) |
|||
demo_to_buffer.return_value = None, buffer_expert |
|||
gail_settings = GAILSettings( |
|||
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions |
|||
) |
|||
DiscriminatorNetwork.initial_beta = 0.0 |
|||
# we must set the initial value of beta to 0 for testing |
|||
# If we do not, the kl-loss will dominate early and will block the estimator |
|||
gail_rp = create_reward_provider( |
|||
RewardSignalType.GAIL, behavior_spec, gail_settings |
|||
) |
|||
|
|||
for _ in range(100): |
|||
gail_rp.update(buffer_policy) |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert >= 0 # GAIL / VAIL reward always positive |
|||
assert reward_policy >= 0 |
|||
reward_expert = gail_rp.evaluate(buffer_expert)[0] |
|||
reward_policy = gail_rp.evaluate(buffer_policy)[0] |
|||
assert reward_expert > reward_policy # Expert reward greater than non-expert reward |
|
|||
import numpy as np |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
|
|||
|
|||
def create_agent_buffer( |
|||
behavior_spec: BehaviorSpec, number: int, reward: float = 0.0 |
|||
) -> AgentBuffer: |
|||
buffer = AgentBuffer() |
|||
curr_observations = [ |
|||
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes |
|||
] |
|||
next_observations = [ |
|||
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes |
|||
] |
|||
action = behavior_spec.create_random_action(1)[0, :] |
|||
for _ in range(number): |
|||
curr_split_obs = SplitObservations.from_observations(curr_observations) |
|||
next_split_obs = SplitObservations.from_observations(next_observations) |
|||
for i, _ in enumerate(curr_split_obs.visual_observations): |
|||
buffer["visual_obs%d" % i].append(curr_split_obs.visual_observations[i]) |
|||
buffer["next_visual_obs%d" % i].append( |
|||
next_split_obs.visual_observations[i] |
|||
) |
|||
buffer["vector_obs"].append(curr_split_obs.vector_observations) |
|||
buffer["next_vector_in"].append(next_split_obs.vector_observations) |
|||
buffer["actions"].append(action) |
|||
buffer["done"].append(np.zeros(1, dtype=np.float32)) |
|||
buffer["reward"].append(np.ones(1, dtype=np.float32) * reward) |
|||
buffer["masks"].append(np.ones(1, dtype=np.float32)) |
|||
return buffer |
|
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( # noqa F401 |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( # noqa F401 |
|||
ExtrinsicRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( # noqa F401 |
|||
CuriosityRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401 |
|||
GAILRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401 |
|||
create_reward_provider, |
|||
) |
|
|||
import numpy as np |
|||
from abc import ABC, abstractmethod |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.settings import RewardSignalSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
|
|||
|
|||
class BaseRewardProvider(ABC): |
|||
def __init__(self, specs: BehaviorSpec, settings: RewardSignalSettings) -> None: |
|||
self._policy_specs = specs |
|||
self._gamma = settings.gamma |
|||
self._strength = settings.strength |
|||
self._ignore_done = False |
|||
|
|||
@property |
|||
def gamma(self) -> float: |
|||
""" |
|||
The discount factor for the reward signal |
|||
""" |
|||
return self._gamma |
|||
|
|||
@property |
|||
def strength(self) -> float: |
|||
""" |
|||
The strength multiplier of the reward provider |
|||
""" |
|||
return self._strength |
|||
|
|||
@property |
|||
def name(self) -> str: |
|||
""" |
|||
The name of the reward provider. Is used for reporting and identification |
|||
""" |
|||
class_name = self.__class__.__name__ |
|||
return class_name.replace("RewardProvider", "") |
|||
|
|||
@property |
|||
def ignore_done(self) -> bool: |
|||
""" |
|||
If true, when the agent is done, the rewards of the next episode must be |
|||
used to calculate the return of the current episode. |
|||
Is used to mitigate the positive bias in rewards with no natural end. |
|||
""" |
|||
return self._ignore_done |
|||
|
|||
@abstractmethod |
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
""" |
|||
Evaluates the reward for the data present in the Dict mini_batch. Use this when evaluating a reward |
|||
function drawn straight from a Buffer. |
|||
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer) |
|||
when drawing from the update buffer. |
|||
:return: a np.ndarray of rewards generated by the reward provider |
|||
""" |
|||
raise NotImplementedError( |
|||
"The reward provider's evaluate method has not been implemented " |
|||
) |
|||
|
|||
@abstractmethod |
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
""" |
|||
Update the reward for the data present in the Dict mini_batch. Use this when updating a reward |
|||
function drawn straight from a Buffer. |
|||
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer) |
|||
when drawing from the update buffer. |
|||
:return: A dictionary from string to stats values |
|||
""" |
|||
raise NotImplementedError( |
|||
"The reward provider's update method has not been implemented " |
|||
) |
|
|||
import numpy as np |
|||
from typing import Dict |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.settings import CuriositySettings |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.networks import NetworkBody |
|||
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|||
|
|||
|
|||
class CuriosityRewardProvider(BaseRewardProvider): |
|||
beta = 0.2 # Forward vs Inverse loss weight |
|||
loss_multiplier = 10.0 # Loss multiplier |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: |
|||
super().__init__(specs, settings) |
|||
self._ignore_done = True |
|||
self._network = CuriosityNetwork(specs, settings) |
|||
self.optimizer = torch.optim.Adam( |
|||
self._network.parameters(), lr=settings.learning_rate |
|||
) |
|||
self._has_updated_once = False |
|||
|
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
with torch.no_grad(): |
|||
rewards = self._network.compute_reward(mini_batch).detach().cpu().numpy() |
|||
rewards = np.minimum(rewards, 1.0 / self.strength) |
|||
return rewards * self._has_updated_once |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
self._has_updated_once = True |
|||
forward_loss = self._network.compute_forward_loss(mini_batch) |
|||
inverse_loss = self._network.compute_inverse_loss(mini_batch) |
|||
|
|||
loss = self.loss_multiplier * ( |
|||
self.beta * forward_loss + (1.0 - self.beta) * inverse_loss |
|||
) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
self.optimizer.step() |
|||
return { |
|||
"Losses/Curiosity Forward Loss": forward_loss.detach().cpu().numpy(), |
|||
"Losses/Curiosity Inverse Loss": inverse_loss.detach().cpu().numpy(), |
|||
} |
|||
|
|||
|
|||
class CuriosityNetwork(torch.nn.Module): |
|||
EPSILON = 1e-10 |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None: |
|||
super().__init__() |
|||
self._policy_specs = specs |
|||
state_encoder_settings = NetworkSettings( |
|||
normalize=False, |
|||
hidden_units=settings.encoding_size, |
|||
num_layers=2, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
memory=None, |
|||
) |
|||
self._state_encoder = NetworkBody( |
|||
specs.observation_shapes, state_encoder_settings |
|||
) |
|||
|
|||
self._action_flattener = ModelUtils.ActionFlattener(specs) |
|||
|
|||
self.inverse_model_action_predition = torch.nn.Sequential( |
|||
torch.nn.Linear(2 * settings.encoding_size, 256), |
|||
ModelUtils.SwishLayer(), |
|||
torch.nn.Linear(256, self._action_flattener.flattened_size), |
|||
) |
|||
self.inverse_model_action_predition[0].bias.data.zero_() |
|||
self.inverse_model_action_predition[2].bias.data.zero_() |
|||
|
|||
self.forward_model_next_state_prediction = torch.nn.Sequential( |
|||
torch.nn.Linear( |
|||
settings.encoding_size + self._action_flattener.flattened_size, 256 |
|||
), |
|||
ModelUtils.SwishLayer(), |
|||
torch.nn.Linear(256, settings.encoding_size), |
|||
) |
|||
self.forward_model_next_state_prediction[0].bias.data.zero_() |
|||
self.forward_model_next_state_prediction[2].bias.data.zero_() |
|||
|
|||
def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Extracts the current state embedding from a mini_batch. |
|||
""" |
|||
n_vis = len(self._state_encoder.visual_encoders) |
|||
hidden, _ = self._state_encoder.forward( |
|||
vec_inputs=[ |
|||
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) |
|||
], |
|||
vis_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["visual_obs%d" % i], dtype=torch.float |
|||
) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
return hidden |
|||
|
|||
def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Extracts the next state embedding from a mini_batch. |
|||
""" |
|||
n_vis = len(self._state_encoder.visual_encoders) |
|||
hidden, _ = self._state_encoder.forward( |
|||
vec_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["next_vector_in"], dtype=torch.float |
|||
) |
|||
], |
|||
vis_inputs=[ |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["next_visual_obs%d" % i], dtype=torch.float |
|||
) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
return hidden |
|||
|
|||
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
In the continuous case, returns the predicted action. |
|||
In the discrete case, returns the logits. |
|||
""" |
|||
inverse_model_input = torch.cat( |
|||
(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1 |
|||
) |
|||
hidden = self.inverse_model_action_predition(inverse_model_input) |
|||
if self._policy_specs.is_action_continuous(): |
|||
return hidden |
|||
else: |
|||
branches = ModelUtils.break_into_branches( |
|||
hidden, self._policy_specs.discrete_action_branches |
|||
) |
|||
branches = [torch.softmax(b, dim=1) for b in branches] |
|||
return torch.cat(branches, dim=1) |
|||
|
|||
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Uses the current state embedding and the action of the mini_batch to predict |
|||
the next state embedding. |
|||
""" |
|||
if self._policy_specs.is_action_continuous(): |
|||
action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) |
|||
else: |
|||
action = torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), |
|||
self._policy_specs.discrete_action_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
forward_model_input = torch.cat( |
|||
(self.get_current_state(mini_batch), action), dim=1 |
|||
) |
|||
|
|||
return self.forward_model_next_state_prediction(forward_model_input) |
|||
|
|||
def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Computes the inverse loss for a mini_batch. Corresponds to the error on the |
|||
action prediction (given the current and next state). |
|||
""" |
|||
predicted_action = self.predict_action(mini_batch) |
|||
if self._policy_specs.is_action_continuous(): |
|||
sq_difference = ( |
|||
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) |
|||
- predicted_action |
|||
) ** 2 |
|||
sq_difference = torch.sum(sq_difference, dim=1) |
|||
return torch.mean( |
|||
ModelUtils.dynamic_partition( |
|||
sq_difference, |
|||
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), |
|||
2, |
|||
)[1] |
|||
) |
|||
else: |
|||
true_action = torch.cat( |
|||
ModelUtils.actions_to_onehot( |
|||
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), |
|||
self._policy_specs.discrete_action_branches, |
|||
), |
|||
dim=1, |
|||
) |
|||
cross_entropy = torch.sum( |
|||
-torch.log(predicted_action + self.EPSILON) * true_action, dim=1 |
|||
) |
|||
return torch.mean( |
|||
ModelUtils.dynamic_partition( |
|||
cross_entropy, |
|||
ModelUtils.list_to_tensor( |
|||
mini_batch["masks"], dtype=torch.float |
|||
), # use masks not action_masks |
|||
2, |
|||
)[1] |
|||
) |
|||
|
|||
def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Calculates the curiosity reward for the mini_batch. Corresponds to the error |
|||
between the predicted and actual next state. |
|||
""" |
|||
predicted_next_state = self.predict_next_state(mini_batch) |
|||
target = self.get_next_state(mini_batch) |
|||
sq_difference = 0.5 * (target - predicted_next_state) ** 2 |
|||
sq_difference = torch.sum(sq_difference, dim=1) |
|||
return sq_difference |
|||
|
|||
def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Computes the loss for the next state prediction |
|||
""" |
|||
return torch.mean( |
|||
ModelUtils.dynamic_partition( |
|||
self.compute_reward(mini_batch), |
|||
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float), |
|||
2, |
|||
)[1] |
|||
) |
|
|||
import numpy as np |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
|
|||
|
|||
class ExtrinsicRewardProvider(BaseRewardProvider): |
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
return np.array(mini_batch["environment_rewards"], dtype=np.float32) |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
return {} |
|
|||
from typing import Optional, Dict |
|||
import numpy as np |
|||
import torch |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.settings import GAILSettings |
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
from mlagents.trainers.torch.utils import ModelUtils |
|||
from mlagents.trainers.torch.networks import NetworkBody |
|||
from mlagents.trainers.settings import NetworkSettings, EncoderType |
|||
from mlagents.trainers.demo_loader import demo_to_buffer |
|||
|
|||
|
|||
class GAILRewardProvider(BaseRewardProvider): |
|||
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: |
|||
super().__init__(specs, settings) |
|||
self._ignore_done = True |
|||
self._discriminator_network = DiscriminatorNetwork(specs, settings) |
|||
_, self._demo_buffer = demo_to_buffer( |
|||
settings.demo_path, 1, specs |
|||
) # This is supposed to be the sequence length but we do not have access here |
|||
params = list(self._discriminator_network.parameters()) |
|||
self.optimizer = torch.optim.Adam(params, lr=settings.learning_rate) |
|||
|
|||
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: |
|||
with torch.no_grad(): |
|||
estimates, _ = self._discriminator_network.compute_estimate( |
|||
mini_batch, use_vail_noise=False |
|||
) |
|||
return ( |
|||
-torch.log( |
|||
1.0 |
|||
- estimates.squeeze(dim=1) |
|||
* (1.0 - self._discriminator_network.EPSILON) |
|||
) |
|||
.detach() |
|||
.cpu() |
|||
.numpy() |
|||
) |
|||
|
|||
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: |
|||
expert_batch = self._demo_buffer.sample_mini_batch( |
|||
mini_batch.num_experiences, 1 |
|||
) |
|||
loss, policy_mean_estimate, expert_mean_estimate, kl_loss = self._discriminator_network.compute_loss( |
|||
mini_batch, expert_batch |
|||
) |
|||
self.optimizer.zero_grad() |
|||
loss.backward() |
|||
self.optimizer.step() |
|||
stats_dict = { |
|||
"Losses/GAIL Discriminator Loss": loss.detach().cpu().numpy(), |
|||
"Policy/GAIL Policy Estimate": policy_mean_estimate.detach().cpu().numpy(), |
|||
"Policy/GAIL Expert Estimate": expert_mean_estimate.detach().cpu().numpy(), |
|||
} |
|||
if self._discriminator_network.use_vail: |
|||
stats_dict["Policy/GAIL Beta"] = ( |
|||
self._discriminator_network.beta.detach().cpu().numpy() |
|||
) |
|||
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy() |
|||
return stats_dict |
|||
|
|||
|
|||
class DiscriminatorNetwork(torch.nn.Module): |
|||
gradient_penalty_weight = 10.0 |
|||
z_size = 128 |
|||
alpha = 0.0005 |
|||
mutual_information = 0.5 |
|||
EPSILON = 1e-7 |
|||
initial_beta = 0.0 |
|||
|
|||
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None: |
|||
super().__init__() |
|||
self._policy_specs = specs |
|||
self.use_vail = settings.use_vail |
|||
self._settings = settings |
|||
|
|||
state_encoder_settings = NetworkSettings( |
|||
normalize=False, |
|||
hidden_units=settings.encoding_size, |
|||
num_layers=2, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
memory=None, |
|||
) |
|||
self._state_encoder = NetworkBody( |
|||
specs.observation_shapes, state_encoder_settings |
|||
) |
|||
|
|||
self._action_flattener = ModelUtils.ActionFlattener(specs) |
|||
|
|||
encoder_input_size = settings.encoding_size |
|||
if settings.use_actions: |
|||
encoder_input_size += ( |
|||
self._action_flattener.flattened_size + 1 |
|||
) # + 1 is for done |
|||
|
|||
self.encoder = torch.nn.Sequential( |
|||
torch.nn.Linear(encoder_input_size, settings.encoding_size), |
|||
ModelUtils.SwishLayer(), |
|||
torch.nn.Linear(settings.encoding_size, settings.encoding_size), |
|||
ModelUtils.SwishLayer(), |
|||
) |
|||
torch.nn.init.xavier_normal_(self.encoder[0].weight.data) |
|||
torch.nn.init.xavier_normal_(self.encoder[2].weight.data) |
|||
self.encoder[0].bias.data.zero_() |
|||
self.encoder[2].bias.data.zero_() |
|||
|
|||
estimator_input_size = settings.encoding_size |
|||
if settings.use_vail: |
|||
estimator_input_size = self.z_size |
|||
self.z_sigma = torch.nn.Parameter( |
|||
torch.ones((self.z_size), dtype=torch.float), requires_grad=True |
|||
) |
|||
self.z_mu_layer = torch.nn.Linear(settings.encoding_size, self.z_size) |
|||
# self.z_mu_layer.weight.data Needs a variance scale initializer |
|||
torch.nn.init.xavier_normal_(self.z_mu_layer.weight.data) |
|||
self.z_mu_layer.bias.data.zero_() |
|||
self.beta = torch.nn.Parameter( |
|||
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False |
|||
) |
|||
|
|||
self.estimator = torch.nn.Sequential( |
|||
torch.nn.Linear(estimator_input_size, 1), torch.nn.Sigmoid() |
|||
) |
|||
torch.nn.init.xavier_normal_(self.estimator[0].weight.data) |
|||
self.estimator[0].bias.data.zero_() |
|||
|
|||
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Creates the action Tensor. In continuous case, corresponds to the action. In |
|||
the discrete case, corresponds to the concatenation of one hot action Tensors. |
|||
""" |
|||
return self._action_flattener.forward( |
|||
torch.as_tensor(mini_batch["actions"], dtype=torch.float) |
|||
) |
|||
|
|||
def get_state_encoding(self, mini_batch: AgentBuffer) -> torch.Tensor: |
|||
""" |
|||
Creates the observation input. |
|||
""" |
|||
n_vis = len(self._state_encoder.visual_encoders) |
|||
hidden, _ = self._state_encoder.forward( |
|||
vec_inputs=[torch.as_tensor(mini_batch["vector_obs"], dtype=torch.float)], |
|||
vis_inputs=[ |
|||
torch.as_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float) |
|||
for i in range(n_vis) |
|||
], |
|||
) |
|||
return hidden |
|||
|
|||
def compute_estimate( |
|||
self, mini_batch: AgentBuffer, use_vail_noise: bool = False |
|||
) -> torch.Tensor: |
|||
""" |
|||
Given a mini_batch, computes the estimate (How much the discriminator believes |
|||
the data was sampled from the demonstration data). |
|||
:param mini_batch: The AgentBuffer of data |
|||
:param use_vail_noise: Only when using VAIL : If true, will sample the code, if |
|||
false, will return the mean of the code. |
|||
""" |
|||
encoder_input = self.get_state_encoding(mini_batch) |
|||
if self._settings.use_actions: |
|||
actions = self.get_action_input(mini_batch) |
|||
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float) |
|||
encoder_input = torch.cat([encoder_input, actions, dones], dim=1) |
|||
hidden = self.encoder(encoder_input) |
|||
z_mu: Optional[torch.Tensor] = None |
|||
if self._settings.use_vail: |
|||
z_mu = self.z_mu_layer(hidden) |
|||
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise) |
|||
estimate = self.estimator(hidden) |
|||
return estimate, z_mu |
|||
|
|||
def compute_loss( |
|||
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer |
|||
) -> torch.Tensor: |
|||
""" |
|||
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. |
|||
""" |
|||
policy_estimate, policy_mu = self.compute_estimate( |
|||
policy_batch, use_vail_noise=True |
|||
) |
|||
expert_estimate, expert_mu = self.compute_estimate( |
|||
expert_batch, use_vail_noise=True |
|||
) |
|||
loss = -( |
|||
torch.log(expert_estimate * (1 - self.EPSILON)) |
|||
+ torch.log(1.0 - policy_estimate * (1 - self.EPSILON)) |
|||
).mean() |
|||
kl_loss: Optional[torch.Tensor] = None |
|||
if self._settings.use_vail: |
|||
# KL divergence loss (encourage latent representation to be normal) |
|||
kl_loss = torch.mean( |
|||
-torch.sum( |
|||
1 |
|||
+ (self.z_sigma ** 2).log() |
|||
- 0.5 * expert_mu ** 2 |
|||
- 0.5 * policy_mu ** 2 |
|||
- (self.z_sigma ** 2), |
|||
dim=1, |
|||
) |
|||
) |
|||
vail_loss = self.beta * (kl_loss - self.mutual_information) |
|||
with torch.no_grad(): |
|||
self.beta.data = torch.max( |
|||
self.beta + self.alpha * (kl_loss - self.mutual_information), |
|||
torch.tensor(0.0), |
|||
) |
|||
loss += vail_loss |
|||
if self.gradient_penalty_weight > 0.0: |
|||
loss += self.gradient_penalty_weight * self.compute_gradient_magnitude( |
|||
policy_batch, expert_batch |
|||
) |
|||
return loss, torch.mean(policy_estimate), torch.mean(expert_estimate), kl_loss |
|||
|
|||
def compute_gradient_magnitude( |
|||
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer |
|||
) -> torch.Tensor: |
|||
""" |
|||
Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp. |
|||
for off-policy. Compute gradients w.r.t randomly interpolated input. |
|||
""" |
|||
policy_obs = self.get_state_encoding(policy_batch) |
|||
expert_obs = self.get_state_encoding(expert_batch) |
|||
obs_epsilon = torch.rand(policy_obs.shape) |
|||
encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs |
|||
if self._settings.use_actions: |
|||
policy_action = self.get_action_input(policy_batch) |
|||
expert_action = self.get_action_input(policy_batch) |
|||
action_epsilon = torch.rand(policy_action.shape) |
|||
policy_dones = torch.as_tensor(policy_batch["done"], dtype=torch.float) |
|||
expert_dones = torch.as_tensor(expert_batch["done"], dtype=torch.float) |
|||
dones_epsilon = torch.rand(policy_dones.shape) |
|||
encoder_input = torch.cat( |
|||
[ |
|||
encoder_input, |
|||
action_epsilon * policy_action |
|||
+ (1 - action_epsilon) * expert_action, |
|||
dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones, |
|||
], |
|||
dim=1, |
|||
) |
|||
hidden = self.encoder(encoder_input) |
|||
if self._settings.use_vail: |
|||
use_vail_noise = True |
|||
z_mu = self.z_mu_layer(hidden) |
|||
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise) |
|||
hidden = self.estimator(hidden) |
|||
estimate = torch.mean(torch.sum(hidden, dim=1)) |
|||
gradient = torch.autograd.grad(estimate, encoder_input)[0] |
|||
# Norm's gradient could be NaN at 0. Use our own safe_norm |
|||
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt() |
|||
gradient_mag = torch.mean((safe_norm - 1) ** 2) |
|||
return gradient_mag |
|
|||
from typing import Dict, Type |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
|
|||
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType |
|||
|
|||
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( |
|||
BaseRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( |
|||
ExtrinsicRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( |
|||
CuriosityRewardProvider, |
|||
) |
|||
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( |
|||
GAILRewardProvider, |
|||
) |
|||
|
|||
from mlagents_envs.base_env import BehaviorSpec |
|||
|
|||
NAME_TO_CLASS: Dict[RewardSignalType, Type[BaseRewardProvider]] = { |
|||
RewardSignalType.EXTRINSIC: ExtrinsicRewardProvider, |
|||
RewardSignalType.CURIOSITY: CuriosityRewardProvider, |
|||
RewardSignalType.GAIL: GAILRewardProvider, |
|||
} |
|||
|
|||
|
|||
def create_reward_provider( |
|||
name: RewardSignalType, specs: BehaviorSpec, settings: RewardSignalSettings |
|||
) -> BaseRewardProvider: |
|||
""" |
|||
Creates a reward provider class based on the name and config entry provided as a dict. |
|||
:param name: The name of the reward signal |
|||
:param specs: The BehaviorSpecs of the policy |
|||
:param settings: The RewardSignalSettings for that reward signal |
|||
:return: The reward signal class instantiated |
|||
""" |
|||
rcls = NAME_TO_CLASS.get(name) |
|||
if not rcls: |
|||
raise UnityTrainerException(f"Unknown reward signal type {name}") |
|||
|
|||
class_inst = rcls(specs, settings) |
|||
return class_inst |
撰写
预览
正在加载...
取消
保存
Reference in new issue