浏览代码

Add Soft Actor-Critic as trainer option (#2341)

* Add Soft Actor-Critic model, trainer, and policy and sac_trainer_config.yaml
* Add documentation for SAC and tweak PPO documentation to reference the new pages.
* Add tests for SAC, change simple_rl test to run both PPO and SAC.
/develop-gpu-test
GitHub 5 年前
当前提交
6a81a2f4
共有 18 个文件被更改,包括 2987 次插入131 次删除
  1. 1
      README.md
  2. 49
      docs/Getting-Started-with-Balance-Ball.md
  3. 1
      docs/Readme.md
  4. 47
      docs/Training-ML-Agents.md
  5. 9
      docs/Training-PPO.md
  6. 3
      ml-agents/mlagents/trainers/__init__.py
  7. 113
      ml-agents/mlagents/trainers/tests/test_bcmodule.py
  8. 123
      ml-agents/mlagents/trainers/tests/test_reward_signals.py
  9. 90
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  10. 13
      ml-agents/mlagents/trainers/trainer_util.py
  11. 276
      config/sac_trainer_config.yaml
  12. 330
      docs/Training-SAC.md
  13. 399
      ml-agents/mlagents/trainers/tests/test_sac.py
  14. 3
      ml-agents/mlagents/trainers/sac/__init__.py
  15. 1001
      ml-agents/mlagents/trainers/sac/models.py
  16. 320
      ml-agents/mlagents/trainers/sac/policy.py
  17. 340
      ml-agents/mlagents/trainers/sac/trainer.py

1
README.md


* Unity environment control from Python
* 10+ sample Unity environments
* Two deep reinforcement learning algorithms, [Proximal Policy Optimization](docs/Training-PPO.md) (PPO) and [Soft Actor-Critic](docs/Training-SAC.md) (SAC)
* Support for multiple environment configurations and training scenarios
* Train memory-enhanced agents using deep reinforcement learning
* Easily definable Curriculum Learning and Generalization scenarios

49
docs/Getting-Started-with-Balance-Ball.md


The Academy object for the scene is placed on the Ball3DAcademy GameObject. When
you look at an Academy component in the inspector, you can see several
properties that control how the environment works.
properties that control how the environment works.
The **Training Configuration** and **Inference Configuration** properties
set the graphics and timescale properties for the Unity application.
The **Training Configuration** and **Inference Configuration** properties
set the graphics and timescale properties for the Unity application.
**Inference Configuration** when not training. (*Inference* means that the
Agent is using a trained model or heuristics or direct control — in other
**Inference Configuration** when not training. (*Inference* means that the
Agent is using a trained model or heuristics or direct control — in other
words, whenever **not** training.)
Typically, you would set a low graphics quality and timescale to greater `1.0` for the **Training
Configuration** and a high graphics quality and timescale to `1.0` for the

Brain, but would act independently. The Brain settings tell you quite a bit about how
an Agent works.
You can create new Brain assets by selecting `Assets ->
Create -> ML-Agents -> Brain`. There are 3 types of Brains.
You can create new Brain assets by selecting `Assets ->
Create -> ML-Agents -> Brain`. There are 3 types of Brains.
The **Learning Brain** is a Brain that uses a trained neural network to make decisions.
When the `Control` box is checked in the Brains property under the **Broadcast Hub** in the Academy, the external process that is training the neural network will take over decision making for the agents
and ultimately generate a trained neural network. You can also use the

Now that we have an environment, we can perform the training.
### Training with PPO
### Training with Deep Reinforcement Learning
In order to train an agent to correctly balance the ball, we provide two
deep reinforcement learning algorithms.
The default algorithm is Proximal Policy Optimization (PPO). This
is a method that has been shown to be more general purpose and stable
than many other RL algorithms. For more information on PPO, OpenAI
has a [blog post](https://blog.openai.com/openai-baselines-ppo/)
explaining it, and [our page](Training-PPO.md) for how to use it in training.
In order to train an agent to correctly balance the ball, we will use a
Reinforcement Learning algorithm called Proximal Policy Optimization (PPO). This
is a method that has been shown to be safe, efficient, and more general purpose
than many other RL algorithms, as such we have chosen it as the example
algorithm for use with ML-Agents toolkit. For more information on PPO, OpenAI
has a recent [blog post](https://blog.openai.com/openai-baselines-ppo/)
explaining it.
We also provide Soft-Actor Critic, an off-policy algorithm that
has been shown to be both stable and sample-efficient.
For more information on SAC, see UC Berkeley's
[blog post](https://bair.berkeley.edu/blog/2018/12/14/sac/) and
[our page](Training-SAC.md) for more guidance on when to use SAC vs. PPO. To
use SAC to train Balance Ball, replace all references to `config/trainer_config.yaml`
with `config/sac_trainer_config.yaml` below.
To train the agents within the Ball Balance environment, we will be using the
Python package. We have provided a convenient command called `mlagents-learn`
To train the agents within the Balance Ball environment, we will be using the
ML-Agents Python package. We have provided a convenient command called `mlagents-learn`
which accepts arguments used to configure both training and inference phases.
We can use `run_id` to identify the experiment and create a folder where the

Once the training process completes, and the training process saves the model
(denoted by the `Saved Model` message) you can add it to the Unity project and
use it with Agents having a **Learning Brain**.
__Note:__ Do not just close the Unity Window once the `Saved Model` message appears.
Either wait for the training process to close the window or press Ctrl+C at the
command-line prompt. If you close the window manually, the `.nn` file
__Note:__ Do not just close the Unity Window once the `Saved Model` message appears.
Either wait for the training process to close the window or press Ctrl+C at the
command-line prompt. If you close the window manually, the `.nn` file
containing the trained model is not exported into the ml-agents folder.
### Embedding the trained model into Unity

1
docs/Readme.md


* [Training ML-Agents](Training-ML-Agents.md)
* [Training with Proximal Policy Optimization](Training-PPO.md)
* [Training with Soft Actor-Critic](Training-SAC.md)
* [Training with Curriculum Learning](Training-Curriculum-Learning.md)
* [Training with Imitation Learning](Training-Imitation-Learning.md)
* [Training with LSTM](Feature-Memory.md)

47
docs/Training-ML-Agents.md


### Training Config File
The training config files `config/trainer_config.yaml`,
`config/online_bc_config.yaml` and `config/offline_bc_config.yaml` specifies the
training method, the hyperparameters, and a few additional values to use during
training with PPO, online and offline BC. These files are divided into sections.
The training config files `config/trainer_config.yaml`, `config/sac_trainer_config.yaml`,
`config/gail_config.yaml`, `config/online_bc_config.yaml` and `config/offline_bc_config.yaml`
specifies the training method, the hyperparameters, and a few additional values to use when
training with PPO, SAC, GAIL (with PPO), and online and offline BC. These files are divided into sections.
The **default** section defines the default values for all the available
settings. You can also add new sections to override these defaults to train
specific Brains. Name each of these override sections after the GameObject

| **Setting** | **Description** | **Applies To Trainer\*** |
| :------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------------------- |
| batch_size | The number of experiences in each iteration of gradient descent. | PPO, BC |
| batch_size | The number of experiences in each iteration of gradient descent. | PPO, SAC, BC |
| buffer_size | The number of experiences to collect before updating the policy model. | PPO |
| buffer_size | The number of experiences to collect before updating the policy model. In SAC, the max size of the experience buffer. | PPO, SAC |
| buffer_init_steps | The number of experiences to collect into the buffer before updating the policy model. | SAC |
| hidden_units | The number of units in the hidden layers of the neural network. | PPO, BC |
| hidden_units | The number of units in the hidden layers of the neural network. | PPO, SAC, BC |
| init_entcoef | How much the agent should explore in the beginning of training. | SAC |
| learning_rate | The initial learning rate for gradient descent. | PPO, BC |
| max_steps | The maximum number of simulation steps to run during a training session. | PPO, BC |
| memory_size | The size of the memory an agent must keep. Used for training with a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, BC |
| normalize | Whether to automatically normalize observations. | PPO |
| learning_rate | The initial learning rate for gradient descent. | PPO, SAC, BC |
| max_steps | The maximum number of simulation steps to run during a training session. | PPO, SAC, BC |
| memory_size | The size of the memory an agent must keep. Used for training with a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, SAC, BC |
| normalize | Whether to automatically normalize observations. | PPO, SAC |
| num_layers | The number of hidden layers in the neural network. | PPO, BC |
| pretraining | Use demonstrations to bootstrap the policy neural network. See [Pretraining Using Demonstrations](Training-PPO.md#optional-pretraining-using-demonstrations). | PPO |
| reward_signals | The reward signals used to train the policy. Enable Curiosity and GAIL here. See [Reward Signals](Reward-Signals.md) for configuration options. | PPO |
| sequence_length | Defines how long the sequences of experiences must be while training. Only used for training with a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, BC |
| summary_freq | How often, in steps, to save training statistics. This determines the number of data points shown by TensorBoard. | PPO, BC |
| time_horizon | How many steps of experience to collect per-agent before adding it to the experience buffer. | PPO, (online)BC |
| trainer | The type of training to perform: "ppo", "offline_bc" or "online_bc". | PPO, BC |
| use_recurrent | Train using a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, BC |
| num_layers | The number of hidden layers in the neural network. | PPO, SAC, BC |
| pretraining | Use demonstrations to bootstrap the policy neural network. See [Pretraining Using Demonstrations](Training-PPO.md#optional-pretraining-using-demonstrations). | PPO, SAC |
| reward_signals | The reward signals used to train the policy. Enable Curiosity and GAIL here. See [Reward Signals](Reward-Signals.md) for configuration options. | PPO, SAC, BC |
| save_replay_buffer | Saves the replay buffer when exiting training, and loads it on resume. | SAC |
| sequence_length | Defines how long the sequences of experiences must be while training. Only used for training with a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, SAC, BC |
| summary_freq | How often, in steps, to save training statistics. This determines the number of data points shown by TensorBoard. | PPO, SAC, BC |
| tau | How aggressively to update the target network used for bootstrapping value estimation in SAC. | SAC |
| time_horizon | How many steps of experience to collect per-agent before adding it to the experience buffer. | PPO, SAC, (online)BC |
| trainer | The type of training to perform: "ppo", "sac", "offline_bc" or "online_bc". | PPO, SAC, BC |
| train_interval | How often to update the agent. | SAC |
| num_update | Number of mini-batches to update the agent with during each update. | SAC |
| use_recurrent | Train using a recurrent neural network. See [Using Recurrent Neural Networks](Feature-Memory.md). | PPO, SAC, BC |
\*PPO = Proximal Policy Optimization, BC = Behavioral Cloning (Imitation)
\*PPO = Proximal Policy Optimization, SAC = Soft Actor-Critic, BC = Behavioral Cloning (Imitation)
* [Training with SAC](Training-SAC.md)
* [Using Recurrent Neural Networks](Feature-Memory.md)
* [Training with Curriculum Learning](Training-Curriculum-Learning.md)
* [Training with Imitation Learning](Training-Imitation-Learning.md)

9
docs/Training-PPO.md


# Training with Proximal Policy Optimization
ML-Agents uses a reinforcement learning technique called
ML-Agents provides an implementation of a reinforcement learning algorithm called
ML-Agents also provides an implementation of
[Soft Actor-Critic (SAC)](https://bair.berkeley.edu/blog/2018/12/14/sac/). SAC tends
to be more _sample-efficient_, i.e. require fewer environment steps,
than PPO, but may spend more time performing model updates. This can produce a large
speedup on heavy or slow environments. Check out how to train with
SAC [here](Training-SAC.md).
To train an agent, you will need to provide the agent one or more reward signals which
the agent should attempt to maximize. See [Reward Signals](Reward-Signals.md)

3
ml-agents/mlagents/trainers/__init__.py


from .ppo.models import *
from .ppo.trainer import *
from .ppo.policy import *
from .sac.models import *
from .sac.trainer import *
from .sac.policy import *
from .exception import *
from .demo_loader import *

113
ml-agents/mlagents/trainers/tests/test_bcmodule.py


import os
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.sac.policy import SACPolicy
@pytest.fixture
def dummy_config():
def ppo_dummy_config():
return yaml.safe_load(
"""
trainer: ppo

)
def create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, use_rnn, demo_file
def sac_dummy_config():
return yaml.safe_load(
"""
trainer: sac
batch_size: 128
buffer_size: 50000
buffer_init_steps: 0
hidden_units: 128
init_entcoef: 1.0
learning_rate: 3.0e-4
max_steps: 5.0e4
memory_size: 256
normalize: false
num_update: 1
train_interval: 1
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
tau: 0.005
use_recurrent: false
vis_encode_type: default
pretraining:
demo_path: ./demos/ExpertPyramid.demo
strength: 1.0
steps: 10000000
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
def create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, use_rnn, demo_file
trainer_parameters = dummy_config
trainer_parameters["model_path"] = model_path
trainer_parameters["keep_checkpoints"] = 3
trainer_parameters["use_recurrent"] = use_rnn
trainer_parameters["pretraining"]["demo_path"] = (
trainer_config["model_path"] = model_path
trainer_config["keep_checkpoints"] = 3
trainer_config["use_recurrent"] = use_rnn
trainer_config["pretraining"]["demo_path"] = (
policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False)
policy = (
PPOPolicy(0, mock_brain, trainer_config, False, False)
if trainer_config["trainer"] == "ppo"
else SACPolicy(0, mock_brain, trainer_config, False, False)
)
def test_bcmodule_defaults(mock_env, dummy_config):
def test_bcmodule_defaults(mock_env):
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
trainer_config = ppo_dummy_config()
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, False, "test.demo"
assert policy.bc_module.num_epoch == dummy_config["num_epoch"]
assert policy.bc_module.batch_size == dummy_config["batch_size"]
assert policy.bc_module.num_epoch == trainer_config["num_epoch"]
assert policy.bc_module.batch_size == trainer_config["batch_size"]
dummy_config["pretraining"]["num_epoch"] = 100
dummy_config["pretraining"]["batch_size"] = 10000
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
trainer_config["pretraining"]["num_epoch"] = 100
trainer_config["pretraining"]["batch_size"] = 10000
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, False, "test.demo"
)
assert policy.bc_module.num_epoch == 100
assert policy.bc_module.batch_size == 10000

# Test with continuous control env and vector actions
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_bcmodule_update(mock_env, dummy_config):
def test_bcmodule_update(mock_env, trainer_config):
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, False, "test.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():

# Test with RNN
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_bcmodule_rnn_update(mock_env, dummy_config):
def test_bcmodule_rnn_update(mock_env, trainer_config):
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, True, "test.demo"
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, True, "test.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():

# Test with discrete control and visual observations
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_bcmodule_dc_visual_update(mock_env, dummy_config):
def test_bcmodule_dc_visual_update(mock_env, trainer_config):
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "testdcvis.demo"
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, False, "testdcvis.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():

# Test with discrete control, visual observations and RNN
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_bcmodule_rnn_dc_update(mock_env, dummy_config):
def test_bcmodule_rnn_dc_update(mock_env, trainer_config):
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, True, "testdcvis.demo"
env, policy = create_policy_with_bc_mock(
mock_env, mock_brain, trainer_config, True, "testdcvis.demo"
)
stats = policy.bc_module.update()
for _, item in stats.items():

123
ml-agents/mlagents/trainers/tests/test_reward_signals.py


from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.ppo.trainer import discount_rewards
from mlagents.trainers.ppo.policy import PPOPolicy
from mlagents.trainers.sac.policy import SACPolicy
@pytest.fixture
def dummy_config():
def ppo_dummy_config():
return yaml.safe_load(
"""
trainer: ppo

summary_freq: 1000
use_recurrent: false
memory_size: 8
curiosity_strength: 0.0
curiosity_enc_size: 1
reward_signals:
extrinsic:
strength: 1.0

def sac_dummy_config():
return yaml.safe_load(
"""
trainer: sac
batch_size: 128
buffer_size: 50000
buffer_init_steps: 0
hidden_units: 128
init_entcoef: 1.0
learning_rate: 3.0e-4
max_steps: 5.0e4
memory_size: 256
normalize: false
num_update: 1
train_interval: 1
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
tau: 0.005
use_recurrent: false
vis_encode_type: default
pretraining:
demo_path: ./demos/ExpertPyramid.demo
strength: 1.0
steps: 10000000
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
@pytest.fixture
def gail_dummy_config():
return {

NUM_AGENTS = 12
def create_ppo_policy_mock(
mock_env, dummy_config, reward_signal_config, use_rnn, use_discrete, use_visual
def create_policy_mock(
mock_env, trainer_config, reward_signal_config, use_rnn, use_discrete, use_visual
):
env, mock_brain, _ = mb.setup_mock_env_and_brains(
mock_env,

discrete_action_space=DISCRETE_ACTION_SPACE,
)
trainer_parameters = dummy_config
trainer_parameters = trainer_config
policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False)
if trainer_config["trainer"] == "ppo":
policy = PPOPolicy(0, mock_brain, trainer_parameters, False, False)
else:
policy = SACPolicy(0, mock_brain, trainer_parameters, False, False)
return env, policy

assert type(out) is dict
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_gail_cc(mock_env, dummy_config, gail_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, False, False, False
def test_gail_cc(mock_env, trainer_config, gail_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, gail_dummy_config, False, False, False
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_gail_dc_visual(mock_env, dummy_config, gail_dummy_config):
def test_gail_dc_visual(mock_env, trainer_config, gail_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, False, True, True
env, policy = create_policy_mock(
mock_env, trainer_config, gail_dummy_config, False, True, True
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_gail_rnn(mock_env, dummy_config, gail_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, gail_dummy_config, True, False, False
def test_gail_rnn(mock_env, trainer_config, gail_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, gail_dummy_config, True, False, False
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_curiosity_cc(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, False
def test_curiosity_cc(mock_env, trainer_config, curiosity_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, curiosity_dummy_config, False, False, False
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_curiosity_dc(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, True, False
def test_curiosity_dc(mock_env, trainer_config, curiosity_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, curiosity_dummy_config, False, True, False
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_curiosity_visual(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, True
def test_curiosity_visual(mock_env, trainer_config, curiosity_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, curiosity_dummy_config, False, False, True
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_curiosity_rnn(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, True, False, False
def test_curiosity_rnn(mock_env, trainer_config, curiosity_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, curiosity_dummy_config, True, False, False
@pytest.mark.parametrize(
"trainer_config", [ppo_dummy_config(), sac_dummy_config()], ids=["ppo", "sac"]
)
def test_extrinsic(mock_env, dummy_config, curiosity_dummy_config):
env, policy = create_ppo_policy_mock(
mock_env, dummy_config, curiosity_dummy_config, False, False, False
def test_extrinsic(mock_env, trainer_config, curiosity_dummy_config):
env, policy = create_policy_mock(
mock_env, trainer_config, curiosity_dummy_config, False, False, False
)
reward_signal_eval(env, policy, "extrinsic")
reward_signal_update(env, policy, "extrinsic")

90
ml-agents/mlagents/trainers/tests/test_simple_rl.py


pass
def _check_environment_trains(env):
config = """
default:
trainer: ppo
batch_size: 16
beta: 5.0e-3
buffer_size: 64
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 5.0e-3
max_steps: 2500
memory_size: 256
normalize: false
num_epoch: 3
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 500
use_recurrent: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
PPO_CONFIG = """
default:
trainer: ppo
batch_size: 16
beta: 5.0e-3
buffer_size: 64
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 5.0e-3
max_steps: 2500
memory_size: 256
normalize: false
num_epoch: 3
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 500
use_recurrent: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
SAC_CONFIG = """
default:
trainer: sac
batch_size: 32
buffer_size: 10240
buffer_init_steps: 1000
hidden_units: 64
init_entcoef: 0.01
learning_rate: 5.0e-3
max_steps: 2000
memory_size: 256
normalize: false
num_update: 1
train_interval: 1
num_layers: 1
time_horizon: 64
sequence_length: 64
summary_freq: 500
tau: 0.005
use_recurrent: false
curiosity_enc_size: 128
demo_path: None
vis_encode_type: default
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
def _check_environment_trains(env, config):
# Create controller and begin training.
with tempfile.TemporaryDirectory() as dir:
run_id = "id"

meta_curriculum=None,
multi_gpu=False,
)
print(trainers)
tc = TrainerController(
trainers=trainers,

@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_rl(use_discrete):
def test_simple_ppo(use_discrete):
_check_environment_trains(env)
_check_environment_trains(env, PPO_CONFIG)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_sac(use_discrete):
env = Simple1DEnvironment(use_discrete=use_discrete)
_check_environment_trains(env, SAC_CONFIG)

13
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers import Trainer
from mlagents.envs.brain import BrainParameters
from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.bc.offline_trainer import OfflineBCTrainer
from mlagents.trainers.bc.online_trainer import OnlineBCTrainer

seed,
run_id,
multi_gpu,
)
elif trainer_parameters_dict[brain_name]["trainer"] == "sac":
trainers[brain_name] = SACTrainer(
external_brains[brain_name],
meta_curriculum.brains_to_curriculums[brain_name].min_lesson_length
if meta_curriculum
else 1,
trainer_parameters_dict[brain_name],
train_model,
load_model,
seed,
run_id,
)
else:
raise UnityEnvironmentException(

276
config/sac_trainer_config.yaml


default:
trainer: sac
batch_size: 128
buffer_size: 50000
buffer_init_steps: 0
hidden_units: 128
init_entcoef: 1.0
learning_rate: 3.0e-4
max_steps: 5.0e4
memory_size: 256
normalize: false
num_update: 1
train_interval: 1
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
tau: 0.005
use_recurrent: false
vis_encode_type: default
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
BananaLearning:
normalize: false
batch_size: 256
buffer_size: 500000
max_steps: 1.0e5
init_entcoef: 0.05
train_interval: 1
VisualBananaLearning:
beta: 1.0e-2
gamma: 0.99
num_epoch: 1
max_steps: 5.0e5
summary_freq: 1000
BouncerLearning:
normalize: true
beta: 0.0
max_steps: 5.0e5
num_layers: 2
hidden_units: 64
summary_freq: 1000
PushBlockLearning:
max_steps: 5.0e4
init_entcoef: 0.05
beta: 1.0e-2
hidden_units: 256
summary_freq: 2000
time_horizon: 64
num_layers: 2
SmallWallJumpLearning:
max_steps: 1.0e6
hidden_units: 256
summary_freq: 2000
time_horizon: 128
init_entcoef: 0.1
num_layers: 2
normalize: false
BigWallJumpLearning:
max_steps: 1.0e6
hidden_units: 256
summary_freq: 2000
time_horizon: 128
num_layers: 2
init_entcoef: 0.1
normalize: false
StrikerLearning:
max_steps: 5.0e5
learning_rate: 1e-3
beta: 1.0e-2
hidden_units: 256
summary_freq: 2000
time_horizon: 128
init_entcoef: 0.1
num_layers: 2
normalize: false
GoalieLearning:
max_steps: 5.0e5
learning_rate: 1e-3
beta: 1.0e-2
hidden_units: 256
summary_freq: 2000
time_horizon: 128
init_entcoef: 0.1
num_layers: 2
normalize: false
PyramidsLearning:
summary_freq: 2000
time_horizon: 128
batch_size: 128
buffer_init_steps: 10000
buffer_size: 500000
hidden_units: 256
num_layers: 2
init_entcoef: 0.01
max_steps: 5.0e5
sequence_length: 16
tau: 0.01
use_recurrent: false
reward_signals:
extrinsic:
strength: 2.0
gamma: 0.99
gail:
strength: 0.02
gamma: 0.99
encoding_size: 128
use_actions: true
demo_path: demos/ExpertPyramid.demo
VisualPyramidsLearning:
time_horizon: 128
batch_size: 64
hidden_units: 256
buffer_init_steps: 1000
num_layers: 1
beta: 1.0e-2
max_steps: 5.0e5
buffer_size: 500000
init_entcoef: 0.01
tau: 0.01
reward_signals:
extrinsic:
strength: 2.0
gamma: 0.99
gail:
strength: 0.02
gamma: 0.99
encoding_size: 128
use_actions: true
demo_path: demos/ExpertPyramid.demo
3DBallLearning:
normalize: true
batch_size: 64
buffer_size: 12000
summary_freq: 1000
time_horizon: 1000
hidden_units: 64
init_entcoef: 0.5
max_steps: 5.0e5
3DBallHardLearning:
normalize: true
batch_size: 256
summary_freq: 1000
time_horizon: 1000
max_steps: 5.0e5
TennisLearning:
normalize: true
max_steps: 2e5
CrawlerStaticLearning:
normalize: true
time_horizon: 1000
batch_size: 256
train_interval: 3
buffer_size: 500000
buffer_init_steps: 2000
max_steps: 5e5
summary_freq: 3000
init_entcoef: 1.0
num_layers: 3
hidden_units: 512
CrawlerDynamicLearning:
normalize: true
time_horizon: 1000
batch_size: 256
buffer_size: 500000
summary_freq: 3000
train_interval: 3
num_layers: 3
max_steps: 5e5
hidden_units: 512
WalkerLearning:
normalize: true
time_horizon: 1000
batch_size: 256
buffer_size: 500000
max_steps: 2e6
summary_freq: 3000
num_layers: 3
train_interval: 3
hidden_units: 512
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
ReacherLearning:
normalize: true
time_horizon: 1000
batch_size: 128
buffer_size: 500000
max_steps: 1e6
summary_freq: 3000
HallwayLearning:
use_recurrent: true
sequence_length: 32
num_layers: 2
hidden_units: 128
memory_size: 256
beta: 0.0
init_entcoef: 0.1
max_steps: 5.0e5
summary_freq: 1000
time_horizon: 64
use_recurrent: true
VisualHallwayLearning:
use_recurrent: true
sequence_length: 32
num_layers: 1
hidden_units: 128
memory_size: 256
beta: 1.0e-2
gamma: 0.99
batch_size: 64
max_steps: 5.0e5
summary_freq: 1000
time_horizon: 64
use_recurrent: true
VisualPushBlockLearning:
use_recurrent: true
sequence_length: 32
num_layers: 1
hidden_units: 128
memory_size: 256
beta: 1.0e-2
gamma: 0.99
buffer_size: 1024
batch_size: 64
max_steps: 5.0e5
summary_freq: 1000
time_horizon: 64
GridWorldLearning:
batch_size: 128
normalize: false
num_layers: 1
hidden_units: 128
init_entcoef: 0.01
buffer_size: 50000
max_steps: 5.0e5
summary_freq: 2000
time_horizon: 5
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.9
BasicLearning:
batch_size: 64
normalize: false
num_layers: 2
init_entcoef: 0.01
hidden_units: 20
max_steps: 5.0e5
summary_freq: 2000
time_horizon: 10

330
docs/Training-SAC.md


# Training with Soft-Actor Critic
In addition to [Proximal Policy Optimization (PPO)](Training-PPO.md), ML-Agents also provides
[Soft Actor-Critic](http://bair.berkeley.edu/blog/2018/12/14/sac/) to perform
reinforcement learning.
In contrast with PPO, SAC is _off-policy_, which means it can learn from experiences collected
at any time during the past. As experiences are collected, they are placed in an
experience replay buffer and randomly drawn during training. This makes SAC
significantly more sample-efficient, often requiring 5-10 times less samples to learn
the same task as PPO. However, SAC tends to require more model updates. SAC is a
good choice for heavier or slower environments (about 0.1 seconds per step or more).
SAC is also a "maximum entropy" algorithm, and enables exploration in an intrinsic way.
Read more about maximum entropy RL [here](https://bair.berkeley.edu/blog/2017/10/06/soft-q-learning/).
To train an agent, you will need to provide the agent one or more reward signals which
the agent should attempt to maximize. See [Reward Signals](Training-RewardSignals.md)
for the available reward signals and the corresponding hyperparameters.
## Best Practices when training with SAC
Successfully training a reinforcement learning model often involves tuning
hyperparameters. This guide contains some best practices for training
when the default parameters don't seem to be giving the level of performance
you would like.
## Hyperparameters
### Reward Signals
In reinforcement learning, the goal is to learn a Policy that maximizes reward.
In the most basic case, the reward is given by the environment. However, we could imagine
rewarding the agent for various different behaviors. For instance, we could reward
the agent for exploring new states, rather than explicitly defined reward signals.
Furthermore, we could mix reward signals to help the learning process.
`reward_signals` provides a section to define [reward signals.](Training-RewardSignals.md)
ML-Agents provides two reward signals by default, the Extrinsic (environment) reward, and the
Curiosity reward, which can be used to encourage exploration in sparse extrinsic reward
environments.
#### Number of Updates for Reward Signal (Optional)
`reward_signal_num_update` for the reward signals corresponds to the number of mini batches sampled
and used for updating the reward signals during each
update. By default, we update the reward signals once every time the main policy is updated.
However, to imitate the training procedure in certain imitation learning papers (e.g.
[Kostrikov et. al](http://arxiv.org/abs/1809.02925), [Blondé et. al](http://arxiv.org/abs/1809.02064)),
we may want to update the policy N times, then update the reward signal (GAIL) M times.
We can change `train_interval` and `num_update` of SAC to N, as well as `reward_signal_num_update`
under `reward_signals` to M to accomplish this. By default, `reward_signal_num_update` is set to
`num_update`.
Typical Range: `num_update`
### Buffer Size
`buffer_size` corresponds the maximum number of experiences (agent observations, actions
and rewards obtained) that can be stored in the experience replay buffer. This value should be
large, on the order of thousands of times longer than your episodes, so that SAC
can learn from old as well as new experiences. It should also be much larger than
`batch_size`.
Typical Range: `50000` - `1000000`
### Buffer Init Steps
`buffer_init_steps` is the number of experiences to prefill the buffer with before attempting training.
As the untrained policy is fairly random, prefilling the buffer with random actions is
useful for exploration. Typically, at least several episodes of experiences should be
prefilled.
Typical Range: `1000` - `10000`
### Batch Size
`batch_size` is the number of experiences used for one iteration of a gradient
descent update. If
you are using a continuous action space, this value should be large (in the
order of 1000s). If you are using a discrete action space, this value should be
smaller (in order of 10s).
Typical Range (Continuous): `128` - `1024`
Typical Range (Discrete): `32` - `512`
### Initial Entropy Coefficient
`init_entcoef` refers to the initial entropy coefficient set at the beginning of training. In
SAC, the agent is incentivized to make its actions entropic to facilitate better exploration.
The entropy coefficient weighs the true reward with a bonus entropy reward. The entropy
coefficient is [automatically adjusted](https://arxiv.org/abs/1812.05905) to a preset target
entropy, so the `init_entcoef` only corresponds to the starting value of the entropy bonus.
Increase `init_entcoef` to explore more in the beginning, decrease to converge to a solution faster.
Typical Range (Continuous): `0.5` - `1.0`
Typical Range (Discrete): `0.05` - `0.5`
### Train Interval
`train_interval` is the number of steps taken between each agent training event. Typically,
we can train after every step, but if your environment's steps are very small and very frequent,
there may not be any new interesting information between steps, and `train_interval` can be increased.
Typical Range: `1` - `5`
### Number of Updates
`num_update` corresponds to the number of mini batches sampled and used for training during each
training event. In SAC, a single "update" corresponds to grabbing a batch of size `batch_size` from the experience
replay buffer, and using this mini batch to update the models. Typically, this can be left at 1.
However, to imitate the training procedure in certain papers (e.g.
[Kostrikov et. al](http://arxiv.org/abs/1809.02925), [Blondé et. al](http://arxiv.org/abs/1809.02064)),
we may want to update N times with different mini batches before grabbing additional samples.
We can change `train_interval` and `num_update` to N to accomplish this.
Typical Range: `1`
### Tau
`tau` corresponds to the magnitude of the target Q update during the SAC model update.
In SAC, there are two neural networks: the target and the policy. The target network is
used to bootstrap the policy's estimate of the future rewards at a given state, and is fixed
while the policy is being updated. This target is then slowly updated according to `tau`.
Typically, this value should be left at `0.005`. For simple problems, increasing
`tau` to `0.01` might reduce the time it takes to learn, at the cost of stability.
Typical Range: `0.005` - `0.01`
### Learning Rate
`learning_rate` corresponds to the strength of each gradient descent update
step. This should typically be decreased if training is unstable, and the reward
does not consistently increase.
Typical Range: `1e-5` - `1e-3`
### Time Horizon
`time_horizon` corresponds to how many steps of experience to collect per-agent
before adding it to the experience buffer. This parameter is a lot less critical
to SAC than PPO, and can typically be set to approximately your episode length.
Typical Range: `32` - `2048`
### Max Steps
`max_steps` corresponds to how many steps of the simulation (multiplied by
frame-skip) are run during the training process. This value should be increased
for more complex problems.
Typical Range: `5e5` - `1e7`
### Normalize
`normalize` corresponds to whether normalization is applied to the vector
observation inputs. This normalization is based on the running average and
variance of the vector observation. Normalization can be helpful in cases with
complex continuous control problems, but may be harmful with simpler discrete
control problems.
### Number of Layers
`num_layers` corresponds to how many hidden layers are present after the
observation input, or after the CNN encoding of the visual observation. For
simple problems, fewer layers are likely to train faster and more efficiently.
More layers may be necessary for more complex control problems.
Typical range: `1` - `3`
### Hidden Units
`hidden_units` correspond to how many units are in each fully connected layer of
the neural network. For simple problems where the correct action is a
straightforward combination of the observation inputs, this should be small. For
problems where the action is a very complex interaction between the observation
variables, this should be larger.
Typical Range: `32` - `512`
### (Optional) Visual Encoder Type
`vis_encode_type` corresponds to the encoder type for encoding visual observations.
Valid options include:
* `simple` (default): a simple encoder which consists of two convolutional layers
* `nature_cnn`: CNN implementation proposed by Mnih et al.(https://www.nature.com/articles/nature14236),
consisting of three convolutional layers
* `resnet`: IMPALA Resnet implementation (https://arxiv.org/abs/1802.01561),
consisting of three stacked layers, each with two risidual blocks, making a
much larger network than the other two.
Options: `simple`, `nature_cnn`, `resnet`
## (Optional) Recurrent Neural Network Hyperparameters
The below hyperparameters are only used when `use_recurrent` is set to true.
### Sequence Length
`sequence_length` corresponds to the length of the sequences of experience
passed through the network during training. This should be long enough to
capture whatever information your agent might need to remember over time. For
example, if your agent needs to remember the velocity of objects, then this can
be a small value. If your agent needs to remember a piece of information given
only once at the beginning of an episode, then this should be a larger value.
Typical Range: `4` - `128`
### Memory Size
`memory_size` corresponds to the size of the array of floating point numbers
used to store the hidden state of the recurrent neural network. This value must
be a multiple of 4, and should scale with the amount of information you expect
the agent will need to remember in order to successfully complete the task.
Typical Range: `64` - `512`
### (Optional) Save Replay Buffer
`save_replay_buffer` enables you to save and load the experience replay buffer as well as
the model when quitting and re-starting training. This may help resumes go more smoothly,
as the experiences collected won't be wiped. Note that replay buffers can be very large, and
will take up a considerable amount of disk space. For that reason, we disable this feature by
default.
Default: `False`
## (Optional) Pretraining Using Demonstrations
In some cases, you might want to bootstrap the agent's policy using behavior recorded
from a player. This can help guide the agent towards the reward. Pretraining adds
training operations that mimic a demonstration rather than attempting to maximize reward.
It is essentially equivalent to running [behavioral cloning](./Training-BehavioralCloning.md)
in-line with SAC.
To use pretraining, add a `pretraining` section to the trainer_config. For instance:
```
pretraining:
demo_path: ./demos/ExpertPyramid.demo
strength: 0.5
steps: 10000
```
Below are the avaliable hyperparameters for pretraining.
### Strength
`strength` corresponds to the learning rate of the imitation relative to the learning
rate of SAC, and roughly corresponds to how strongly we allow the behavioral cloning
to influence the policy.
Typical Range: `0.1` - `0.5`
### Demo Path
`demo_path` is the path to your `.demo` file or directory of `.demo` files.
See the [imitation learning guide](Training-Imitation-Learning.md) for more on `.demo` files.
### Steps
During pretraining, it is often desirable to stop using demonstrations after the agent has
"seen" rewards, and allow it to optimize past the available demonstrations and/or generalize
outside of the provided demonstrations. `steps` corresponds to the training steps over which
pretraining is active. The learning rate of the pretrainer will anneal over the steps. Set
the steps to 0 for constant imitation over the entire training run.
### (Optional) Batch Size
`batch_size` is the number of demonstration experiences used for one iteration of a gradient
descent update. If not specified, it will default to the `batch_size` defined for SAC.
Typical Range (Continuous): `512` - `5120`
Typical Range (Discrete): `32` - `512`
## Training Statistics
To view training statistics, use TensorBoard. For information on launching and
using TensorBoard, see
[here](./Getting-Started-with-Balance-Ball.md#observing-training-progress).
### Cumulative Reward
The general trend in reward should consistently increase over time. Small ups
and downs are to be expected. Depending on the complexity of the task, a
significant increase in reward may not present itself until millions of steps
into the training process.
### Entropy Coefficient
SAC is a "maximum entropy" reinforcement learning algorithm, and agents trained using
SAC are incentivized to behave randomly while also solving the problem. The entropy
coefficient balances the incentive to behave randomly vs. maximizing the reward.
This value is adjusted automatically so that the agent retains some amount of randomness during
training. It should steadily decrease in the beginning of training, and reach some small
value where it will level off. If it decreases too soon or takes too
long to decrease, `init_entcoef` should be adjusted.
### Entropy
This corresponds to how random the decisions of a Brain are. This should
initially increase during training, reach a peak, and should decline along
with the Entropy Coefficient. This is because in the beginning, the agent is
incentivised to be more random for exploration due to a high entropy coefficient.
If it decreases too soon or takes too long to decrease, `init_entcoef` should be adjusted.
### Learning Rate
This will decrease over time on a linear schedule.
### Policy Loss
These values may increase as the agent explores, but should decrease longterm
as the agent learns how to solve the task.
### Value Estimate
These values should increase as the cumulative reward increases. They correspond
to how much future reward the agent predicts itself receiving at any given
point. They may also increase at the beginning as the agent is rewarded for
being random (see: Entropy and Entropy Coefficient), but should decline as
Entropy Coefficient decreases.
### Value Loss
These values will increase as the reward increases, and then should decrease
once reward becomes stable.

399
ml-agents/mlagents/trainers/tests/test_sac.py


import unittest.mock as mock
import pytest
import tempfile
import yaml
import math
import numpy as np
import tensorflow as tf
from mlagents.trainers.sac.models import SACModel
from mlagents.trainers.sac.policy import SACPolicy
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.tests.test_simple_rl import Simple1DEnvironment, SimpleEnvManager
from mlagents.trainers.trainer_util import initialize_trainers
from mlagents.envs import UnityEnvironment
from mlagents.envs.mock_communicator import MockCommunicator
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs import BrainInfo, AllBrainInfo, BrainParameters
from mlagents.envs.communicator_objects import AgentInfoProto
from mlagents.envs.sampler_class import SamplerManager
from mlagents.trainers.tests import mock_brain as mb
@pytest.fixture
def dummy_config():
return yaml.load(
"""
trainer: sac
batch_size: 32
buffer_size: 10240
buffer_init_steps: 0
hidden_units: 32
init_entcoef: 0.1
learning_rate: 3.0e-4
max_steps: 1024
memory_size: 8
normalize: false
num_update: 1
train_interval: 1
num_layers: 1
time_horizon: 64
sequence_length: 16
summary_freq: 1000
tau: 0.005
use_recurrent: false
curiosity_enc_size: 128
demo_path: None
vis_encode_type: default
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
"""
)
VECTOR_ACTION_SPACE = [2]
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 32
NUM_AGENTS = 12
def create_sac_policy_mock(mock_env, dummy_config, use_rnn, use_discrete, use_visual):
env, mock_brain, _ = mb.setup_mock_env_and_brains(
mock_env,
use_discrete,
use_visual,
num_agents=NUM_AGENTS,
vector_action_space=VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
discrete_action_space=DISCRETE_ACTION_SPACE,
)
trainer_parameters = dummy_config
model_path = env.brain_names[0]
trainer_parameters["model_path"] = model_path
trainer_parameters["keep_checkpoints"] = 3
trainer_parameters["use_recurrent"] = use_rnn
policy = SACPolicy(0, mock_brain, trainer_parameters, False, False)
return env, policy
@mock.patch("mlagents.envs.UnityEnvironment")
def test_sac_cc_policy(mock_env, dummy_config):
# Test evaluate
tf.reset_default_graph()
env, policy = create_sac_policy_mock(
mock_env, dummy_config, use_rnn=False, use_discrete=False, use_visual=False
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
run_out = policy.evaluate(brain_info)
assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE[0])
# Test update
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
# Mock out reward signal eval
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
policy.update(
buffer.update_buffer, num_sequences=len(buffer.update_buffer["actions"])
)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment")
def test_sac_update_reward_signals(mock_env, dummy_config):
# Test evaluate
tf.reset_default_graph()
# Add a Curiosity module
dummy_config["reward_signals"]["curiosity"] = {}
dummy_config["reward_signals"]["curiosity"]["strength"] = 1.0
dummy_config["reward_signals"]["curiosity"]["gamma"] = 0.99
dummy_config["reward_signals"]["curiosity"]["encoding_size"] = 128
env, policy = create_sac_policy_mock(
mock_env, dummy_config, use_rnn=False, use_discrete=False, use_visual=False
)
# Test update
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
# Mock out reward signal eval
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
buffer.update_buffer["curiosity_rewards"] = buffer.update_buffer["rewards"]
policy.update_reward_signals(
{"curiosity": buffer.update_buffer},
num_sequences=len(buffer.update_buffer["actions"]),
)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment")
def test_sac_dc_policy(mock_env, dummy_config):
# Test evaluate
tf.reset_default_graph()
env, policy = create_sac_policy_mock(
mock_env, dummy_config, use_rnn=False, use_discrete=True, use_visual=False
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
run_out = policy.evaluate(brain_info)
assert run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE))
# Test update
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
# Mock out reward signal eval
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
policy.update(
buffer.update_buffer, num_sequences=len(buffer.update_buffer["actions"])
)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment")
def test_sac_visual_policy(mock_env, dummy_config):
# Test evaluate
tf.reset_default_graph()
env, policy = create_sac_policy_mock(
mock_env, dummy_config, use_rnn=False, use_discrete=True, use_visual=True
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
run_out = policy.evaluate(brain_info)
assert run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE))
# Test update
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
# Mock out reward signal eval
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
run_out = policy.update(
buffer.update_buffer, num_sequences=len(buffer.update_buffer["actions"])
)
assert type(run_out) is dict
@mock.patch("mlagents.envs.UnityEnvironment")
def test_sac_rnn_policy(mock_env, dummy_config):
# Test evaluate
tf.reset_default_graph()
env, policy = create_sac_policy_mock(
mock_env, dummy_config, use_rnn=True, use_discrete=True, use_visual=False
)
brain_infos = env.reset()
brain_info = brain_infos[env.brain_names[0]]
run_out = policy.evaluate(brain_info)
assert run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE))
# Test update
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES)
# Mock out reward signal eval
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"]
policy.update(buffer.update_buffer, num_sequences=2)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_sac_model_cc_vector(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
model = SACModel(env.brains["RealFakeBrain"])
init = tf.global_variables_initializer()
sess.run(init)
run_list = [model.output, model.value, model.entropy, model.learning_rate]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_sac_model_cc_visual(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=2
)
env = UnityEnvironment(" ")
model = SACModel(env.brains["RealFakeBrain"])
init = tf.global_variables_initializer()
sess.run(init)
run_list = [model.output, model.value, model.entropy, model.learning_rate]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_sac_model_dc_visual(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=2
)
env = UnityEnvironment(" ")
model = SACModel(env.brains["RealFakeBrain"])
init = tf.global_variables_initializer()
sess.run(init)
run_list = [model.output, model.value, model.entropy, model.learning_rate]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.visual_in[0]: np.ones([2, 40, 30, 3]),
model.visual_in[1]: np.ones([2, 40, 30, 3]),
model.action_masks: np.ones([2, 2]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_sac_model_dc_vector(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=0
)
env = UnityEnvironment(" ")
model = SACModel(env.brains["RealFakeBrain"])
init = tf.global_variables_initializer()
sess.run(init)
run_list = [model.output, model.value, model.entropy, model.learning_rate]
feed_dict = {
model.batch_size: 2,
model.sequence_length: 1,
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.action_masks: np.ones([2, 2]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_sac_model_dc_vector_rnn(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=True, visual_inputs=0
)
env = UnityEnvironment(" ")
memory_size = 128
model = SACModel(
env.brains["RealFakeBrain"], use_recurrent=True, m_size=memory_size
)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.memory_out,
]
feed_dict = {
model.batch_size: 1,
model.sequence_length: 2,
model.prev_action: [[0], [0]],
model.memory_in: np.zeros((1, memory_size)),
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
model.action_masks: np.ones([1, 2]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator")
def test_sac_model_cc_vector_rnn(mock_communicator, mock_launcher):
tf.reset_default_graph()
with tf.Session() as sess:
with tf.variable_scope("FakeGraphScope"):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
memory_size = 128
model = SACModel(
env.brains["RealFakeBrain"], use_recurrent=True, m_size=memory_size
)
init = tf.global_variables_initializer()
sess.run(init)
run_list = [
model.output,
model.all_log_probs,
model.value,
model.entropy,
model.learning_rate,
model.memory_out,
]
feed_dict = {
model.batch_size: 1,
model.sequence_length: 2,
model.memory_in: np.zeros((1, memory_size)),
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]),
}
sess.run(run_list, feed_dict=feed_dict)
env.close()
def test_sac_save_load_buffer(tmpdir):
env, mock_brain, _ = mb.setup_mock_env_and_brains(
mock.Mock(),
False,
False,
num_agents=NUM_AGENTS,
vector_action_space=VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
discrete_action_space=DISCRETE_ACTION_SPACE,
)
trainer_params = dummy_config()
trainer_params["summary_path"] = str(tmpdir)
trainer_params["model_path"] = str(tmpdir)
trainer_params["save_replay_buffer"] = True
trainer = SACTrainer(mock_brain, 1, trainer_params, True, False, 0, 0)
trainer.training_buffer = mb.simulate_rollout(
env, trainer.policy, BUFFER_INIT_SAMPLES
)
buffer_len = len(trainer.training_buffer.update_buffer["actions"])
trainer.save_model()
# Wipe Trainer and try to load
trainer2 = SACTrainer(mock_brain, 1, trainer_params, True, True, 0, 0)
assert len(trainer2.training_buffer.update_buffer["actions"]) == buffer_len
if __name__ == "__main__":
pytest.main()

3
ml-agents/mlagents/trainers/sac/__init__.py


from .models import *
from .trainer import *
from .policy import *

1001
ml-agents/mlagents/trainers/sac/models.py
文件差异内容过多而无法显示
查看文件

320
ml-agents/mlagents/trainers/sac/policy.py


import logging
from typing import Dict, List, Any
import numpy as np
import tensorflow as tf
from mlagents.envs.timers import timed
from mlagents.trainers import BrainInfo, ActionInfo, BrainParameters
from mlagents.trainers.sac.models import SACModel
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
from mlagents.trainers.components.reward_signals.reward_signal import RewardSignal
from mlagents.trainers.components.bc import BCModule
logger = logging.getLogger("mlagents.trainers")
class SACPolicy(TFPolicy):
def __init__(
self,
seed: int,
brain: BrainParameters,
trainer_params: Dict[str, Any],
is_training: bool,
load: bool,
) -> None:
"""
Policy for Proximal Policy Optimization Networks.
:param seed: Random seed.
:param brain: Assigned Brain object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
"""
super().__init__(seed, brain, trainer_params)
reward_signal_configs = {}
for key, rsignal in trainer_params["reward_signals"].items():
if type(rsignal) is dict:
reward_signal_configs[key] = rsignal
self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
self.create_model(
brain, trainer_params, reward_signal_configs, is_training, load, seed
)
self.create_reward_signals(reward_signal_configs)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
"Losses/Q1 Loss": "q1_loss",
"Losses/Q2 Loss": "q2_loss",
"Policy/Entropy Coeff": "entropy_coef",
}
with self.graph.as_default():
# Create pretrainer if needed
if "pretraining" in trainer_params:
BCModule.check_config(trainer_params["pretraining"])
self.bc_module = BCModule(
self,
policy_learning_rate=trainer_params["learning_rate"],
default_batch_size=trainer_params["batch_size"],
default_num_epoch=1,
samples_per_update=trainer_params["batch_size"],
**trainer_params["pretraining"],
)
# SAC-specific setting - we don't want to do a whole epoch each update!
if "samples_per_update" in trainer_params["pretraining"]:
logger.warning(
"Pretraining: Samples Per Update is not a valid setting for SAC."
)
self.bc_module.samples_per_update = 1
else:
self.bc_module = None
if load:
self._load_graph()
else:
self._initialize_graph()
self.sess.run(self.model.target_init_op)
# Disable terminal states for certain reward signals to avoid survivor bias
for name, reward_signal in self.reward_signals.items():
if not reward_signal.use_terminal_states:
self.sess.run(self.model.disable_use_dones[name])
def create_model(
self,
brain: BrainParameters,
trainer_params: Dict[str, Any],
reward_signal_configs: Dict[str, Any],
is_training: bool,
load: bool,
seed: int,
) -> None:
with self.graph.as_default():
self.model = SACModel(
brain,
lr=float(trainer_params["learning_rate"]),
h_size=int(trainer_params["hidden_units"]),
init_entcoef=float(trainer_params["init_entcoef"]),
max_step=float(trainer_params["max_steps"]),
normalize=trainer_params["normalize"],
use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
seed=seed,
stream_names=list(reward_signal_configs.keys()),
tau=float(trainer_params["tau"]),
gammas=list(_val["gamma"] for _val in reward_signal_configs.values()),
vis_encode_type=trainer_params["vis_encode_type"],
)
self.model.create_sac_optimizers()
self.inference_dict.update(
{
"action": self.model.output,
"log_probs": self.model.all_log_probs,
"value_heads": self.model.value_heads,
"value": self.model.value,
"entropy": self.model.entropy,
"learning_rate": self.model.learning_rate,
}
)
if self.use_continuous_act:
self.inference_dict["pre_action"] = self.model.output_pre
if self.use_recurrent:
self.inference_dict["memory_out"] = self.model.memory_out
if (
is_training
and self.use_vec_obs
and trainer_params["normalize"]
and not load
):
self.inference_dict["update_mean"] = self.model.update_normalization
self.update_dict.update(
{
"value_loss": self.model.total_value_loss,
"policy_loss": self.model.policy_loss,
"q1_loss": self.model.q1_loss,
"q2_loss": self.model.q2_loss,
"entropy_coef": self.model.ent_coef,
"entropy": self.model.entropy,
"update_batch": self.model.update_batch_policy,
"update_value": self.model.update_batch_value,
"update_entropy": self.model.update_batch_entropy,
}
)
def create_reward_signals(self, reward_signal_configs: Dict[str, Any]) -> None:
"""
Create reward signals
:param reward_signal_configs: Reward signal config.
"""
self.reward_signals: Dict[str, RewardSignal] = {}
with self.graph.as_default():
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
if type(config) is dict:
self.reward_signals[reward_signal] = create_reward_signal(
self, self.model, reward_signal, config
)
def evaluate(self, brain_info: BrainInfo) -> Dict[str, np.ndarray]:
"""
Evaluates policy for the agent experiences provided.
:param brain_info: BrainInfo object containing inputs.
:return: Outputs from network as defined by self.inference_dict.
"""
feed_dict = {
self.model.batch_size: len(brain_info.vector_observations),
self.model.sequence_length: 1,
}
if self.use_recurrent:
if not self.use_continuous_act:
feed_dict[
self.model.prev_action
] = brain_info.previous_vector_actions.reshape(
[-1, len(self.model.act_size)]
)
if brain_info.memories.shape[1] == 0:
brain_info.memories = self.make_empty_memory(len(brain_info.agents))
feed_dict[self.model.memory_in] = brain_info.memories
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
@timed
def update(
self, mini_batch: Dict[str, Any], num_sequences: int, update_target: bool = True
) -> Dict[str, float]:
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.
:param mini_batch: Experience batch.
:param update_target: Whether or not to update target value network
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
feed_dict = self.construct_feed_dict(self.model, mini_batch, num_sequences)
stats_needed = self.stats_name_to_update_name
update_stats: Dict[str, float] = {}
update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
if update_target:
self.sess.run(self.model.target_update_op)
return update_stats
def update_reward_signals(
self, reward_signal_minibatches: Dict[str, Dict], num_sequences: int
) -> Dict[str, float]:
"""
Only update the reward signals.
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
"""
# Collect feed dicts for all reward signals.
feed_dict: Dict[tf.Tensor, Any] = {}
update_dict: Dict[str, tf.Tensor] = {}
update_stats: Dict[str, float] = {}
stats_needed: Dict[str, str] = {}
if reward_signal_minibatches:
self.add_reward_signal_dicts(
feed_dict,
update_dict,
stats_needed,
reward_signal_minibatches,
num_sequences,
)
update_vals = self._execute_model(feed_dict, update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
return update_stats
def add_reward_signal_dicts(
self,
feed_dict: Dict[tf.Tensor, Any],
update_dict: Dict[str, tf.Tensor],
stats_needed: Dict[str, str],
reward_signal_minibatches: Dict[str, Dict],
num_sequences: int,
) -> None:
"""
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict.
:param feed_dict: Feed dict needed update
:param update_dit: Update dict that needs update
:param stats_needed: Stats needed to get from the update.
:param reward_signal_minibatches: Minibatches to use for updating the reward signals,
indexed by name.
"""
for name, r_mini_batch in reward_signal_minibatches.items():
feed_dict.update(
self.reward_signals[name].prepare_update(
self.model, r_mini_batch, num_sequences
)
)
update_dict.update(self.reward_signals[name].update_dict)
stats_needed.update(self.reward_signals[name].stats_name_to_update_name)
def construct_feed_dict(
self, model: SACModel, mini_batch: Dict[str, Any], num_sequences: int
) -> Dict[tf.Tensor, Any]:
"""
Builds the feed dict for updating the SAC model.
:param model: The model to update. May be different when, e.g. using multi-GPU.
:param mini_batch: Mini-batch to use to update.
:param num_sequences: Number of LSTM sequences in mini_batch.
"""
feed_dict = {
self.model.batch_size: num_sequences,
self.model.sequence_length: self.sequence_length,
self.model.next_sequence_length: self.sequence_length,
self.model.mask_input: mini_batch["masks"],
}
for name in self.reward_signals:
feed_dict[model.rewards_holders[name]] = mini_batch[
"{}_rewards".format(name)
]
if self.use_continuous_act:
feed_dict[model.action_holder] = mini_batch["actions"]
else:
feed_dict[model.action_holder] = mini_batch["actions"]
if self.use_recurrent:
feed_dict[model.prev_action] = mini_batch["prev_action"]
feed_dict[model.action_masks] = mini_batch["action_mask"]
if self.use_vec_obs:
feed_dict[model.vector_in] = mini_batch["vector_obs"]
feed_dict[model.next_vector_in] = mini_batch["next_vector_in"]
if self.model.vis_obs_size > 0:
for i, _ in enumerate(model.visual_in):
_obs = mini_batch["visual_obs%d" % i]
feed_dict[model.visual_in[i]] = _obs
for i, _ in enumerate(model.next_visual_in):
_obs = mini_batch["next_visual_obs%d" % i]
feed_dict[model.next_visual_in[i]] = _obs
if self.use_recurrent:
mem_in = [
mini_batch["memory"][i]
for i in range(0, len(mini_batch["memory"]), self.sequence_length)
]
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
offset = 1 if self.sequence_length > 1 else 0
next_mem_in = [
mini_batch["memory"][i][
: self.m_size // 4
] # only pass value part of memory to target network
for i in range(offset, len(mini_batch["memory"]), self.sequence_length)
]
feed_dict[model.memory_in] = mem_in
feed_dict[model.next_memory_in] = next_mem_in
feed_dict[model.dones_holder] = mini_batch["done"]
return feed_dict

340
ml-agents/mlagents/trainers/sac/trainer.py


# # Unity ML-Agents Toolkit
# ## ML-Agent Learning (SAC)
# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290
# and implemented in https://github.com/hill-a/stable-baselines
import logging
from collections import deque, defaultdict
from typing import List, Any, Dict
import os
import numpy as np
import tensorflow as tf
from mlagents.envs import AllBrainInfo, BrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.envs.timers import timed, hierarchical_timer
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.sac.policy import SACPolicy
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.rl_trainer import RLTrainer, AllRewardsOutput
from mlagents.trainers.components.reward_signals import RewardSignalResult
LOGGER = logging.getLogger("mlagents.trainers")
BUFFER_TRUNCATE_PERCENT = 0.8
class SACTrainer(RLTrainer):
"""
The SACTrainer is an implementation of the SAC algorithm, with support
for discrete actions and recurrent networks.
"""
def __init__(
self, brain, reward_buff_cap, trainer_parameters, training, load, seed, run_id
):
"""
Responsible for collecting experiences and training SAC model.
:param trainer_parameters: The parameters for the trainer (dictionary).
:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param run_id: The The identifier of the current run
"""
super().__init__(brain, trainer_parameters, training, run_id, reward_buff_cap)
self.param_keys = [
"batch_size",
"buffer_size",
"buffer_init_steps",
"hidden_units",
"learning_rate",
"init_entcoef",
"max_steps",
"normalize",
"num_update",
"num_layers",
"time_horizon",
"sequence_length",
"summary_freq",
"tau",
"use_recurrent",
"summary_path",
"memory_size",
"model_path",
"reward_signals",
"vis_encode_type",
]
self.check_param_keys()
self.step = 0
self.train_interval = (
trainer_parameters["train_interval"]
if "train_interval" in trainer_parameters
else 1
)
self.reward_signal_updates_per_train = (
trainer_parameters["reward_signals"]["reward_signal_num_update"]
if "reward_signal_num_update" in trainer_parameters["reward_signals"]
else trainer_parameters["num_update"]
)
self.checkpoint_replay_buffer = (
trainer_parameters["save_replay_buffer"]
if "save_replay_buffer" in trainer_parameters
else False
)
self.policy = SACPolicy(seed, brain, trainer_parameters, self.is_training, load)
# Load the replay buffer if load
if load and self.checkpoint_replay_buffer:
try:
self.load_replay_buffer()
except (AttributeError, FileNotFoundError):
LOGGER.warning(
"Replay buffer was unable to load, starting from scratch."
)
LOGGER.debug(
"Loaded update buffer with {} sequences".format(
len(self.training_buffer.update_buffer["actions"])
)
)
for _reward_signal in self.policy.reward_signals.keys():
self.collected_rewards[_reward_signal] = {}
self.episode_steps = {}
def save_model(self) -> None:
"""
Saves the model. Overrides the default save_model since we want to save
the replay buffer as well.
"""
self.policy.save_model(self.get_step)
if self.checkpoint_replay_buffer:
self.save_replay_buffer()
def save_replay_buffer(self) -> None:
"""
Save the training buffer's update buffer to a pickle file.
"""
filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5")
LOGGER.info("Saving Experience Replay Buffer to {}".format(filename))
with open(filename, "wb") as file_object:
self.training_buffer.update_buffer.save_to_file(file_object)
def load_replay_buffer(self) -> Buffer:
"""
Loads the last saved replay buffer from a file.
"""
filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5")
LOGGER.info("Loading Experience Replay Buffer from {}".format(filename))
with open(filename, "rb+") as file_object:
self.training_buffer.update_buffer.load_from_file(file_object)
LOGGER.info(
"Experience replay buffer has {} experiences.".format(
len(self.training_buffer.update_buffer["actions"])
)
)
def add_policy_outputs(
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int
) -> None:
"""
Takes the output of the last action and store it into the training buffer.
"""
actions = take_action_outputs["action"]
self.training_buffer[agent_id]["actions"].append(actions[agent_idx])
def add_rewards_outputs(
self,
rewards_out: AllRewardsOutput,
values: Dict[str, np.ndarray],
agent_id: str,
agent_idx: int,
agent_next_idx: int,
) -> None:
"""
Takes the value output of the last action and store it into the training buffer.
"""
self.training_buffer[agent_id]["environment_rewards"].append(
rewards_out.environment[agent_next_idx]
)
def process_experiences(
self, current_info: AllBrainInfo, new_info: AllBrainInfo
) -> None:
"""
Checks agent histories for processing condition, and processes them as necessary.
:param current_info: Dictionary of all current brains and corresponding BrainInfo.
:param new_info: Dictionary of all next brains and corresponding BrainInfo.
"""
info = new_info[self.brain_name]
for l in range(len(info.agents)):
agent_actions = self.training_buffer[info.agents[l]]["actions"]
if (
info.local_done[l]
or len(agent_actions) >= self.trainer_parameters["time_horizon"]
) and len(agent_actions) > 0:
agent_id = info.agents[l]
# Bootstrap using last brain info. Set last element to duplicate obs and remove dones.
if info.max_reached[l]:
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
idx = bootstrapping_info.agents.index(agent_id)
for i, obs in enumerate(bootstrapping_info.visual_observations):
self.training_buffer[agent_id]["next_visual_obs%d" % i][
-1
] = obs[idx]
if self.policy.use_vec_obs:
self.training_buffer[agent_id]["next_vector_in"][
-1
] = bootstrapping_info.vector_observations[idx]
self.training_buffer[agent_id]["done"][-1] = False
self.training_buffer.append_update_buffer(
agent_id,
batch_size=None,
training_length=self.policy.sequence_length,
)
self.training_buffer[agent_id].reset_agent()
if info.local_done[l]:
self.stats["Environment/Episode Length"].append(
self.episode_steps.get(agent_id, 0)
)
self.episode_steps[agent_id] = 0
for name, rewards in self.collected_rewards.items():
if name == "environment":
self.cumulative_returns_since_policy_update.append(
rewards.get(agent_id, 0)
)
self.stats["Environment/Cumulative Reward"].append(
rewards.get(agent_id, 0)
)
self.reward_buffer.appendleft(rewards.get(agent_id, 0))
rewards[agent_id] = 0
else:
self.stats[
self.policy.reward_signals[name].stat_name
].append(rewards.get(agent_id, 0))
rewards[agent_id] = 0
def is_ready_update(self) -> bool:
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to whether or not update_model() can be run
"""
return (
len(self.training_buffer.update_buffer["actions"])
>= self.trainer_parameters["batch_size"]
and self.step >= self.trainer_parameters["buffer_init_steps"]
)
@timed
def update_policy(self) -> None:
"""
If train_interval is met, update the SAC policy given the current reward signals.
If reward_signal_train_interval is met, update the reward signals from the buffer.
"""
if self.step % self.train_interval == 0:
self.trainer_metrics.start_policy_update_timer(
number_experiences=len(self.training_buffer.update_buffer["actions"]),
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),
)
self.update_sac_policy()
self.update_reward_signals()
self.trainer_metrics.end_policy_update()
def update_sac_policy(self) -> None:
"""
Uses demonstration_buffer to update the policy.
The reward signal generators are updated using different mini batches.
If we want to imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated
N times, then the reward signals are updated N times, then reward_signal_updates_per_train
is greater than 1 and the reward signals are not updated in parallel.
"""
self.cumulative_returns_since_policy_update: List[float] = []
n_sequences = max(
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)
num_updates = self.trainer_parameters["num_update"]
batch_update_stats: Dict[str, list] = defaultdict(list)
for _ in range(num_updates):
LOGGER.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.training_buffer.update_buffer
if (
len(self.training_buffer.update_buffer["actions"])
>= self.trainer_parameters["batch_size"]
):
sampled_minibatch = buffer.sample_mini_batch(
self.trainer_parameters["batch_size"],
sequence_length=self.policy.sequence_length,
)
# Get rewards for each reward
for name, signal in self.policy.reward_signals.items():
sampled_minibatch[
"{}_rewards".format(name)
] = signal.evaluate_batch(sampled_minibatch).scaled_reward
update_stats = self.policy.update(
sampled_minibatch, n_sequences, update_target=True
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating
# a large buffer at each update.
if (
len(self.training_buffer.update_buffer["actions"])
> self.trainer_parameters["buffer_size"]
):
self.training_buffer.truncate_update_buffer(
int(self.trainer_parameters["buffer_size"] * BUFFER_TRUNCATE_PERCENT)
)
for stat, stat_list in batch_update_stats.items():
self.stats[stat].append(np.mean(stat_list))
if self.policy.bc_module:
update_stats = self.policy.bc_module.update()
for stat, val in update_stats.items():
self.stats[stat].append(val)
def update_reward_signals(self) -> None:
"""
Iterate through the reward signals and update them. Unlike in PPO,
do it separate from the policy so that it can be done at a different
interval.
This function should only be used to simulate
http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated
N times, then the reward signals are updated N times. Normally, the reward signal
and policy are updated in parallel.
"""
buffer = self.training_buffer.update_buffer
num_updates = self.reward_signal_updates_per_train
n_sequences = max(
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1
)
batch_update_stats: Dict[str, list] = defaultdict(list)
for _ in range(num_updates):
# Get minibatches for reward signal update if needed
reward_signal_minibatches = {}
for name, signal in self.policy.reward_signals.items():
LOGGER.debug("Updating {} at step {}".format(name, self.step))
# Some signals don't need a minibatch to be sampled - so we don't!
if signal.update_dict:
reward_signal_minibatches[name] = buffer.sample_mini_batch(
self.trainer_parameters["batch_size"],
sequence_length=self.policy.sequence_length,
)
update_stats = self.policy.update_reward_signals(
reward_signal_minibatches, n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
for stat, stat_list in batch_update_stats.items():
self.stats[stat].append(np.mean(stat_list))
正在加载...
取消
保存