浏览代码
Refactor reward signals into separate class (#2144)
Refactor reward signals into separate class (#2144)
* Create new class (RewardSignal) that represents a reward signal. * Add value heads for each reward signal in the PPO model. * Make summaries agnostic to the type of reward signals, and log weighted rewards per reward signal. * Move extrinsic and curiosity rewards into this new structure. * Allow defining multiple reward signals in YAML file. Add documentation for this new structure./develop-generalizationTraining-TrainerController
GitHub
5 年前
当前提交
4ac79742
共有 27 个文件被更改,包括 1302 次插入 和 636 次删除
-
56config/trainer_config.yaml
-
45docs/Training-PPO.md
-
2ml-agents/mlagents/trainers/bc/policy.py
-
65ml-agents/mlagents/trainers/demo_loader.py
-
122ml-agents/mlagents/trainers/models.py
-
2ml-agents/mlagents/trainers/policy.py
-
222ml-agents/mlagents/trainers/ppo/models.py
-
138ml-agents/mlagents/trainers/ppo/policy.py
-
231ml-agents/mlagents/trainers/ppo/trainer.py
-
158ml-agents/mlagents/trainers/tests/test_ppo.py
-
22ml-agents/mlagents/trainers/trainer.py
-
114docs/Training-RewardSignals.md
-
92ml-agents/mlagents/trainers/tests/mock_brain.py
-
156ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
0ml-agents/mlagents/trainers/components/__init__.py
-
1ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
1ml-agents/mlagents/trainers/components/reward_signals/curiosity/__init__.py
-
179ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
-
178ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
-
1ml-agents/mlagents/trainers/components/reward_signals/extrinsic/__init__.py
-
42ml-agents/mlagents/trainers/components/reward_signals/extrinsic/signal.py
-
68ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
-
43ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py
|
|||
# Reward Signals |
|||
|
|||
In reinforcement learning, the end goal for the Agent is to discover a behavior (a Policy) |
|||
that maximizes a reward. Typically, a reward is defined by your environment, and corresponds |
|||
to reaching some goal. These are what we refer to as "extrinsic" rewards, as they are defined |
|||
external of the learning algorithm. |
|||
|
|||
Rewards, however, can be defined outside of the enviroment as well, to encourage the agent to |
|||
behave in certain ways, or to aid the learning of the true extrinsic reward. We refer to these |
|||
rewards as "intrinsic" reward signals. The total reward that the agent will learn to maximize can |
|||
be a mix of extrinsic and intrinsic reward signals. |
|||
|
|||
ML-Agents allows reward signals to be defined in a modular way, and we provide three reward |
|||
signals that can the mixed and matched to help shape your agent's behavior. The `extrinsic` Reward |
|||
Signal represents the rewards defined in your environment, and is enabled by default. |
|||
The `curiosity` reward signal helps your agent explore when extrinsic rewards are sparse. |
|||
|
|||
## Enabling Reward Signals |
|||
|
|||
Reward signals, like other hyperparameters, are defined in the trainer config `.yaml` file. An |
|||
example is provided in `config/trainer_config.yaml`. To enable a reward signal, add it to the |
|||
`reward_signals:` section under the brain name. For instance, to enable the extrinsic signal |
|||
in addition to a small curiosity reward, you would define your `reward_signals` as follows: |
|||
|
|||
```yaml |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
curiosity: |
|||
strength: 0.01 |
|||
gamma: 0.99 |
|||
encoding_size: 128 |
|||
``` |
|||
|
|||
Each reward signal should define at least two parameters, `strength` and `gamma`, in addition |
|||
to any class-specific hyperparameters. Note that to remove a reward signal, you should delete |
|||
its entry entirely from `reward_signals`. At least one reward signal should be left defined |
|||
at all times. |
|||
|
|||
## Reward Signal Types |
|||
|
|||
### The Extrinsic Reward Signal |
|||
|
|||
The `extrinsic` reward signal is simply the reward given by the |
|||
[environment](Learning-Environment-Design.md). Remove it to force the agent |
|||
to ignore the environment reward. |
|||
|
|||
#### Strength |
|||
|
|||
`strength` is the factor by which to multiply the raw |
|||
reward. Typical ranges will vary depending on the reward signal. |
|||
|
|||
Typical Range: `1.0` |
|||
|
|||
#### Gamma |
|||
|
|||
`gamma` corresponds to the discount factor for future rewards. This can be |
|||
thought of as how far into the future the agent should care about possible |
|||
rewards. In situations when the agent should be acting in the present in order |
|||
to prepare for rewards in the distant future, this value should be large. In |
|||
cases when rewards are more immediate, it can be smaller. |
|||
|
|||
Typical Range: `0.8` - `0.995` |
|||
|
|||
### The Curiosity Reward Signal |
|||
|
|||
The `curiosity` Reward Signal enables the Intrinsic Curiosity Module. This is an implementation |
|||
of the approach described in "Curiosity-driven Exploration by Self-supervised Prediction" |
|||
by Pathak, et al. It trains two networks: |
|||
* an inverse model, which takes the current and next obersvation of the agent, encodes them, and |
|||
uses the encoding to predict the action that was taken between the observations |
|||
* a forward model, which takes the encoded current obseravation and action, and predicts the |
|||
next encoded observation. |
|||
|
|||
The loss of the forward model (the difference between the predicted and actual encoded observations) is used as the intrinsic reward, so the more surprised the model is, the larger the reward will be. |
|||
|
|||
For more information, see |
|||
* https://arxiv.org/abs/1705.05363 |
|||
* https://pathak22.github.io/noreward-rl/ |
|||
* https://blogs.unity3d.com/2018/06/26/solving-sparse-reward-tasks-with-curiosity/ |
|||
|
|||
#### Strength |
|||
|
|||
In this case, `strength` corresponds to the magnitude of the curiosity reward generated |
|||
by the intrinsic curiosity module. This should be scaled in order to ensure it is large enough |
|||
to not be overwhelmed by extrinsic reward signals in the environment. |
|||
Likewise it should not be too large to overwhelm the extrinsic reward signal. |
|||
|
|||
Typical Range: `0.001` - `0.1` |
|||
|
|||
#### Gamma |
|||
|
|||
`gamma` corresponds to the discount factor for future rewards. |
|||
|
|||
Typical Range: `0.8` - `0.995` |
|||
|
|||
#### Encoding Size |
|||
|
|||
`encoding_size` corresponds to the size of the encoding used by the intrinsic curiosity model. |
|||
This value should be small enough to encourage the ICM to compress the original |
|||
observation, but also not too small to prevent it from learning to differentiate between |
|||
demonstrated and actual behavior. |
|||
|
|||
Default Value: 64 |
|||
Typical Range: `64` - `256` |
|||
|
|||
#### Learning Rate |
|||
|
|||
`learning_rate` is the learning rate used to update the intrinsic curiosity module. |
|||
This should typically be decreased if training is unstable, and the curiosity loss is unstable. |
|||
|
|||
Default Value: `3e-4` |
|||
Typical Range: `1e-5` - `1e-3` |
|
|||
import unittest.mock as mock |
|||
import pytest |
|||
import numpy as np |
|||
|
|||
|
|||
def create_mock_brainparams( |
|||
number_visual_observations=0, |
|||
num_stacked_vector_observations=1, |
|||
vector_action_space_type="continuous", |
|||
vector_observation_space_size=3, |
|||
vector_action_space_size=None, |
|||
): |
|||
""" |
|||
Creates a mock BrainParameters object with parameters. |
|||
""" |
|||
# Avoid using mutable object as default param |
|||
if vector_action_space_size is None: |
|||
vector_action_space_size = [2] |
|||
mock_brain = mock.Mock() |
|||
mock_brain.return_value.number_visual_observations = number_visual_observations |
|||
mock_brain.return_value.num_stacked_vector_observations = ( |
|||
num_stacked_vector_observations |
|||
) |
|||
mock_brain.return_value.vector_action_space_type = vector_action_space_type |
|||
mock_brain.return_value.vector_observation_space_size = ( |
|||
vector_observation_space_size |
|||
) |
|||
camrez = {"blackAndWhite": False, "height": 84, "width": 84} |
|||
mock_brain.return_value.camera_resolutions = [camrez] * number_visual_observations |
|||
mock_brain.return_value.vector_action_space_size = vector_action_space_size |
|||
return mock_brain() |
|||
|
|||
|
|||
def create_mock_braininfo( |
|||
num_agents=1, |
|||
num_vector_observations=0, |
|||
num_vis_observations=0, |
|||
num_vector_acts=2, |
|||
discrete=False, |
|||
): |
|||
""" |
|||
Creates a mock BrainInfo with observations. Imitates constant |
|||
vector/visual observations, rewards, dones, and agents. |
|||
|
|||
:int num_agents: Number of "agents" to imitate in your BrainInfo values. |
|||
:int num_vector_observations: Number of "observations" in your observation space |
|||
:int num_vis_observations: Number of "observations" in your observation space |
|||
:int num_vector_acts: Number of actions in your action space |
|||
:bool discrete: Whether or not action space is discrete |
|||
""" |
|||
mock_braininfo = mock.Mock() |
|||
|
|||
mock_braininfo.return_value.visual_observations = num_vis_observations * [ |
|||
np.ones((num_agents, 84, 84, 3)) |
|||
] |
|||
mock_braininfo.return_value.vector_observations = np.array( |
|||
num_agents * [num_vector_observations * [1]] |
|||
) |
|||
if discrete: |
|||
mock_braininfo.return_value.previous_vector_actions = np.array( |
|||
num_agents * [1 * [0.5]] |
|||
) |
|||
mock_braininfo.return_value.action_masks = np.array( |
|||
num_agents * [num_vector_acts * [1.0]] |
|||
) |
|||
else: |
|||
mock_braininfo.return_value.previous_vector_actions = np.array( |
|||
num_agents * [num_vector_acts * [0.5]] |
|||
) |
|||
mock_braininfo.return_value.memories = np.ones((num_agents, 8)) |
|||
mock_braininfo.return_value.rewards = num_agents * [1.0] |
|||
mock_braininfo.return_value.local_done = num_agents * [False] |
|||
mock_braininfo.return_value.text_observations = num_agents * [""] |
|||
mock_braininfo.return_value.agents = range(0, num_agents) |
|||
return mock_braininfo() |
|||
|
|||
|
|||
def setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo): |
|||
""" |
|||
Takes a mock UnityEnvironment and adds the appropriate properties, defined by the mock |
|||
BrainParameters and BrainInfo. |
|||
|
|||
:Mock mock_env: A mock UnityEnvironment, usually empty. |
|||
:Mock mock_brain: A mock Brain object that specifies the params of this environment. |
|||
:Mock mock_braininfo: A mock BrainInfo object that will be returned at each step and reset. |
|||
""" |
|||
mock_env.return_value.academy_name = "MockAcademy" |
|||
mock_env.return_value.brains = {"MockBrain": mock_brain} |
|||
mock_env.return_value.external_brain_names = ["MockBrain"] |
|||
mock_env.return_value.brain_names = ["MockBrain"] |
|||
mock_env.return_value.reset.return_value = {"MockBrain": mock_braininfo} |
|||
mock_env.return_value.step.return_value = {"MockBrain": mock_braininfo} |
|
|||
import unittest.mock as mock |
|||
import pytest |
|||
import mlagents.trainers.tests.mock_brain as mb |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
import yaml |
|||
import os |
|||
|
|||
from mlagents.trainers.ppo.models import PPOModel |
|||
from mlagents.trainers.ppo.trainer import discount_rewards |
|||
from mlagents.trainers.ppo.policy import PPOPolicy |
|||
from mlagents.envs import UnityEnvironment |
|||
from mlagents.envs.mock_communicator import MockCommunicator |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.safe_load( |
|||
""" |
|||
trainer: ppo |
|||
batch_size: 32 |
|||
beta: 5.0e-3 |
|||
buffer_size: 512 |
|||
epsilon: 0.2 |
|||
hidden_units: 128 |
|||
lambd: 0.95 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 5.0e4 |
|||
normalize: true |
|||
num_epoch: 5 |
|||
num_layers: 2 |
|||
time_horizon: 64 |
|||
sequence_length: 64 |
|||
summary_freq: 1000 |
|||
use_recurrent: false |
|||
memory_size: 8 |
|||
curiosity_strength: 0.0 |
|||
curiosity_enc_size: 1 |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
""" |
|||
) |
|||
|
|||
|
|||
@pytest.fixture |
|||
def curiosity_dummy_config(): |
|||
return {"curiosity": {"strength": 0.1, "gamma": 0.9, "encoding_size": 128}} |
|||
|
|||
|
|||
def create_ppo_policy_mock( |
|||
mock_env, dummy_config, reward_signal_config, use_rnn, use_discrete, use_visual |
|||
): |
|||
|
|||
if not use_visual: |
|||
mock_brain = mb.create_mock_brainparams( |
|||
vector_action_space_type="discrete" if use_discrete else "continuous", |
|||
vector_action_space_size=[2], |
|||
vector_observation_space_size=8, |
|||
) |
|||
mock_braininfo = mb.create_mock_braininfo( |
|||
num_agents=12, |
|||
num_vector_observations=8, |
|||
num_vector_acts=2, |
|||
discrete=use_discrete, |
|||
) |
|||
else: |
|||
mock_brain = mb.create_mock_brainparams( |
|||
vector_action_space_type="discrete" if use_discrete else "continuous", |
|||
vector_action_space_size=[2], |
|||
vector_observation_space_size=0, |
|||
number_visual_observations=1, |
|||
) |
|||
mock_braininfo = mb.create_mock_braininfo( |
|||
num_agents=12, |
|||
num_vis_observations=1, |
|||
num_vector_acts=2, |
|||
discrete=use_discrete, |
|||
) |
|||
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo) |
|||
env = mock_env() |
|||
|
|||
trainer_parameters = dummy_config |
|||
model_path = env.brain_names[0] |
|||
trainer_parameters["model_path"] = model_path |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
trainer_parameters["reward_signals"].update(reward_signal_config) |
|||
trainer_parameters["use_recurrent"] = use_rnn |
|||
policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False) |
|||
return env, policy |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_curiosity_cc_evaluate(mock_env, dummy_config, curiosity_dummy_config): |
|||
env, policy = create_ppo_policy_mock( |
|||
mock_env, dummy_config, curiosity_dummy_config, False, False, False |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
next_brain_info = env.step()[env.brain_names[0]] |
|||
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate( |
|||
brain_info, next_brain_info |
|||
) |
|||
assert scaled_reward.shape == (12,) |
|||
assert unscaled_reward.shape == (12,) |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_curiosity_dc_evaluate(mock_env, dummy_config, curiosity_dummy_config): |
|||
env, policy = create_ppo_policy_mock( |
|||
mock_env, dummy_config, curiosity_dummy_config, False, True, False |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
next_brain_info = env.step()[env.brain_names[0]] |
|||
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate( |
|||
brain_info, next_brain_info |
|||
) |
|||
assert scaled_reward.shape == (12,) |
|||
assert unscaled_reward.shape == (12,) |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_curiosity_visual_evaluate(mock_env, dummy_config, curiosity_dummy_config): |
|||
env, policy = create_ppo_policy_mock( |
|||
mock_env, dummy_config, curiosity_dummy_config, False, False, True |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
next_brain_info = env.step()[env.brain_names[0]] |
|||
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate( |
|||
brain_info, next_brain_info |
|||
) |
|||
assert scaled_reward.shape == (12,) |
|||
assert unscaled_reward.shape == (12,) |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_curiosity_rnn_evaluate(mock_env, dummy_config, curiosity_dummy_config): |
|||
env, policy = create_ppo_policy_mock( |
|||
mock_env, dummy_config, curiosity_dummy_config, True, False, False |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
next_brain_info = env.step()[env.brain_names[0]] |
|||
scaled_reward, unscaled_reward = policy.reward_signals["curiosity"].evaluate( |
|||
brain_info, next_brain_info |
|||
) |
|||
assert scaled_reward.shape == (12,) |
|||
assert unscaled_reward.shape == (12,) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
from .reward_signal import * |
|
|||
from .signal import CuriosityRewardSignal |
|
|||
import tensorflow as tf |
|||
from mlagents.trainers.models import LearningModel |
|||
|
|||
|
|||
class CuriosityModel(object): |
|||
def __init__( |
|||
self, |
|||
policy_model: LearningModel, |
|||
encoding_size: int = 128, |
|||
learning_rate: float = 3e-4, |
|||
): |
|||
""" |
|||
Creates the curiosity model for the Curiosity reward Generator |
|||
:param policy_model: The model being used by the learning policy |
|||
:param encoding_size: The size of the encoding for the Curiosity module |
|||
:param learning_rate: The learning rate for the curiosity module |
|||
""" |
|||
self.encoding_size = encoding_size |
|||
self.policy_model = policy_model |
|||
encoded_state, encoded_next_state = self.create_curiosity_encoders() |
|||
self.create_inverse_model(encoded_state, encoded_next_state) |
|||
self.create_forward_model(encoded_state, encoded_next_state) |
|||
self.create_loss(learning_rate) |
|||
|
|||
def create_curiosity_encoders(self): |
|||
""" |
|||
Creates state encoders for current and future observations. |
|||
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction |
|||
See https://arxiv.org/abs/1705.05363 for more details. |
|||
:return: current and future state encoder tensors. |
|||
""" |
|||
encoded_state_list = [] |
|||
encoded_next_state_list = [] |
|||
|
|||
if self.policy_model.vis_obs_size > 0: |
|||
self.next_visual_in = [] |
|||
visual_encoders = [] |
|||
next_visual_encoders = [] |
|||
for i in range(self.policy_model.vis_obs_size): |
|||
# Create input ops for next (t+1) visual observations. |
|||
next_visual_input = LearningModel.create_visual_input( |
|||
self.policy_model.brain.camera_resolutions[i], |
|||
name="next_visual_observation_" + str(i), |
|||
) |
|||
self.next_visual_in.append(next_visual_input) |
|||
|
|||
# Create the encoder ops for current and next visual input. |
|||
# Note that these encoders are siamese. |
|||
encoded_visual = self.policy_model.create_visual_observation_encoder( |
|||
self.policy_model.visual_in[i], |
|||
self.encoding_size, |
|||
LearningModel.swish, |
|||
1, |
|||
"stream_{}_visual_obs_encoder".format(i), |
|||
False, |
|||
) |
|||
|
|||
encoded_next_visual = self.policy_model.create_visual_observation_encoder( |
|||
self.next_visual_in[i], |
|||
self.encoding_size, |
|||
LearningModel.swish, |
|||
1, |
|||
"stream_{}_visual_obs_encoder".format(i), |
|||
True, |
|||
) |
|||
visual_encoders.append(encoded_visual) |
|||
next_visual_encoders.append(encoded_next_visual) |
|||
|
|||
hidden_visual = tf.concat(visual_encoders, axis=1) |
|||
hidden_next_visual = tf.concat(next_visual_encoders, axis=1) |
|||
encoded_state_list.append(hidden_visual) |
|||
encoded_next_state_list.append(hidden_next_visual) |
|||
|
|||
if self.policy_model.vec_obs_size > 0: |
|||
# Create the encoder ops for current and next vector input. |
|||
# Note that these encoders are siamese. |
|||
# Create input op for next (t+1) vector observation. |
|||
self.next_vector_in = tf.placeholder( |
|||
shape=[None, self.policy_model.vec_obs_size], |
|||
dtype=tf.float32, |
|||
name="next_vector_observation", |
|||
) |
|||
|
|||
encoded_vector_obs = self.policy_model.create_vector_observation_encoder( |
|||
self.policy_model.vector_in, |
|||
self.encoding_size, |
|||
LearningModel.swish, |
|||
2, |
|||
"vector_obs_encoder", |
|||
False, |
|||
) |
|||
encoded_next_vector_obs = self.policy_model.create_vector_observation_encoder( |
|||
self.next_vector_in, |
|||
self.encoding_size, |
|||
LearningModel.swish, |
|||
2, |
|||
"vector_obs_encoder", |
|||
True, |
|||
) |
|||
encoded_state_list.append(encoded_vector_obs) |
|||
encoded_next_state_list.append(encoded_next_vector_obs) |
|||
|
|||
encoded_state = tf.concat(encoded_state_list, axis=1) |
|||
encoded_next_state = tf.concat(encoded_next_state_list, axis=1) |
|||
return encoded_state, encoded_next_state |
|||
|
|||
def create_inverse_model(self, encoded_state, encoded_next_state): |
|||
""" |
|||
Creates inverse model TensorFlow ops for Curiosity module. |
|||
Predicts action taken given current and future encoded states. |
|||
:param encoded_state: Tensor corresponding to encoded current state. |
|||
:param encoded_next_state: Tensor corresponding to encoded next state. |
|||
""" |
|||
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1) |
|||
hidden = tf.layers.dense(combined_input, 256, activation=LearningModel.swish) |
|||
if self.policy_model.brain.vector_action_space_type == "continuous": |
|||
pred_action = tf.layers.dense( |
|||
hidden, self.policy_model.act_size[0], activation=None |
|||
) |
|||
squared_difference = tf.reduce_sum( |
|||
tf.squared_difference(pred_action, self.policy_model.selected_actions), |
|||
axis=1, |
|||
) |
|||
self.inverse_loss = tf.reduce_mean( |
|||
tf.dynamic_partition(squared_difference, self.policy_model.mask, 2)[1] |
|||
) |
|||
else: |
|||
pred_action = tf.concat( |
|||
[ |
|||
tf.layers.dense( |
|||
hidden, self.policy_model.act_size[i], activation=tf.nn.softmax |
|||
) |
|||
for i in range(len(self.policy_model.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
cross_entropy = tf.reduce_sum( |
|||
-tf.log(pred_action + 1e-10) * self.policy_model.selected_actions, |
|||
axis=1, |
|||
) |
|||
self.inverse_loss = tf.reduce_mean( |
|||
tf.dynamic_partition(cross_entropy, self.policy_model.mask, 2)[1] |
|||
) |
|||
|
|||
def create_forward_model(self, encoded_state, encoded_next_state): |
|||
""" |
|||
Creates forward model TensorFlow ops for Curiosity module. |
|||
Predicts encoded future state based on encoded current state and given action. |
|||
:param encoded_state: Tensor corresponding to encoded current state. |
|||
:param encoded_next_state: Tensor corresponding to encoded next state. |
|||
""" |
|||
combined_input = tf.concat( |
|||
[encoded_state, self.policy_model.selected_actions], axis=1 |
|||
) |
|||
hidden = tf.layers.dense(combined_input, 256, activation=LearningModel.swish) |
|||
pred_next_state = tf.layers.dense( |
|||
hidden, |
|||
self.encoding_size |
|||
* ( |
|||
self.policy_model.vis_obs_size + int(self.policy_model.vec_obs_size > 0) |
|||
), |
|||
activation=None, |
|||
) |
|||
squared_difference = 0.5 * tf.reduce_sum( |
|||
tf.squared_difference(pred_next_state, encoded_next_state), axis=1 |
|||
) |
|||
self.intrinsic_reward = squared_difference |
|||
self.forward_loss = tf.reduce_mean( |
|||
tf.dynamic_partition(squared_difference, self.policy_model.mask, 2)[1] |
|||
) |
|||
|
|||
def create_loss(self, learning_rate): |
|||
""" |
|||
Creates the loss node of the model as well as the update_batch optimizer to update the model. |
|||
:param learning_rate: The learning rate for the optimizer. |
|||
""" |
|||
self.loss = 10 * (0.2 * self.forward_loss + 0.8 * self.inverse_loss) |
|||
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) |
|||
self.update_batch = optimizer.minimize(self.loss) |
|
|||
import numpy as np |
|||
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult |
|||
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel |
|||
from mlagents.trainers.policy import Policy |
|||
|
|||
|
|||
class CuriosityRewardSignal(RewardSignal): |
|||
def __init__( |
|||
self, |
|||
policy: Policy, |
|||
strength: float, |
|||
gamma: float, |
|||
encoding_size: int = 128, |
|||
learning_rate: float = 3e-4, |
|||
num_epoch: int = 3, |
|||
): |
|||
""" |
|||
Creates the Curiosity reward generator |
|||
:param policy: The Learning Policy |
|||
:param encoding_size: The size of the Curiosity encoding |
|||
:param signal_strength: The scaling parameter for the reward. The scaled reward will be the unscaled |
|||
reward multiplied by the strength parameter |
|||
""" |
|||
super().__init__(policy, strength, gamma) |
|||
self.model = CuriosityModel( |
|||
policy.model, encoding_size=encoding_size, learning_rate=learning_rate |
|||
) |
|||
self.num_epoch = num_epoch |
|||
self.update_dict = { |
|||
"forward_loss": self.model.forward_loss, |
|||
"inverse_loss": self.model.inverse_loss, |
|||
"update": self.model.update_batch, |
|||
} |
|||
self.has_updated = False |
|||
|
|||
def evaluate(self, current_info, next_info): |
|||
""" |
|||
Evaluates the reward for the agents present in current_info given the next_info |
|||
:param current_info: The current BrainInfo. |
|||
:param next_info: The BrainInfo from the next timestep. |
|||
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator |
|||
""" |
|||
if len(current_info.agents) == 0: |
|||
return [] |
|||
|
|||
feed_dict = { |
|||
self.policy.model.batch_size: len(next_info.vector_observations), |
|||
self.policy.model.sequence_length: 1, |
|||
} |
|||
feed_dict = self.policy.fill_eval_dict(feed_dict, brain_info=current_info) |
|||
if self.policy.use_continuous_act: |
|||
feed_dict[ |
|||
self.policy.model.selected_actions |
|||
] = next_info.previous_vector_actions |
|||
else: |
|||
feed_dict[ |
|||
self.policy.model.action_holder |
|||
] = next_info.previous_vector_actions |
|||
for i in range(self.policy.model.vis_obs_size): |
|||
feed_dict[self.model.next_visual_in[i]] = next_info.visual_observations[i] |
|||
if self.policy.use_vec_obs: |
|||
feed_dict[self.model.next_vector_in] = next_info.vector_observations |
|||
if self.policy.use_recurrent: |
|||
if current_info.memories.shape[1] == 0: |
|||
current_info.memories = self.policy.make_empty_memory( |
|||
len(current_info.agents) |
|||
) |
|||
feed_dict[self.policy.model.memory_in] = current_info.memories |
|||
unscaled_reward = self.policy.sess.run( |
|||
self.model.intrinsic_reward, feed_dict=feed_dict |
|||
) |
|||
scaled_reward = np.clip( |
|||
unscaled_reward * float(self.has_updated) * self.strength, 0, 1 |
|||
) |
|||
return RewardSignalResult(scaled_reward, unscaled_reward) |
|||
|
|||
@classmethod |
|||
def check_config(cls, config_dict): |
|||
""" |
|||
Checks the config and throw an exception if a hyperparameter is missing. Curiosity requires strength, |
|||
gamma, and encoding size at minimum. |
|||
""" |
|||
param_keys = ["strength", "gamma", "encoding_size"] |
|||
super().check_config(config_dict, param_keys) |
|||
|
|||
def update(self, update_buffer, num_sequences): |
|||
""" |
|||
Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs |
|||
gradient descent. |
|||
:param update_buffer: Update buffer from which to pull data from. |
|||
:param num_sequences: Number of sequences in the update buffer. |
|||
:return: Dict of stats that should be reported to Tensorboard. |
|||
""" |
|||
forward_total, inverse_total = [], [] |
|||
for _ in range(self.num_epoch): |
|||
update_buffer.shuffle() |
|||
buffer = update_buffer |
|||
for l in range(len(update_buffer["actions"]) // num_sequences): |
|||
start = l * num_sequences |
|||
end = (l + 1) * num_sequences |
|||
run_out_curio = self._update_batch( |
|||
buffer.make_mini_batch(start, end), num_sequences |
|||
) |
|||
inverse_total.append(run_out_curio["inverse_loss"]) |
|||
forward_total.append(run_out_curio["forward_loss"]) |
|||
|
|||
update_stats = { |
|||
"Losses/Curiosity Forward Loss": np.mean(forward_total), |
|||
"Losses/Curiosity Inverse Loss": np.mean(inverse_total), |
|||
} |
|||
return update_stats |
|||
|
|||
def _update_batch(self, mini_batch, num_sequences): |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param mini_batch: Experience batch. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = { |
|||
self.policy.model.batch_size: num_sequences, |
|||
self.policy.model.sequence_length: self.policy.sequence_length, |
|||
self.policy.model.mask_input: mini_batch["masks"].flatten(), |
|||
self.policy.model.advantage: mini_batch["advantages"].reshape([-1, 1]), |
|||
self.policy.model.all_old_log_probs: mini_batch["action_probs"].reshape( |
|||
[-1, sum(self.policy.model.act_size)] |
|||
), |
|||
} |
|||
if self.policy.use_continuous_act: |
|||
feed_dict[self.policy.model.output_pre] = mini_batch["actions_pre"].reshape( |
|||
[-1, self.policy.model.act_size[0]] |
|||
) |
|||
feed_dict[self.policy.model.epsilon] = mini_batch[ |
|||
"random_normal_epsilon" |
|||
].reshape([-1, self.policy.model.act_size[0]]) |
|||
else: |
|||
feed_dict[self.policy.model.action_holder] = mini_batch["actions"].reshape( |
|||
[-1, len(self.policy.model.act_size)] |
|||
) |
|||
if self.policy.use_recurrent: |
|||
feed_dict[self.policy.model.prev_action] = mini_batch[ |
|||
"prev_action" |
|||
].reshape([-1, len(self.policy.model.act_size)]) |
|||
feed_dict[self.policy.model.action_masks] = mini_batch[ |
|||
"action_mask" |
|||
].reshape([-1, sum(self.policy.brain.vector_action_space_size)]) |
|||
if self.policy.use_vec_obs: |
|||
feed_dict[self.policy.model.vector_in] = mini_batch["vector_obs"].reshape( |
|||
[-1, self.policy.vec_obs_size] |
|||
) |
|||
feed_dict[self.model.next_vector_in] = mini_batch["next_vector_in"].reshape( |
|||
[-1, self.policy.vec_obs_size] |
|||
) |
|||
if self.policy.model.vis_obs_size > 0: |
|||
for i, _ in enumerate(self.policy.model.visual_in): |
|||
_obs = mini_batch["visual_obs%d" % i] |
|||
if self.policy.sequence_length > 1 and self.policy.use_recurrent: |
|||
(_batch, _seq, _w, _h, _c) = _obs.shape |
|||
feed_dict[self.policy.model.visual_in[i]] = _obs.reshape( |
|||
[-1, _w, _h, _c] |
|||
) |
|||
else: |
|||
feed_dict[self.policy.model.visual_in[i]] = _obs |
|||
for i, _ in enumerate(self.policy.model.visual_in): |
|||
_obs = mini_batch["next_visual_obs%d" % i] |
|||
if self.policy.sequence_length > 1 and self.policy.use_recurrent: |
|||
(_batch, _seq, _w, _h, _c) = _obs.shape |
|||
feed_dict[self.model.next_visual_in[i]] = _obs.reshape( |
|||
[-1, _w, _h, _c] |
|||
) |
|||
else: |
|||
feed_dict[self.model.next_visual_in[i]] = _obs |
|||
if self.policy.use_recurrent: |
|||
mem_in = mini_batch["memory"][:, 0, :] |
|||
feed_dict[self.policy.model.memory_in] = mem_in |
|||
self.has_updated = True |
|||
run_out = self.policy._execute_model(feed_dict, self.update_dict) |
|||
return run_out |
|
|||
from .signal import ExtrinsicRewardSignal |
|
|||
import numpy as np |
|||
|
|||
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult |
|||
from mlagents.trainers.policy import Policy |
|||
|
|||
|
|||
class ExtrinsicRewardSignal(RewardSignal): |
|||
def __init__(self, policy: Policy, strength: float, gamma: float): |
|||
""" |
|||
The extrinsic reward generator. Returns the reward received by the environment |
|||
:param policy: The Policy object (e.g. PPOPolicy) that this Reward Signal will apply to. |
|||
:param strength: The strength of the reward. The reward's raw value will be multiplied by this value. |
|||
:param gamma: The time discounting factor used for this reward. |
|||
:return: An ExtrinsicRewardSignal object. |
|||
""" |
|||
super().__init__(policy, strength, gamma) |
|||
|
|||
@classmethod |
|||
def check_config(cls, config_dict): |
|||
""" |
|||
Checks the config and throw an exception if a hyperparameter is missing. Extrinsic requires strength and gamma |
|||
at minimum. |
|||
""" |
|||
param_keys = ["strength", "gamma"] |
|||
super().check_config(config_dict, param_keys) |
|||
|
|||
def evaluate(self, current_info, next_info): |
|||
""" |
|||
Evaluates the reward for the agents present in current_info given the next_info |
|||
:param current_info: The current BrainInfo. |
|||
:param next_info: The BrainInfo from the next timestep. |
|||
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator |
|||
""" |
|||
unscaled_reward = np.array(next_info.rewards) |
|||
scaled_reward = self.strength * unscaled_reward |
|||
return RewardSignalResult(scaled_reward, unscaled_reward) |
|||
|
|||
def update(self, update_buffer, num_sequences): |
|||
""" |
|||
This method does nothing, as there is nothing to update. |
|||
""" |
|||
return {} |
|
|||
import logging |
|||
from mlagents.trainers.trainer import UnityTrainerException |
|||
from mlagents.trainers.policy import Policy |
|||
from collections import namedtuple |
|||
import numpy as np |
|||
import abc |
|||
|
|||
import tensorflow as tf |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
RewardSignalResult = namedtuple( |
|||
"RewardSignalResult", ["scaled_reward", "unscaled_reward"] |
|||
) |
|||
|
|||
|
|||
class RewardSignal(abc.ABC): |
|||
def __init__(self, policy: Policy, strength: float, gamma: float): |
|||
""" |
|||
Initializes a reward signal. At minimum, you must pass in the policy it is being applied to, |
|||
the reward strength, and the gamma (discount factor.) |
|||
:param policy: The Policy object (e.g. PPOPolicy) that this Reward Signal will apply to. |
|||
:param strength: The strength of the reward. The reward's raw value will be multiplied by this value. |
|||
:param gamma: The time discounting factor used for this reward. |
|||
:return: A RewardSignal object. |
|||
""" |
|||
class_name = self.__class__.__name__ |
|||
short_name = class_name.replace("RewardSignal", "") |
|||
self.stat_name = f"Policy/{short_name} Reward" |
|||
self.value_name = f"Policy/{short_name} Value Estimate" |
|||
self.gamma = gamma |
|||
self.policy = policy |
|||
self.strength = strength |
|||
|
|||
def evaluate(self, current_info, next_info): |
|||
""" |
|||
Evaluates the reward for the agents present in current_info given the next_info |
|||
:param current_info: The current BrainInfo. |
|||
:param next_info: The BrainInfo from the next timestep. |
|||
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator |
|||
""" |
|||
return ( |
|||
self.strength * np.zeros(len(current_info.agents)), |
|||
np.zeros(len(current_info.agents)), |
|||
) |
|||
|
|||
def update(self, update_buffer, n_sequences): |
|||
""" |
|||
If the reward signal has an internal model (e.g. GAIL or Curiosity), update that model. |
|||
:param update_buffer: An AgentBuffer that contains the live data from which to update. |
|||
:param n_sequences: The number of sequences in the training buffer. |
|||
:return: A dict of {"Stat Name": stat} to be added to Tensorboard |
|||
""" |
|||
return {} |
|||
|
|||
@classmethod |
|||
def check_config(cls, config_dict, param_keys=None): |
|||
""" |
|||
Check the config dict, and throw an error if there are missing hyperparameters. |
|||
""" |
|||
param_keys = param_keys or [] |
|||
for k in param_keys: |
|||
if k not in config_dict: |
|||
raise UnityTrainerException( |
|||
"The hyper-parameter {0} could not be found for {1}.".format( |
|||
k, cls.__name__ |
|||
) |
|||
) |
|
|||
import logging |
|||
from typing import Any, Dict, Type |
|||
|
|||
from mlagents.trainers.trainer import UnityTrainerException |
|||
from mlagents.trainers.components.reward_signals.reward_signal import RewardSignal |
|||
from mlagents.trainers.components.reward_signals.extrinsic.signal import ( |
|||
ExtrinsicRewardSignal, |
|||
) |
|||
from mlagents.trainers.components.reward_signals.curiosity.signal import ( |
|||
CuriosityRewardSignal, |
|||
) |
|||
from mlagents.trainers.policy import Policy |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
NAME_TO_CLASS: Dict[str, Type[RewardSignal]] = { |
|||
"extrinsic": ExtrinsicRewardSignal, |
|||
"curiosity": CuriosityRewardSignal, |
|||
} |
|||
|
|||
|
|||
def create_reward_signal( |
|||
policy: Policy, name: str, config_entry: Dict[str, Any] |
|||
) -> RewardSignal: |
|||
""" |
|||
Creates a reward signal class based on the name and config entry provided as a dict. |
|||
:param policy: The policy class which the reward will be applied to. |
|||
:param name: The name of the reward signal |
|||
:param config_entry: The config entries for that reward signal |
|||
:return: The reward signal class instantiated |
|||
""" |
|||
rcls = NAME_TO_CLASS.get(name) |
|||
if not rcls: |
|||
raise UnityTrainerException("Unknown reward signal type {0}".format(name)) |
|||
rcls.check_config(config_entry) |
|||
try: |
|||
class_inst = rcls(policy, **config_entry) |
|||
except TypeError: |
|||
raise UnityTrainerException( |
|||
"Unknown parameters given for reward signal {0}".format(name) |
|||
) |
|||
return class_inst |
撰写
预览
正在加载...
取消
保存
Reference in new issue