浏览代码

Added Reward Providers for Torch (#4280)

* Added Reward Providers for Torch

* Use NetworkBody to encode state in the reward providers

* Integrating the reward prodiders with ppo and torch

* work in progress, integration with PPO. Not training properly Pyramids at the moment

* Integration in PPO

* Removing duplicate file

* Gail and Curiosity working

* addressing comments

* Enfore float32 for tests

* enfore np.float32 in buffer
/develop/add-fire
GitHub 4 年前
当前提交
7ddfd81f
共有 23 个文件被更改,包括 1111 次插入49 次删除
  1. 2
      ml-agents/mlagents/trainers/buffer.py
  2. 26
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  3. 3
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py
  4. 25
      ml-agents/mlagents/trainers/ppo/trainer.py
  5. 5
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  6. 49
      ml-agents/mlagents/trainers/sac/trainer.py
  7. 6
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  8. 2
      ml-agents/mlagents/trainers/torch/distributions.py
  9. 6
      ml-agents/mlagents/trainers/torch/encoders.py
  10. 53
      ml-agents/mlagents/trainers/torch/utils.py
  11. 16
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  12. 111
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  13. 56
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
  14. 138
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  15. 32
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
  16. 15
      ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py
  17. 72
      ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py
  18. 228
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  19. 15
      ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py
  20. 257
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  21. 43
      ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py

2
ml-agents/mlagents/trainers/buffer.py


Adds a list of np.arrays to the end of the list of np.arrays.
:param data: The np.array list to append.
"""
self += list(np.array(data))
self += list(np.array(data, dtype=np.float32))
def set(self, data):
"""

26
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.components.bc.module import BCModule
from mlagents.trainers.components.reward_signals.extrinsic.signal import (
ExtrinsicRewardSignal,
)
from mlagents.trainers.torch.components.reward_providers import create_reward_provider
from mlagents.trainers.settings import TrainerSettings, RewardSignalType
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.utils import ModelUtils

Create reward signals
:param reward_signal_configs: Reward signal config.
"""
extrinsic_signal = ExtrinsicRewardSignal(
self.policy, reward_signal_configs[RewardSignalType.EXTRINSIC]
)
self.reward_signals = {RewardSignalType.EXTRINSIC.value: extrinsic_signal}
# Create reward signals
# for reward_signal, config in reward_signal_configs.items():
# self.reward_signals[reward_signal] = create_reward_signal(
# self.policy, reward_signal, config
# )
# self.update_dict.update(self.reward_signals[reward_signal].update_dict)
for reward_signal, settings in reward_signal_configs.items():
# Name reward signals by string in case we have duplicates later
self.reward_signals[reward_signal.value] = create_reward_provider(
reward_signal, self.policy.behavior_spec, settings
)
def get_value_estimates(
self, decision_requests: DecisionSteps, idx: int, done: bool

# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if self.reward_signals[k].use_terminal_states:
if not self.reward_signals[k].ignore_done:
value_estimates[k] = 0.0
return value_estimates

if done:
for k in next_value_estimate:
if self.reward_signals[k].use_terminal_states:
if not self.reward_signals[k].ignore_done:
next_value_estimate[k] = 0.0
return value_estimates, next_value_estimate

3
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


"Policy/Beta": decay_bet,
}
for reward_provider in self.reward_signals.values():
update_stats.update(reward_provider.update(batch))
return update_stats

25
ml-agents/mlagents/trainers/ppo/trainer.py


TestingConfiguration,
FrameworkType,
)
from mlagents.trainers.components.reward_signals import RewardSignal
try:
from mlagents.trainers.policy.torch_policy import TorchPolicy

for name, v in value_estimates.items():
agent_buffer_trajectory[f"{name}_value_estimates"].extend(v)
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)
if isinstance(self.optimizer.reward_signals[name], RewardSignal):
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)
else:
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value Estimate",
np.mean(v),
)
# Evaluate all reward functions
self.collected_rewards["environment"][agent_id] += np.sum(

evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
if isinstance(reward_signal, RewardSignal):
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
else:
evaluate_result = (
reward_signal.evaluate(agent_buffer_trajectory)
* reward_signal.strength
)
agent_buffer_trajectory[f"{name}_rewards"].extend(evaluate_result)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)

5
ml-agents/mlagents/trainers/sac/optimizer_torch.py


# Use to reduce "survivor bonus" when using Curiosity or GAIL.
self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()]
self.use_dones_in_backup = {
name: int(self.reward_signals[name].use_terminal_states)
name: int(not self.reward_signals[name].ignore_done)
for name in self.stream_names
}

.numpy(),
"Policy/Learning Rate": decay_lr,
}
for signal in self.reward_signals.values():
signal.update(batch)
return update_stats

49
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, SACSettings, FrameworkType
from mlagents.trainers.components.reward_signals import RewardSignal
try:
from mlagents.trainers.policy.torch_policy import TorchPolicy

agent_buffer_trajectory["environment_rewards"]
)
for name, reward_signal in self.optimizer.reward_signals.items():
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
if isinstance(reward_signal, RewardSignal):
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
else:
evaluate_result = (
reward_signal.evaluate(agent_buffer_trajectory)
* reward_signal.strength
)
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)

)
for name, v in value_estimates.items():
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)
if isinstance(self.optimizer.reward_signals[name], RewardSignal):
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)
else:
self._stats_reporter.add_stat(
f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value",
np.mean(v),
)
# Bootstrap using the last step rather than the bootstrap step if max step is reached.
# Set last element to duplicate obs and remove dones.

)
# Get rewards for each reward
for name, signal in self.optimizer.reward_signals.items():
sampled_minibatch[f"{name}_rewards"] = signal.evaluate_batch(
sampled_minibatch
).scaled_reward
if isinstance(signal, RewardSignal):
sampled_minibatch[f"{name}_rewards"] = signal.evaluate_batch(
sampled_minibatch
).scaled_reward
else:
sampled_minibatch[f"{name}_rewards"] = (
signal.evaluate(sampled_minibatch) * signal.strength
)
update_stats = self.optimizer.update(sampled_minibatch, n_sequences)
for stat_name, value in update_stats.items():

reward_signal_minibatches = {}
for name, signal in self.optimizer.reward_signals.items():
logger.debug(f"Updating {name} at step {self.step}")
# Some signals don't need a minibatch to be sampled - so we don't!
if signal.update_dict:
reward_signal_minibatches[name] = buffer.sample_mini_batch(
self.hyperparameters.batch_size,
sequence_length=self.policy.sequence_length,
)
if isinstance(signal, RewardSignal):
# Some signals don't need a minibatch to be sampled - so we don't!
if signal.update_dict:
reward_signal_minibatches[name] = buffer.sample_mini_batch(
self.hyperparameters.batch_size,
sequence_length=self.policy.sequence_length,
)
update_stats = self.optimizer.update_reward_signals(
reward_signal_minibatches, n_sequences
)

6
ml-agents/mlagents/trainers/tests/torch/test_utils.py


action_size = [2, 1, 3]
oh_actions = ModelUtils.actions_to_onehot(all_actions, action_size)
expected_result = [
torch.tensor([[0, 1], [0, 1]]),
torch.tensor([[1], [1]]),
torch.tensor([[0, 0, 1], [0, 0, 1]]),
torch.tensor([[0, 1], [0, 1]], dtype=torch.float),
torch.tensor([[1], [1]], dtype=torch.float),
torch.tensor([[0, 0, 1], [0, 0, 1]], dtype=torch.float),
]
for res, exp in zip(oh_actions, expected_result):
assert torch.equal(res, exp)

2
ml-agents/mlagents/trainers/torch/distributions.py


return torch.log(self.probs)
def entropy(self):
return torch.sum(self.probs * torch.log(self.probs), dim=-1)
return -torch.sum(self.probs * torch.log(self.probs), dim=-1)
class GaussianDistribution(nn.Module):

6
ml-agents/mlagents/trainers/torch/encoders.py


return height, width
class SwishLayer(torch.nn.Module):
def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.mul(data, torch.sigmoid(data))
class VectorEncoder(nn.Module):
def __init__(
self,

self.normalizer: Optional[Normalizer] = None
super().__init__()
self.layers = [nn.Linear(input_size, hidden_size)]
self.layers.append(SwishLayer())
if normalize:
self.normalizer = Normalizer(input_size)

53
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance

def swish(input_activation: torch.Tensor) -> torch.Tensor:
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return torch.mul(input_activation, torch.sigmoid(input_activation))
class SwishLayer(torch.nn.Module):
def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.mul(data, torch.sigmoid(data))
class ActionFlattener:
def __init__(self, behavior_spec: BehaviorSpec):
self._specs = behavior_spec
@property
def flattened_size(self) -> int:
if self._specs.is_action_continuous():
return self._specs.action_size
else:
return sum(self._specs.discrete_action_branches)
def forward(self, action: torch.Tensor) -> torch.Tensor:
if self._specs.is_action_continuous():
return action
else:
return torch.cat(
ModelUtils.actions_to_onehot(
torch.as_tensor(action, dtype=torch.long),
self._specs.discrete_action_branches,
),
dim=1,
)
@staticmethod
def update_learning_rate(optim: torch.optim.Optimizer, lr: float) -> None:

:return: List of one-hot tensors, one representing each branch.
"""
onehot_branches = [
torch.nn.functional.one_hot(_act.T, action_size[i])
for i, _act in enumerate(discrete_actions.T)
torch.nn.functional.one_hot(_act.T, action_size[i]).float()
for i, _act in enumerate(discrete_actions.long().T)
@staticmethod
def dynamic_partition(
data: torch.Tensor, partitions: torch.Tensor, num_partitions: int
) -> List[torch.Tensor]:
"""
Torch implementation of dynamic_partition :
https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
Splits the data Tensor input into num_partitions Tensors according to the indices in
partitions.
:param data: The Tensor data that will be split into partitions.
:param partitions: An indices tensor that determines in which partition each element
of data will be in.
:param num_partitions: The number of partitions to output. Corresponds to the
maximum possible index in the partitions argument.
:return: A list of Tensor partitions (Their indices correspond to their partition index).
"""
res: List[torch.Tensor] = []
for i in range(num_partitions):
res += [data[(partitions == i).nonzero().squeeze(1)]]
return res
@staticmethod
def get_probs_and_entropy(

16
ml-agents/mlagents/trainers/trainer/rl_trainer.py


from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.components.reward_signals import RewardSignalResult
from mlagents.trainers.components.reward_signals import RewardSignalResult, RewardSignal
from mlagents_envs.timers import hierarchical_timer
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.policy import Policy

)
self.framework = self.trainer_settings.framework
logger.debug(f"Using framework {self.framework.value}")
if TestingConfiguration.max_steps > 0:
self.trainer_settings.max_steps = TestingConfiguration.max_steps
self._next_save_step = 0

self.reward_buffer.appendleft(rewards.get(agent_id, 0))
rewards[agent_id] = 0
else:
self.stats_reporter.add_stat(
optimizer.reward_signals[name].stat_name, rewards.get(agent_id, 0)
)
if isinstance(optimizer.reward_signals[name], RewardSignal):
self.stats_reporter.add_stat(
optimizer.reward_signals[name].stat_name,
rewards.get(agent_id, 0),
)
else:
self.stats_reporter.add_stat(
f"Policy/{optimizer.reward_signals[name].name.capitalize()} Reward",
rewards.get(agent_id, 0),
)
rewards[agent_id] = 0
def _clear_update_buffer(self) -> None:

111
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


import numpy as np
import pytest
import torch
from mlagents.trainers.torch.components.reward_providers import (
CuriosityRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents.trainers.settings import CuriositySettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)
SEED = [42]
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
curiosity_settings = CuriositySettings(32, 0.01)
curiosity_settings.strength = 0.1
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
assert curiosity_rp.strength == 0.1
assert curiosity_rp.name == "Curiosity"
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
curiosity_settings = CuriositySettings(32, 0.01)
curiosity_rp = create_reward_provider(
RewardSignalType.CURIOSITY, behavior_spec, curiosity_settings
)
assert curiosity_rp.name == "Curiosity"
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
curiosity_settings = CuriositySettings(32, 0.01)
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
buffer = create_agent_buffer(behavior_spec, 5)
curiosity_rp.update(buffer)
reward_old = curiosity_rp.evaluate(buffer)[0]
for _ in range(10):
curiosity_rp.update(buffer)
reward_new = curiosity_rp.evaluate(buffer)[0]
assert reward_new < reward_old
reward_old = reward_new
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5)]
)
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
curiosity_settings = CuriositySettings(32, 0.1)
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
buffer = create_agent_buffer(behavior_spec, 5)
for _ in range(200):
curiosity_rp.update(buffer)
prediction = curiosity_rp._network.predict_action(buffer)[0].detach()
target = buffer["actions"][0]
error = float(torch.mean((prediction - target) ** 2))
assert error < 0.001
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
],
)
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
curiosity_settings = CuriositySettings(32, 0.1)
curiosity_rp = CuriosityRewardProvider(behavior_spec, curiosity_settings)
buffer = create_agent_buffer(behavior_spec, 5)
for _ in range(100):
curiosity_rp.update(buffer)
prediction = curiosity_rp._network.predict_next_state(buffer)[0]
target = curiosity_rp._network.get_next_state(buffer)[0]
error = float(torch.mean((prediction - target) ** 2).detach())
assert error < 0.001

56
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py


import pytest
from mlagents.trainers.torch.components.reward_providers import (
ExtrinsicRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
settings = RewardSignalSettings()
settings.gamma = 0.2
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings)
assert extrinsic_rp.gamma == 0.2
assert extrinsic_rp.name == "Extrinsic"
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
settings = RewardSignalSettings()
extrinsic_rp = create_reward_provider(
RewardSignalType.EXTRINSIC, behavior_spec, settings
)
assert extrinsic_rp.name == "Extrinsic"
@pytest.mark.parametrize("reward", [2.0, 3.0, 4.0])
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
],
)
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None:
buffer = create_agent_buffer(behavior_spec, 1000, reward)
settings = RewardSignalSettings()
extrinsic_rp = ExtrinsicRewardProvider(behavior_spec, settings)
generated_rewards = extrinsic_rp.evaluate(buffer)
assert (generated_rewards == reward).all()

138
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


from typing import Any
import numpy as np
import pytest
from unittest.mock import patch
import torch
import os
from mlagents.trainers.torch.components.reward_providers import (
GAILRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents.trainers.settings import GAILSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import (
DiscriminatorNetwork,
)
CONTINUOUS_PATH = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
+ "/test.demo"
)
DISCRETE_PATH = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
+ "/testdcvis.demo"
)
SEED = [42]
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
def test_construction(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = GAILRewardProvider(behavior_spec, gail_settings)
assert gail_rp.name == "GAIL"
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
def test_factory(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = create_reward_provider(
RewardSignalType.GAIL, behavior_spec, gail_settings
)
assert gail_rp.name == "GAIL"
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
],
)
@pytest.mark.parametrize("use_actions", [False, True])
@patch(
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer"
)
def test_reward_decreases(
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
buffer_expert = create_agent_buffer(behavior_spec, 1000)
buffer_policy = create_agent_buffer(behavior_spec, 1000)
demo_to_buffer.return_value = None, buffer_expert
gail_settings = GAILSettings(
demo_path="", learning_rate=0.05, use_vail=False, use_actions=use_actions
)
gail_rp = create_reward_provider(
RewardSignalType.GAIL, behavior_spec, gail_settings
)
init_reward_expert = gail_rp.evaluate(buffer_expert)[0]
init_reward_policy = gail_rp.evaluate(buffer_policy)[0]
for _ in range(10):
gail_rp.update(buffer_policy)
reward_expert = gail_rp.evaluate(buffer_expert)[0]
reward_policy = gail_rp.evaluate(buffer_policy)[0]
assert reward_expert >= 0 # GAIL / VAIL reward always positive
assert reward_policy >= 0
reward_expert = gail_rp.evaluate(buffer_expert)[0]
reward_policy = gail_rp.evaluate(buffer_policy)[0]
assert reward_expert > reward_policy # Expert reward greater than non-expert reward
assert (
reward_expert > init_reward_expert
) # Expert reward getting better as network trains
assert (
reward_policy < init_reward_policy
) # Non-expert reward getting worse as network trains
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
],
)
@pytest.mark.parametrize("use_actions", [False, True])
@patch(
"mlagents.trainers.torch.components.reward_providers.gail_reward_provider.demo_to_buffer"
)
def test_reward_decreases_vail(
demo_to_buffer: Any, use_actions: bool, behavior_spec: BehaviorSpec, seed: int
) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
buffer_expert = create_agent_buffer(behavior_spec, 1000)
buffer_policy = create_agent_buffer(behavior_spec, 1000)
demo_to_buffer.return_value = None, buffer_expert
gail_settings = GAILSettings(
demo_path="", learning_rate=0.005, use_vail=True, use_actions=use_actions
)
DiscriminatorNetwork.initial_beta = 0.0
# we must set the initial value of beta to 0 for testing
# If we do not, the kl-loss will dominate early and will block the estimator
gail_rp = create_reward_provider(
RewardSignalType.GAIL, behavior_spec, gail_settings
)
for _ in range(100):
gail_rp.update(buffer_policy)
reward_expert = gail_rp.evaluate(buffer_expert)[0]
reward_policy = gail_rp.evaluate(buffer_policy)[0]
assert reward_expert >= 0 # GAIL / VAIL reward always positive
assert reward_policy >= 0
reward_expert = gail_rp.evaluate(buffer_expert)[0]
reward_policy = gail_rp.evaluate(buffer_policy)[0]
assert reward_expert > reward_policy # Expert reward greater than non-expert reward

32
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py


import numpy as np
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.trajectory import SplitObservations
def create_agent_buffer(
behavior_spec: BehaviorSpec, number: int, reward: float = 0.0
) -> AgentBuffer:
buffer = AgentBuffer()
curr_observations = [
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes
]
next_observations = [
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes
]
action = behavior_spec.create_random_action(1)[0, :]
for _ in range(number):
curr_split_obs = SplitObservations.from_observations(curr_observations)
next_split_obs = SplitObservations.from_observations(next_observations)
for i, _ in enumerate(curr_split_obs.visual_observations):
buffer["visual_obs%d" % i].append(curr_split_obs.visual_observations[i])
buffer["next_visual_obs%d" % i].append(
next_split_obs.visual_observations[i]
)
buffer["vector_obs"].append(curr_split_obs.vector_observations)
buffer["next_vector_in"].append(next_split_obs.vector_observations)
buffer["actions"].append(action)
buffer["done"].append(np.zeros(1, dtype=np.float32))
buffer["reward"].append(np.ones(1, dtype=np.float32) * reward)
buffer["masks"].append(np.ones(1, dtype=np.float32))
return buffer

15
ml-agents/mlagents/trainers/torch/components/reward_providers/__init__.py


from mlagents.trainers.torch.components.reward_providers.base_reward_provider import ( # noqa F401
BaseRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import ( # noqa F401
ExtrinsicRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import ( # noqa F401
CuriosityRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import ( # noqa F401
GAILRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.reward_provider_factory import ( # noqa F401
create_reward_provider,
)

72
ml-agents/mlagents/trainers/torch/components/reward_providers/base_reward_provider.py


import numpy as np
from abc import ABC, abstractmethod
from typing import Dict
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.settings import RewardSignalSettings
from mlagents_envs.base_env import BehaviorSpec
class BaseRewardProvider(ABC):
def __init__(self, specs: BehaviorSpec, settings: RewardSignalSettings) -> None:
self._policy_specs = specs
self._gamma = settings.gamma
self._strength = settings.strength
self._ignore_done = False
@property
def gamma(self) -> float:
"""
The discount factor for the reward signal
"""
return self._gamma
@property
def strength(self) -> float:
"""
The strength multiplier of the reward provider
"""
return self._strength
@property
def name(self) -> str:
"""
The name of the reward provider. Is used for reporting and identification
"""
class_name = self.__class__.__name__
return class_name.replace("RewardProvider", "")
@property
def ignore_done(self) -> bool:
"""
If true, when the agent is done, the rewards of the next episode must be
used to calculate the return of the current episode.
Is used to mitigate the positive bias in rewards with no natural end.
"""
return self._ignore_done
@abstractmethod
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
"""
Evaluates the reward for the data present in the Dict mini_batch. Use this when evaluating a reward
function drawn straight from a Buffer.
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer)
when drawing from the update buffer.
:return: a np.ndarray of rewards generated by the reward provider
"""
raise NotImplementedError(
"The reward provider's evaluate method has not been implemented "
)
@abstractmethod
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
"""
Update the reward for the data present in the Dict mini_batch. Use this when updating a reward
function drawn straight from a Buffer.
:param mini_batch: A Dict of numpy arrays (the format used by our Buffer)
when drawing from the update buffer.
:return: A dictionary from string to stats values
"""
raise NotImplementedError(
"The reward provider's update method has not been implemented "
)

228
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


import numpy as np
from typing import Dict
import torch
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents.trainers.settings import CuriositySettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.settings import NetworkSettings, EncoderType
class CuriosityRewardProvider(BaseRewardProvider):
beta = 0.2 # Forward vs Inverse loss weight
loss_multiplier = 10.0 # Loss multiplier
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__(specs, settings)
self._ignore_done = True
self._network = CuriosityNetwork(specs, settings)
self.optimizer = torch.optim.Adam(
self._network.parameters(), lr=settings.learning_rate
)
self._has_updated_once = False
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():
rewards = self._network.compute_reward(mini_batch).detach().cpu().numpy()
rewards = np.minimum(rewards, 1.0 / self.strength)
return rewards * self._has_updated_once
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
self._has_updated_once = True
forward_loss = self._network.compute_forward_loss(mini_batch)
inverse_loss = self._network.compute_inverse_loss(mini_batch)
loss = self.loss_multiplier * (
self.beta * forward_loss + (1.0 - self.beta) * inverse_loss
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {
"Losses/Curiosity Forward Loss": forward_loss.detach().cpu().numpy(),
"Losses/Curiosity Inverse Loss": inverse_loss.detach().cpu().numpy(),
}
class CuriosityNetwork(torch.nn.Module):
EPSILON = 1e-10
def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__()
self._policy_specs = specs
state_encoder_settings = NetworkSettings(
normalize=False,
hidden_units=settings.encoding_size,
num_layers=2,
vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._state_encoder = NetworkBody(
specs.observation_shapes, state_encoder_settings
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self.inverse_model_action_predition = torch.nn.Sequential(
torch.nn.Linear(2 * settings.encoding_size, 256),
ModelUtils.SwishLayer(),
torch.nn.Linear(256, self._action_flattener.flattened_size),
)
self.inverse_model_action_predition[0].bias.data.zero_()
self.inverse_model_action_predition[2].bias.data.zero_()
self.forward_model_next_state_prediction = torch.nn.Sequential(
torch.nn.Linear(
settings.encoding_size + self._action_flattener.flattened_size, 256
),
ModelUtils.SwishLayer(),
torch.nn.Linear(256, settings.encoding_size),
)
self.forward_model_next_state_prediction[0].bias.data.zero_()
self.forward_model_next_state_prediction[2].bias.data.zero_()
def get_current_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Extracts the current state embedding from a mini_batch.
"""
n_vis = len(self._state_encoder.visual_encoders)
hidden, _ = self._state_encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
return hidden
def get_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Extracts the next state embedding from a mini_batch.
"""
n_vis = len(self._state_encoder.visual_encoders)
hidden, _ = self._state_encoder.forward(
vec_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_vector_in"], dtype=torch.float
)
],
vis_inputs=[
ModelUtils.list_to_tensor(
mini_batch["next_visual_obs%d" % i], dtype=torch.float
)
for i in range(n_vis)
],
)
return hidden
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
In the continuous case, returns the predicted action.
In the discrete case, returns the logits.
"""
inverse_model_input = torch.cat(
(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1
)
hidden = self.inverse_model_action_predition(inverse_model_input)
if self._policy_specs.is_action_continuous():
return hidden
else:
branches = ModelUtils.break_into_branches(
hidden, self._policy_specs.discrete_action_branches
)
branches = [torch.softmax(b, dim=1) for b in branches]
return torch.cat(branches, dim=1)
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Uses the current state embedding and the action of the mini_batch to predict
the next state embedding.
"""
if self._policy_specs.is_action_continuous():
action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
else:
action = torch.cat(
ModelUtils.actions_to_onehot(
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
self._policy_specs.discrete_action_branches,
),
dim=1,
)
forward_model_input = torch.cat(
(self.get_current_state(mini_batch), action), dim=1
)
return self.forward_model_next_state_prediction(forward_model_input)
def compute_inverse_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Computes the inverse loss for a mini_batch. Corresponds to the error on the
action prediction (given the current and next state).
"""
predicted_action = self.predict_action(mini_batch)
if self._policy_specs.is_action_continuous():
sq_difference = (
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
- predicted_action
) ** 2
sq_difference = torch.sum(sq_difference, dim=1)
return torch.mean(
ModelUtils.dynamic_partition(
sq_difference,
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
2,
)[1]
)
else:
true_action = torch.cat(
ModelUtils.actions_to_onehot(
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
self._policy_specs.discrete_action_branches,
),
dim=1,
)
cross_entropy = torch.sum(
-torch.log(predicted_action + self.EPSILON) * true_action, dim=1
)
return torch.mean(
ModelUtils.dynamic_partition(
cross_entropy,
ModelUtils.list_to_tensor(
mini_batch["masks"], dtype=torch.float
), # use masks not action_masks
2,
)[1]
)
def compute_reward(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Calculates the curiosity reward for the mini_batch. Corresponds to the error
between the predicted and actual next state.
"""
predicted_next_state = self.predict_next_state(mini_batch)
target = self.get_next_state(mini_batch)
sq_difference = 0.5 * (target - predicted_next_state) ** 2
sq_difference = torch.sum(sq_difference, dim=1)
return sq_difference
def compute_forward_loss(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Computes the loss for the next state prediction
"""
return torch.mean(
ModelUtils.dynamic_partition(
self.compute_reward(mini_batch),
ModelUtils.list_to_tensor(mini_batch["masks"], dtype=torch.float),
2,
)[1]
)

15
ml-agents/mlagents/trainers/torch/components/reward_providers/extrinsic_reward_provider.py


import numpy as np
from typing import Dict
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
class ExtrinsicRewardProvider(BaseRewardProvider):
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
return np.array(mini_batch["environment_rewards"], dtype=np.float32)
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
return {}

257
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


from typing import Optional, Dict
import numpy as np
import torch
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents.trainers.settings import GAILSettings
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.torch.networks import NetworkBody
from mlagents.trainers.settings import NetworkSettings, EncoderType
from mlagents.trainers.demo_loader import demo_to_buffer
class GAILRewardProvider(BaseRewardProvider):
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
super().__init__(specs, settings)
self._ignore_done = True
self._discriminator_network = DiscriminatorNetwork(specs, settings)
_, self._demo_buffer = demo_to_buffer(
settings.demo_path, 1, specs
) # This is supposed to be the sequence length but we do not have access here
params = list(self._discriminator_network.parameters())
self.optimizer = torch.optim.Adam(params, lr=settings.learning_rate)
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray:
with torch.no_grad():
estimates, _ = self._discriminator_network.compute_estimate(
mini_batch, use_vail_noise=False
)
return (
-torch.log(
1.0
- estimates.squeeze(dim=1)
* (1.0 - self._discriminator_network.EPSILON)
)
.detach()
.cpu()
.numpy()
)
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]:
expert_batch = self._demo_buffer.sample_mini_batch(
mini_batch.num_experiences, 1
)
loss, policy_mean_estimate, expert_mean_estimate, kl_loss = self._discriminator_network.compute_loss(
mini_batch, expert_batch
)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
stats_dict = {
"Losses/GAIL Discriminator Loss": loss.detach().cpu().numpy(),
"Policy/GAIL Policy Estimate": policy_mean_estimate.detach().cpu().numpy(),
"Policy/GAIL Expert Estimate": expert_mean_estimate.detach().cpu().numpy(),
}
if self._discriminator_network.use_vail:
stats_dict["Policy/GAIL Beta"] = (
self._discriminator_network.beta.detach().cpu().numpy()
)
stats_dict["Losses/GAIL KL Loss"] = kl_loss.detach().cpu().numpy()
return stats_dict
class DiscriminatorNetwork(torch.nn.Module):
gradient_penalty_weight = 10.0
z_size = 128
alpha = 0.0005
mutual_information = 0.5
EPSILON = 1e-7
initial_beta = 0.0
def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
super().__init__()
self._policy_specs = specs
self.use_vail = settings.use_vail
self._settings = settings
state_encoder_settings = NetworkSettings(
normalize=False,
hidden_units=settings.encoding_size,
num_layers=2,
vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._state_encoder = NetworkBody(
specs.observation_shapes, state_encoder_settings
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
encoder_input_size = settings.encoding_size
if settings.use_actions:
encoder_input_size += (
self._action_flattener.flattened_size + 1
) # + 1 is for done
self.encoder = torch.nn.Sequential(
torch.nn.Linear(encoder_input_size, settings.encoding_size),
ModelUtils.SwishLayer(),
torch.nn.Linear(settings.encoding_size, settings.encoding_size),
ModelUtils.SwishLayer(),
)
torch.nn.init.xavier_normal_(self.encoder[0].weight.data)
torch.nn.init.xavier_normal_(self.encoder[2].weight.data)
self.encoder[0].bias.data.zero_()
self.encoder[2].bias.data.zero_()
estimator_input_size = settings.encoding_size
if settings.use_vail:
estimator_input_size = self.z_size
self.z_sigma = torch.nn.Parameter(
torch.ones((self.z_size), dtype=torch.float), requires_grad=True
)
self.z_mu_layer = torch.nn.Linear(settings.encoding_size, self.z_size)
# self.z_mu_layer.weight.data Needs a variance scale initializer
torch.nn.init.xavier_normal_(self.z_mu_layer.weight.data)
self.z_mu_layer.bias.data.zero_()
self.beta = torch.nn.Parameter(
torch.tensor(self.initial_beta, dtype=torch.float), requires_grad=False
)
self.estimator = torch.nn.Sequential(
torch.nn.Linear(estimator_input_size, 1), torch.nn.Sigmoid()
)
torch.nn.init.xavier_normal_(self.estimator[0].weight.data)
self.estimator[0].bias.data.zero_()
def get_action_input(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Creates the action Tensor. In continuous case, corresponds to the action. In
the discrete case, corresponds to the concatenation of one hot action Tensors.
"""
return self._action_flattener.forward(
torch.as_tensor(mini_batch["actions"], dtype=torch.float)
)
def get_state_encoding(self, mini_batch: AgentBuffer) -> torch.Tensor:
"""
Creates the observation input.
"""
n_vis = len(self._state_encoder.visual_encoders)
hidden, _ = self._state_encoder.forward(
vec_inputs=[torch.as_tensor(mini_batch["vector_obs"], dtype=torch.float)],
vis_inputs=[
torch.as_tensor(mini_batch["visual_obs%d" % i], dtype=torch.float)
for i in range(n_vis)
],
)
return hidden
def compute_estimate(
self, mini_batch: AgentBuffer, use_vail_noise: bool = False
) -> torch.Tensor:
"""
Given a mini_batch, computes the estimate (How much the discriminator believes
the data was sampled from the demonstration data).
:param mini_batch: The AgentBuffer of data
:param use_vail_noise: Only when using VAIL : If true, will sample the code, if
false, will return the mean of the code.
"""
encoder_input = self.get_state_encoding(mini_batch)
if self._settings.use_actions:
actions = self.get_action_input(mini_batch)
dones = torch.as_tensor(mini_batch["done"], dtype=torch.float)
encoder_input = torch.cat([encoder_input, actions, dones], dim=1)
hidden = self.encoder(encoder_input)
z_mu: Optional[torch.Tensor] = None
if self._settings.use_vail:
z_mu = self.z_mu_layer(hidden)
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise)
estimate = self.estimator(hidden)
return estimate, z_mu
def compute_loss(
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
) -> torch.Tensor:
"""
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
"""
policy_estimate, policy_mu = self.compute_estimate(
policy_batch, use_vail_noise=True
)
expert_estimate, expert_mu = self.compute_estimate(
expert_batch, use_vail_noise=True
)
loss = -(
torch.log(expert_estimate * (1 - self.EPSILON))
+ torch.log(1.0 - policy_estimate * (1 - self.EPSILON))
).mean()
kl_loss: Optional[torch.Tensor] = None
if self._settings.use_vail:
# KL divergence loss (encourage latent representation to be normal)
kl_loss = torch.mean(
-torch.sum(
1
+ (self.z_sigma ** 2).log()
- 0.5 * expert_mu ** 2
- 0.5 * policy_mu ** 2
- (self.z_sigma ** 2),
dim=1,
)
)
vail_loss = self.beta * (kl_loss - self.mutual_information)
with torch.no_grad():
self.beta.data = torch.max(
self.beta + self.alpha * (kl_loss - self.mutual_information),
torch.tensor(0.0),
)
loss += vail_loss
if self.gradient_penalty_weight > 0.0:
loss += self.gradient_penalty_weight * self.compute_gradient_magnitude(
policy_batch, expert_batch
)
return loss, torch.mean(policy_estimate), torch.mean(expert_estimate), kl_loss
def compute_gradient_magnitude(
self, policy_batch: AgentBuffer, expert_batch: AgentBuffer
) -> torch.Tensor:
"""
Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
for off-policy. Compute gradients w.r.t randomly interpolated input.
"""
policy_obs = self.get_state_encoding(policy_batch)
expert_obs = self.get_state_encoding(expert_batch)
obs_epsilon = torch.rand(policy_obs.shape)
encoder_input = obs_epsilon * policy_obs + (1 - obs_epsilon) * expert_obs
if self._settings.use_actions:
policy_action = self.get_action_input(policy_batch)
expert_action = self.get_action_input(policy_batch)
action_epsilon = torch.rand(policy_action.shape)
policy_dones = torch.as_tensor(policy_batch["done"], dtype=torch.float)
expert_dones = torch.as_tensor(expert_batch["done"], dtype=torch.float)
dones_epsilon = torch.rand(policy_dones.shape)
encoder_input = torch.cat(
[
encoder_input,
action_epsilon * policy_action
+ (1 - action_epsilon) * expert_action,
dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones,
],
dim=1,
)
hidden = self.encoder(encoder_input)
if self._settings.use_vail:
use_vail_noise = True
z_mu = self.z_mu_layer(hidden)
hidden = torch.normal(z_mu, self.z_sigma * use_vail_noise)
hidden = self.estimator(hidden)
estimate = torch.mean(torch.sum(hidden, dim=1))
gradient = torch.autograd.grad(estimate, encoder_input)[0]
# Norm's gradient could be NaN at 0. Use our own safe_norm
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
gradient_mag = torch.mean((safe_norm - 1) ** 2)
return gradient_mag

43
ml-agents/mlagents/trainers/torch/components/reward_providers/reward_provider_factory.py


from typing import Dict, Type
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import (
ExtrinsicRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.curiosity_reward_provider import (
CuriosityRewardProvider,
)
from mlagents.trainers.torch.components.reward_providers.gail_reward_provider import (
GAILRewardProvider,
)
from mlagents_envs.base_env import BehaviorSpec
NAME_TO_CLASS: Dict[RewardSignalType, Type[BaseRewardProvider]] = {
RewardSignalType.EXTRINSIC: ExtrinsicRewardProvider,
RewardSignalType.CURIOSITY: CuriosityRewardProvider,
RewardSignalType.GAIL: GAILRewardProvider,
}
def create_reward_provider(
name: RewardSignalType, specs: BehaviorSpec, settings: RewardSignalSettings
) -> BaseRewardProvider:
"""
Creates a reward provider class based on the name and config entry provided as a dict.
:param name: The name of the reward signal
:param specs: The BehaviorSpecs of the policy
:param settings: The RewardSignalSettings for that reward signal
:return: The reward signal class instantiated
"""
rcls = NAME_TO_CLASS.get(name)
if not rcls:
raise UnityTrainerException(f"Unknown reward signal type {name}")
class_inst = rcls(specs, settings)
return class_inst
正在加载...
取消
保存