浏览代码
Add Soft Actor-Critic as trainer option (#2341)
Add Soft Actor-Critic as trainer option (#2341)
* Add Soft Actor-Critic model, trainer, and policy and sac_trainer_config.yaml * Add documentation for SAC and tweak PPO documentation to reference the new pages. * Add tests for SAC, change simple_rl test to run both PPO and SAC./develop-gpu-test
GitHub
5 年前
当前提交
6a81a2f4
共有 18 个文件被更改,包括 2987 次插入 和 131 次删除
-
1README.md
-
49docs/Getting-Started-with-Balance-Ball.md
-
1docs/Readme.md
-
47docs/Training-ML-Agents.md
-
9docs/Training-PPO.md
-
3ml-agents/mlagents/trainers/__init__.py
-
113ml-agents/mlagents/trainers/tests/test_bcmodule.py
-
123ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
90ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
13ml-agents/mlagents/trainers/trainer_util.py
-
276config/sac_trainer_config.yaml
-
330docs/Training-SAC.md
-
399ml-agents/mlagents/trainers/tests/test_sac.py
-
3ml-agents/mlagents/trainers/sac/__init__.py
-
1001ml-agents/mlagents/trainers/sac/models.py
-
320ml-agents/mlagents/trainers/sac/policy.py
-
340ml-agents/mlagents/trainers/sac/trainer.py
|
|||
default: |
|||
trainer: sac |
|||
batch_size: 128 |
|||
buffer_size: 50000 |
|||
buffer_init_steps: 0 |
|||
hidden_units: 128 |
|||
init_entcoef: 1.0 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 5.0e4 |
|||
memory_size: 256 |
|||
normalize: false |
|||
num_update: 1 |
|||
train_interval: 1 |
|||
num_layers: 2 |
|||
time_horizon: 64 |
|||
sequence_length: 64 |
|||
summary_freq: 1000 |
|||
tau: 0.005 |
|||
use_recurrent: false |
|||
vis_encode_type: default |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
|
|||
BananaLearning: |
|||
normalize: false |
|||
batch_size: 256 |
|||
buffer_size: 500000 |
|||
max_steps: 1.0e5 |
|||
init_entcoef: 0.05 |
|||
train_interval: 1 |
|||
|
|||
VisualBananaLearning: |
|||
beta: 1.0e-2 |
|||
gamma: 0.99 |
|||
num_epoch: 1 |
|||
max_steps: 5.0e5 |
|||
summary_freq: 1000 |
|||
|
|||
BouncerLearning: |
|||
normalize: true |
|||
beta: 0.0 |
|||
max_steps: 5.0e5 |
|||
num_layers: 2 |
|||
hidden_units: 64 |
|||
summary_freq: 1000 |
|||
|
|||
PushBlockLearning: |
|||
max_steps: 5.0e4 |
|||
init_entcoef: 0.05 |
|||
beta: 1.0e-2 |
|||
hidden_units: 256 |
|||
summary_freq: 2000 |
|||
time_horizon: 64 |
|||
num_layers: 2 |
|||
|
|||
SmallWallJumpLearning: |
|||
max_steps: 1.0e6 |
|||
hidden_units: 256 |
|||
summary_freq: 2000 |
|||
time_horizon: 128 |
|||
init_entcoef: 0.1 |
|||
num_layers: 2 |
|||
normalize: false |
|||
|
|||
BigWallJumpLearning: |
|||
max_steps: 1.0e6 |
|||
hidden_units: 256 |
|||
summary_freq: 2000 |
|||
time_horizon: 128 |
|||
num_layers: 2 |
|||
init_entcoef: 0.1 |
|||
normalize: false |
|||
|
|||
StrikerLearning: |
|||
max_steps: 5.0e5 |
|||
learning_rate: 1e-3 |
|||
beta: 1.0e-2 |
|||
hidden_units: 256 |
|||
summary_freq: 2000 |
|||
time_horizon: 128 |
|||
init_entcoef: 0.1 |
|||
num_layers: 2 |
|||
normalize: false |
|||
|
|||
GoalieLearning: |
|||
max_steps: 5.0e5 |
|||
learning_rate: 1e-3 |
|||
beta: 1.0e-2 |
|||
hidden_units: 256 |
|||
summary_freq: 2000 |
|||
time_horizon: 128 |
|||
init_entcoef: 0.1 |
|||
num_layers: 2 |
|||
normalize: false |
|||
|
|||
PyramidsLearning: |
|||
summary_freq: 2000 |
|||
time_horizon: 128 |
|||
batch_size: 128 |
|||
buffer_init_steps: 10000 |
|||
buffer_size: 500000 |
|||
hidden_units: 256 |
|||
num_layers: 2 |
|||
init_entcoef: 0.01 |
|||
max_steps: 5.0e5 |
|||
sequence_length: 16 |
|||
tau: 0.01 |
|||
use_recurrent: false |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 2.0 |
|||
gamma: 0.99 |
|||
gail: |
|||
strength: 0.02 |
|||
gamma: 0.99 |
|||
encoding_size: 128 |
|||
use_actions: true |
|||
demo_path: demos/ExpertPyramid.demo |
|||
|
|||
VisualPyramidsLearning: |
|||
time_horizon: 128 |
|||
batch_size: 64 |
|||
hidden_units: 256 |
|||
buffer_init_steps: 1000 |
|||
num_layers: 1 |
|||
beta: 1.0e-2 |
|||
max_steps: 5.0e5 |
|||
buffer_size: 500000 |
|||
init_entcoef: 0.01 |
|||
tau: 0.01 |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 2.0 |
|||
gamma: 0.99 |
|||
gail: |
|||
strength: 0.02 |
|||
gamma: 0.99 |
|||
encoding_size: 128 |
|||
use_actions: true |
|||
demo_path: demos/ExpertPyramid.demo |
|||
|
|||
3DBallLearning: |
|||
normalize: true |
|||
batch_size: 64 |
|||
buffer_size: 12000 |
|||
summary_freq: 1000 |
|||
time_horizon: 1000 |
|||
hidden_units: 64 |
|||
init_entcoef: 0.5 |
|||
max_steps: 5.0e5 |
|||
|
|||
3DBallHardLearning: |
|||
normalize: true |
|||
batch_size: 256 |
|||
summary_freq: 1000 |
|||
time_horizon: 1000 |
|||
max_steps: 5.0e5 |
|||
|
|||
TennisLearning: |
|||
normalize: true |
|||
max_steps: 2e5 |
|||
|
|||
CrawlerStaticLearning: |
|||
normalize: true |
|||
time_horizon: 1000 |
|||
batch_size: 256 |
|||
train_interval: 3 |
|||
buffer_size: 500000 |
|||
buffer_init_steps: 2000 |
|||
max_steps: 5e5 |
|||
summary_freq: 3000 |
|||
init_entcoef: 1.0 |
|||
num_layers: 3 |
|||
hidden_units: 512 |
|||
|
|||
CrawlerDynamicLearning: |
|||
normalize: true |
|||
time_horizon: 1000 |
|||
batch_size: 256 |
|||
buffer_size: 500000 |
|||
summary_freq: 3000 |
|||
train_interval: 3 |
|||
num_layers: 3 |
|||
max_steps: 5e5 |
|||
hidden_units: 512 |
|||
|
|||
WalkerLearning: |
|||
normalize: true |
|||
time_horizon: 1000 |
|||
batch_size: 256 |
|||
buffer_size: 500000 |
|||
max_steps: 2e6 |
|||
summary_freq: 3000 |
|||
num_layers: 3 |
|||
train_interval: 3 |
|||
hidden_units: 512 |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.995 |
|||
|
|||
ReacherLearning: |
|||
normalize: true |
|||
time_horizon: 1000 |
|||
batch_size: 128 |
|||
buffer_size: 500000 |
|||
max_steps: 1e6 |
|||
summary_freq: 3000 |
|||
|
|||
HallwayLearning: |
|||
use_recurrent: true |
|||
sequence_length: 32 |
|||
num_layers: 2 |
|||
hidden_units: 128 |
|||
memory_size: 256 |
|||
beta: 0.0 |
|||
init_entcoef: 0.1 |
|||
max_steps: 5.0e5 |
|||
summary_freq: 1000 |
|||
time_horizon: 64 |
|||
use_recurrent: true |
|||
|
|||
VisualHallwayLearning: |
|||
use_recurrent: true |
|||
sequence_length: 32 |
|||
num_layers: 1 |
|||
hidden_units: 128 |
|||
memory_size: 256 |
|||
beta: 1.0e-2 |
|||
gamma: 0.99 |
|||
batch_size: 64 |
|||
max_steps: 5.0e5 |
|||
summary_freq: 1000 |
|||
time_horizon: 64 |
|||
use_recurrent: true |
|||
|
|||
VisualPushBlockLearning: |
|||
use_recurrent: true |
|||
sequence_length: 32 |
|||
num_layers: 1 |
|||
hidden_units: 128 |
|||
memory_size: 256 |
|||
beta: 1.0e-2 |
|||
gamma: 0.99 |
|||
buffer_size: 1024 |
|||
batch_size: 64 |
|||
max_steps: 5.0e5 |
|||
summary_freq: 1000 |
|||
time_horizon: 64 |
|||
|
|||
GridWorldLearning: |
|||
batch_size: 128 |
|||
normalize: false |
|||
num_layers: 1 |
|||
hidden_units: 128 |
|||
init_entcoef: 0.01 |
|||
buffer_size: 50000 |
|||
max_steps: 5.0e5 |
|||
summary_freq: 2000 |
|||
time_horizon: 5 |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.9 |
|||
|
|||
BasicLearning: |
|||
batch_size: 64 |
|||
normalize: false |
|||
num_layers: 2 |
|||
init_entcoef: 0.01 |
|||
hidden_units: 20 |
|||
max_steps: 5.0e5 |
|||
summary_freq: 2000 |
|||
time_horizon: 10 |
|
|||
# Training with Soft-Actor Critic |
|||
|
|||
In addition to [Proximal Policy Optimization (PPO)](Training-PPO.md), ML-Agents also provides |
|||
[Soft Actor-Critic](http://bair.berkeley.edu/blog/2018/12/14/sac/) to perform |
|||
reinforcement learning. |
|||
|
|||
In contrast with PPO, SAC is _off-policy_, which means it can learn from experiences collected |
|||
at any time during the past. As experiences are collected, they are placed in an |
|||
experience replay buffer and randomly drawn during training. This makes SAC |
|||
significantly more sample-efficient, often requiring 5-10 times less samples to learn |
|||
the same task as PPO. However, SAC tends to require more model updates. SAC is a |
|||
good choice for heavier or slower environments (about 0.1 seconds per step or more). |
|||
|
|||
SAC is also a "maximum entropy" algorithm, and enables exploration in an intrinsic way. |
|||
Read more about maximum entropy RL [here](https://bair.berkeley.edu/blog/2017/10/06/soft-q-learning/). |
|||
|
|||
To train an agent, you will need to provide the agent one or more reward signals which |
|||
the agent should attempt to maximize. See [Reward Signals](Training-RewardSignals.md) |
|||
for the available reward signals and the corresponding hyperparameters. |
|||
|
|||
## Best Practices when training with SAC |
|||
|
|||
Successfully training a reinforcement learning model often involves tuning |
|||
hyperparameters. This guide contains some best practices for training |
|||
when the default parameters don't seem to be giving the level of performance |
|||
you would like. |
|||
|
|||
## Hyperparameters |
|||
|
|||
### Reward Signals |
|||
|
|||
In reinforcement learning, the goal is to learn a Policy that maximizes reward. |
|||
In the most basic case, the reward is given by the environment. However, we could imagine |
|||
rewarding the agent for various different behaviors. For instance, we could reward |
|||
the agent for exploring new states, rather than explicitly defined reward signals. |
|||
Furthermore, we could mix reward signals to help the learning process. |
|||
|
|||
`reward_signals` provides a section to define [reward signals.](Training-RewardSignals.md) |
|||
ML-Agents provides two reward signals by default, the Extrinsic (environment) reward, and the |
|||
Curiosity reward, which can be used to encourage exploration in sparse extrinsic reward |
|||
environments. |
|||
|
|||
#### Number of Updates for Reward Signal (Optional) |
|||
|
|||
`reward_signal_num_update` for the reward signals corresponds to the number of mini batches sampled |
|||
and used for updating the reward signals during each |
|||
update. By default, we update the reward signals once every time the main policy is updated. |
|||
However, to imitate the training procedure in certain imitation learning papers (e.g. |
|||
[Kostrikov et. al](http://arxiv.org/abs/1809.02925), [Blondé et. al](http://arxiv.org/abs/1809.02064)), |
|||
we may want to update the policy N times, then update the reward signal (GAIL) M times. |
|||
We can change `train_interval` and `num_update` of SAC to N, as well as `reward_signal_num_update` |
|||
under `reward_signals` to M to accomplish this. By default, `reward_signal_num_update` is set to |
|||
`num_update`. |
|||
|
|||
Typical Range: `num_update` |
|||
|
|||
### Buffer Size |
|||
|
|||
`buffer_size` corresponds the maximum number of experiences (agent observations, actions |
|||
and rewards obtained) that can be stored in the experience replay buffer. This value should be |
|||
large, on the order of thousands of times longer than your episodes, so that SAC |
|||
can learn from old as well as new experiences. It should also be much larger than |
|||
`batch_size`. |
|||
|
|||
Typical Range: `50000` - `1000000` |
|||
|
|||
### Buffer Init Steps |
|||
|
|||
`buffer_init_steps` is the number of experiences to prefill the buffer with before attempting training. |
|||
As the untrained policy is fairly random, prefilling the buffer with random actions is |
|||
useful for exploration. Typically, at least several episodes of experiences should be |
|||
prefilled. |
|||
|
|||
Typical Range: `1000` - `10000` |
|||
|
|||
### Batch Size |
|||
|
|||
`batch_size` is the number of experiences used for one iteration of a gradient |
|||
descent update. If |
|||
you are using a continuous action space, this value should be large (in the |
|||
order of 1000s). If you are using a discrete action space, this value should be |
|||
smaller (in order of 10s). |
|||
|
|||
Typical Range (Continuous): `128` - `1024` |
|||
|
|||
Typical Range (Discrete): `32` - `512` |
|||
|
|||
### Initial Entropy Coefficient |
|||
|
|||
`init_entcoef` refers to the initial entropy coefficient set at the beginning of training. In |
|||
SAC, the agent is incentivized to make its actions entropic to facilitate better exploration. |
|||
The entropy coefficient weighs the true reward with a bonus entropy reward. The entropy |
|||
coefficient is [automatically adjusted](https://arxiv.org/abs/1812.05905) to a preset target |
|||
entropy, so the `init_entcoef` only corresponds to the starting value of the entropy bonus. |
|||
Increase `init_entcoef` to explore more in the beginning, decrease to converge to a solution faster. |
|||
|
|||
Typical Range (Continuous): `0.5` - `1.0` |
|||
|
|||
Typical Range (Discrete): `0.05` - `0.5` |
|||
|
|||
### Train Interval |
|||
|
|||
`train_interval` is the number of steps taken between each agent training event. Typically, |
|||
we can train after every step, but if your environment's steps are very small and very frequent, |
|||
there may not be any new interesting information between steps, and `train_interval` can be increased. |
|||
|
|||
Typical Range: `1` - `5` |
|||
|
|||
### Number of Updates |
|||
|
|||
`num_update` corresponds to the number of mini batches sampled and used for training during each |
|||
training event. In SAC, a single "update" corresponds to grabbing a batch of size `batch_size` from the experience |
|||
replay buffer, and using this mini batch to update the models. Typically, this can be left at 1. |
|||
However, to imitate the training procedure in certain papers (e.g. |
|||
[Kostrikov et. al](http://arxiv.org/abs/1809.02925), [Blondé et. al](http://arxiv.org/abs/1809.02064)), |
|||
we may want to update N times with different mini batches before grabbing additional samples. |
|||
We can change `train_interval` and `num_update` to N to accomplish this. |
|||
|
|||
Typical Range: `1` |
|||
|
|||
### Tau |
|||
|
|||
`tau` corresponds to the magnitude of the target Q update during the SAC model update. |
|||
In SAC, there are two neural networks: the target and the policy. The target network is |
|||
used to bootstrap the policy's estimate of the future rewards at a given state, and is fixed |
|||
while the policy is being updated. This target is then slowly updated according to `tau`. |
|||
Typically, this value should be left at `0.005`. For simple problems, increasing |
|||
`tau` to `0.01` might reduce the time it takes to learn, at the cost of stability. |
|||
|
|||
Typical Range: `0.005` - `0.01` |
|||
|
|||
### Learning Rate |
|||
|
|||
`learning_rate` corresponds to the strength of each gradient descent update |
|||
step. This should typically be decreased if training is unstable, and the reward |
|||
does not consistently increase. |
|||
|
|||
Typical Range: `1e-5` - `1e-3` |
|||
|
|||
### Time Horizon |
|||
|
|||
`time_horizon` corresponds to how many steps of experience to collect per-agent |
|||
before adding it to the experience buffer. This parameter is a lot less critical |
|||
to SAC than PPO, and can typically be set to approximately your episode length. |
|||
|
|||
Typical Range: `32` - `2048` |
|||
|
|||
### Max Steps |
|||
|
|||
`max_steps` corresponds to how many steps of the simulation (multiplied by |
|||
frame-skip) are run during the training process. This value should be increased |
|||
for more complex problems. |
|||
|
|||
Typical Range: `5e5` - `1e7` |
|||
|
|||
### Normalize |
|||
|
|||
`normalize` corresponds to whether normalization is applied to the vector |
|||
observation inputs. This normalization is based on the running average and |
|||
variance of the vector observation. Normalization can be helpful in cases with |
|||
complex continuous control problems, but may be harmful with simpler discrete |
|||
control problems. |
|||
|
|||
### Number of Layers |
|||
|
|||
`num_layers` corresponds to how many hidden layers are present after the |
|||
observation input, or after the CNN encoding of the visual observation. For |
|||
simple problems, fewer layers are likely to train faster and more efficiently. |
|||
More layers may be necessary for more complex control problems. |
|||
|
|||
Typical range: `1` - `3` |
|||
|
|||
### Hidden Units |
|||
|
|||
`hidden_units` correspond to how many units are in each fully connected layer of |
|||
the neural network. For simple problems where the correct action is a |
|||
straightforward combination of the observation inputs, this should be small. For |
|||
problems where the action is a very complex interaction between the observation |
|||
variables, this should be larger. |
|||
|
|||
Typical Range: `32` - `512` |
|||
|
|||
### (Optional) Visual Encoder Type |
|||
|
|||
`vis_encode_type` corresponds to the encoder type for encoding visual observations. |
|||
Valid options include: |
|||
* `simple` (default): a simple encoder which consists of two convolutional layers |
|||
* `nature_cnn`: CNN implementation proposed by Mnih et al.(https://www.nature.com/articles/nature14236), |
|||
consisting of three convolutional layers |
|||
* `resnet`: IMPALA Resnet implementation (https://arxiv.org/abs/1802.01561), |
|||
consisting of three stacked layers, each with two risidual blocks, making a |
|||
much larger network than the other two. |
|||
|
|||
Options: `simple`, `nature_cnn`, `resnet` |
|||
|
|||
## (Optional) Recurrent Neural Network Hyperparameters |
|||
|
|||
The below hyperparameters are only used when `use_recurrent` is set to true. |
|||
|
|||
### Sequence Length |
|||
|
|||
`sequence_length` corresponds to the length of the sequences of experience |
|||
passed through the network during training. This should be long enough to |
|||
capture whatever information your agent might need to remember over time. For |
|||
example, if your agent needs to remember the velocity of objects, then this can |
|||
be a small value. If your agent needs to remember a piece of information given |
|||
only once at the beginning of an episode, then this should be a larger value. |
|||
|
|||
Typical Range: `4` - `128` |
|||
|
|||
### Memory Size |
|||
|
|||
`memory_size` corresponds to the size of the array of floating point numbers |
|||
used to store the hidden state of the recurrent neural network. This value must |
|||
be a multiple of 4, and should scale with the amount of information you expect |
|||
the agent will need to remember in order to successfully complete the task. |
|||
|
|||
Typical Range: `64` - `512` |
|||
|
|||
### (Optional) Save Replay Buffer |
|||
|
|||
`save_replay_buffer` enables you to save and load the experience replay buffer as well as |
|||
the model when quitting and re-starting training. This may help resumes go more smoothly, |
|||
as the experiences collected won't be wiped. Note that replay buffers can be very large, and |
|||
will take up a considerable amount of disk space. For that reason, we disable this feature by |
|||
default. |
|||
|
|||
Default: `False` |
|||
|
|||
## (Optional) Pretraining Using Demonstrations |
|||
|
|||
In some cases, you might want to bootstrap the agent's policy using behavior recorded |
|||
from a player. This can help guide the agent towards the reward. Pretraining adds |
|||
training operations that mimic a demonstration rather than attempting to maximize reward. |
|||
It is essentially equivalent to running [behavioral cloning](./Training-BehavioralCloning.md) |
|||
in-line with SAC. |
|||
|
|||
To use pretraining, add a `pretraining` section to the trainer_config. For instance: |
|||
|
|||
``` |
|||
pretraining: |
|||
demo_path: ./demos/ExpertPyramid.demo |
|||
strength: 0.5 |
|||
steps: 10000 |
|||
``` |
|||
|
|||
Below are the avaliable hyperparameters for pretraining. |
|||
|
|||
### Strength |
|||
|
|||
`strength` corresponds to the learning rate of the imitation relative to the learning |
|||
rate of SAC, and roughly corresponds to how strongly we allow the behavioral cloning |
|||
to influence the policy. |
|||
|
|||
Typical Range: `0.1` - `0.5` |
|||
|
|||
### Demo Path |
|||
|
|||
`demo_path` is the path to your `.demo` file or directory of `.demo` files. |
|||
See the [imitation learning guide](Training-Imitation-Learning.md) for more on `.demo` files. |
|||
|
|||
### Steps |
|||
|
|||
During pretraining, it is often desirable to stop using demonstrations after the agent has |
|||
"seen" rewards, and allow it to optimize past the available demonstrations and/or generalize |
|||
outside of the provided demonstrations. `steps` corresponds to the training steps over which |
|||
pretraining is active. The learning rate of the pretrainer will anneal over the steps. Set |
|||
the steps to 0 for constant imitation over the entire training run. |
|||
|
|||
### (Optional) Batch Size |
|||
|
|||
`batch_size` is the number of demonstration experiences used for one iteration of a gradient |
|||
descent update. If not specified, it will default to the `batch_size` defined for SAC. |
|||
|
|||
Typical Range (Continuous): `512` - `5120` |
|||
|
|||
Typical Range (Discrete): `32` - `512` |
|||
|
|||
## Training Statistics |
|||
|
|||
To view training statistics, use TensorBoard. For information on launching and |
|||
using TensorBoard, see |
|||
[here](./Getting-Started-with-Balance-Ball.md#observing-training-progress). |
|||
|
|||
### Cumulative Reward |
|||
|
|||
The general trend in reward should consistently increase over time. Small ups |
|||
and downs are to be expected. Depending on the complexity of the task, a |
|||
significant increase in reward may not present itself until millions of steps |
|||
into the training process. |
|||
|
|||
### Entropy Coefficient |
|||
|
|||
SAC is a "maximum entropy" reinforcement learning algorithm, and agents trained using |
|||
SAC are incentivized to behave randomly while also solving the problem. The entropy |
|||
coefficient balances the incentive to behave randomly vs. maximizing the reward. |
|||
This value is adjusted automatically so that the agent retains some amount of randomness during |
|||
training. It should steadily decrease in the beginning of training, and reach some small |
|||
value where it will level off. If it decreases too soon or takes too |
|||
long to decrease, `init_entcoef` should be adjusted. |
|||
|
|||
### Entropy |
|||
|
|||
This corresponds to how random the decisions of a Brain are. This should |
|||
initially increase during training, reach a peak, and should decline along |
|||
with the Entropy Coefficient. This is because in the beginning, the agent is |
|||
incentivised to be more random for exploration due to a high entropy coefficient. |
|||
If it decreases too soon or takes too long to decrease, `init_entcoef` should be adjusted. |
|||
|
|||
### Learning Rate |
|||
|
|||
This will decrease over time on a linear schedule. |
|||
|
|||
### Policy Loss |
|||
|
|||
These values may increase as the agent explores, but should decrease longterm |
|||
as the agent learns how to solve the task. |
|||
|
|||
### Value Estimate |
|||
|
|||
These values should increase as the cumulative reward increases. They correspond |
|||
to how much future reward the agent predicts itself receiving at any given |
|||
point. They may also increase at the beginning as the agent is rewarded for |
|||
being random (see: Entropy and Entropy Coefficient), but should decline as |
|||
Entropy Coefficient decreases. |
|||
|
|||
### Value Loss |
|||
|
|||
These values will increase as the reward increases, and then should decrease |
|||
once reward becomes stable. |
|
|||
import unittest.mock as mock |
|||
import pytest |
|||
import tempfile |
|||
import yaml |
|||
import math |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
|
|||
from mlagents.trainers.sac.models import SACModel |
|||
from mlagents.trainers.sac.policy import SACPolicy |
|||
from mlagents.trainers.sac.trainer import SACTrainer |
|||
from mlagents.trainers.tests.test_simple_rl import Simple1DEnvironment, SimpleEnvManager |
|||
from mlagents.trainers.trainer_util import initialize_trainers |
|||
from mlagents.envs import UnityEnvironment |
|||
from mlagents.envs.mock_communicator import MockCommunicator |
|||
from mlagents.trainers.trainer_controller import TrainerController |
|||
from mlagents.envs.base_unity_environment import BaseUnityEnvironment |
|||
from mlagents.envs import BrainInfo, AllBrainInfo, BrainParameters |
|||
from mlagents.envs.communicator_objects import AgentInfoProto |
|||
from mlagents.envs.sampler_class import SamplerManager |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.load( |
|||
""" |
|||
trainer: sac |
|||
batch_size: 32 |
|||
buffer_size: 10240 |
|||
buffer_init_steps: 0 |
|||
hidden_units: 32 |
|||
init_entcoef: 0.1 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 1024 |
|||
memory_size: 8 |
|||
normalize: false |
|||
num_update: 1 |
|||
train_interval: 1 |
|||
num_layers: 1 |
|||
time_horizon: 64 |
|||
sequence_length: 16 |
|||
summary_freq: 1000 |
|||
tau: 0.005 |
|||
use_recurrent: false |
|||
curiosity_enc_size: 128 |
|||
demo_path: None |
|||
vis_encode_type: default |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
""" |
|||
) |
|||
|
|||
|
|||
VECTOR_ACTION_SPACE = [2] |
|||
VECTOR_OBS_SPACE = 8 |
|||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|||
BUFFER_INIT_SAMPLES = 32 |
|||
NUM_AGENTS = 12 |
|||
|
|||
|
|||
def create_sac_policy_mock(mock_env, dummy_config, use_rnn, use_discrete, use_visual): |
|||
env, mock_brain, _ = mb.setup_mock_env_and_brains( |
|||
mock_env, |
|||
use_discrete, |
|||
use_visual, |
|||
num_agents=NUM_AGENTS, |
|||
vector_action_space=VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
discrete_action_space=DISCRETE_ACTION_SPACE, |
|||
) |
|||
|
|||
trainer_parameters = dummy_config |
|||
model_path = env.brain_names[0] |
|||
trainer_parameters["model_path"] = model_path |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
trainer_parameters["use_recurrent"] = use_rnn |
|||
policy = SACPolicy(0, mock_brain, trainer_parameters, False, False) |
|||
return env, policy |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_sac_cc_policy(mock_env, dummy_config): |
|||
# Test evaluate |
|||
tf.reset_default_graph() |
|||
env, policy = create_sac_policy_mock( |
|||
mock_env, dummy_config, use_rnn=False, use_discrete=False, use_visual=False |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
run_out = policy.evaluate(brain_info) |
|||
assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE[0]) |
|||
|
|||
# Test update |
|||
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES) |
|||
# Mock out reward signal eval |
|||
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"] |
|||
policy.update( |
|||
buffer.update_buffer, num_sequences=len(buffer.update_buffer["actions"]) |
|||
) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_sac_update_reward_signals(mock_env, dummy_config): |
|||
# Test evaluate |
|||
tf.reset_default_graph() |
|||
# Add a Curiosity module |
|||
dummy_config["reward_signals"]["curiosity"] = {} |
|||
dummy_config["reward_signals"]["curiosity"]["strength"] = 1.0 |
|||
dummy_config["reward_signals"]["curiosity"]["gamma"] = 0.99 |
|||
dummy_config["reward_signals"]["curiosity"]["encoding_size"] = 128 |
|||
env, policy = create_sac_policy_mock( |
|||
mock_env, dummy_config, use_rnn=False, use_discrete=False, use_visual=False |
|||
) |
|||
|
|||
# Test update |
|||
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES) |
|||
# Mock out reward signal eval |
|||
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"] |
|||
buffer.update_buffer["curiosity_rewards"] = buffer.update_buffer["rewards"] |
|||
policy.update_reward_signals( |
|||
{"curiosity": buffer.update_buffer}, |
|||
num_sequences=len(buffer.update_buffer["actions"]), |
|||
) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_sac_dc_policy(mock_env, dummy_config): |
|||
# Test evaluate |
|||
tf.reset_default_graph() |
|||
env, policy = create_sac_policy_mock( |
|||
mock_env, dummy_config, use_rnn=False, use_discrete=True, use_visual=False |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
run_out = policy.evaluate(brain_info) |
|||
assert run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|||
|
|||
# Test update |
|||
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES) |
|||
# Mock out reward signal eval |
|||
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"] |
|||
policy.update( |
|||
buffer.update_buffer, num_sequences=len(buffer.update_buffer["actions"]) |
|||
) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_sac_visual_policy(mock_env, dummy_config): |
|||
# Test evaluate |
|||
tf.reset_default_graph() |
|||
env, policy = create_sac_policy_mock( |
|||
mock_env, dummy_config, use_rnn=False, use_discrete=True, use_visual=True |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
run_out = policy.evaluate(brain_info) |
|||
assert run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|||
|
|||
# Test update |
|||
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES) |
|||
# Mock out reward signal eval |
|||
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"] |
|||
run_out = policy.update( |
|||
buffer.update_buffer, num_sequences=len(buffer.update_buffer["actions"]) |
|||
) |
|||
assert type(run_out) is dict |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment") |
|||
def test_sac_rnn_policy(mock_env, dummy_config): |
|||
# Test evaluate |
|||
tf.reset_default_graph() |
|||
env, policy = create_sac_policy_mock( |
|||
mock_env, dummy_config, use_rnn=True, use_discrete=True, use_visual=False |
|||
) |
|||
brain_infos = env.reset() |
|||
brain_info = brain_infos[env.brain_names[0]] |
|||
run_out = policy.evaluate(brain_info) |
|||
assert run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|||
|
|||
# Test update |
|||
buffer = mb.simulate_rollout(env, policy, BUFFER_INIT_SAMPLES) |
|||
# Mock out reward signal eval |
|||
buffer.update_buffer["extrinsic_rewards"] = buffer.update_buffer["rewards"] |
|||
policy.update(buffer.update_buffer, num_sequences=2) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") |
|||
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator") |
|||
def test_sac_model_cc_vector(mock_communicator, mock_launcher): |
|||
tf.reset_default_graph() |
|||
with tf.Session() as sess: |
|||
with tf.variable_scope("FakeGraphScope"): |
|||
mock_communicator.return_value = MockCommunicator( |
|||
discrete_action=False, visual_inputs=0 |
|||
) |
|||
env = UnityEnvironment(" ") |
|||
|
|||
model = SACModel(env.brains["RealFakeBrain"]) |
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
|
|||
run_list = [model.output, model.value, model.entropy, model.learning_rate] |
|||
feed_dict = { |
|||
model.batch_size: 2, |
|||
model.sequence_length: 1, |
|||
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]), |
|||
} |
|||
sess.run(run_list, feed_dict=feed_dict) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") |
|||
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator") |
|||
def test_sac_model_cc_visual(mock_communicator, mock_launcher): |
|||
tf.reset_default_graph() |
|||
with tf.Session() as sess: |
|||
with tf.variable_scope("FakeGraphScope"): |
|||
mock_communicator.return_value = MockCommunicator( |
|||
discrete_action=False, visual_inputs=2 |
|||
) |
|||
env = UnityEnvironment(" ") |
|||
|
|||
model = SACModel(env.brains["RealFakeBrain"]) |
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
|
|||
run_list = [model.output, model.value, model.entropy, model.learning_rate] |
|||
feed_dict = { |
|||
model.batch_size: 2, |
|||
model.sequence_length: 1, |
|||
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]), |
|||
model.visual_in[0]: np.ones([2, 40, 30, 3]), |
|||
model.visual_in[1]: np.ones([2, 40, 30, 3]), |
|||
} |
|||
sess.run(run_list, feed_dict=feed_dict) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") |
|||
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator") |
|||
def test_sac_model_dc_visual(mock_communicator, mock_launcher): |
|||
tf.reset_default_graph() |
|||
with tf.Session() as sess: |
|||
with tf.variable_scope("FakeGraphScope"): |
|||
mock_communicator.return_value = MockCommunicator( |
|||
discrete_action=True, visual_inputs=2 |
|||
) |
|||
env = UnityEnvironment(" ") |
|||
model = SACModel(env.brains["RealFakeBrain"]) |
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
|
|||
run_list = [model.output, model.value, model.entropy, model.learning_rate] |
|||
feed_dict = { |
|||
model.batch_size: 2, |
|||
model.sequence_length: 1, |
|||
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]), |
|||
model.visual_in[0]: np.ones([2, 40, 30, 3]), |
|||
model.visual_in[1]: np.ones([2, 40, 30, 3]), |
|||
model.action_masks: np.ones([2, 2]), |
|||
} |
|||
sess.run(run_list, feed_dict=feed_dict) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") |
|||
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator") |
|||
def test_sac_model_dc_vector(mock_communicator, mock_launcher): |
|||
tf.reset_default_graph() |
|||
with tf.Session() as sess: |
|||
with tf.variable_scope("FakeGraphScope"): |
|||
mock_communicator.return_value = MockCommunicator( |
|||
discrete_action=True, visual_inputs=0 |
|||
) |
|||
env = UnityEnvironment(" ") |
|||
model = SACModel(env.brains["RealFakeBrain"]) |
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
|
|||
run_list = [model.output, model.value, model.entropy, model.learning_rate] |
|||
feed_dict = { |
|||
model.batch_size: 2, |
|||
model.sequence_length: 1, |
|||
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]), |
|||
model.action_masks: np.ones([2, 2]), |
|||
} |
|||
sess.run(run_list, feed_dict=feed_dict) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") |
|||
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator") |
|||
def test_sac_model_dc_vector_rnn(mock_communicator, mock_launcher): |
|||
tf.reset_default_graph() |
|||
with tf.Session() as sess: |
|||
with tf.variable_scope("FakeGraphScope"): |
|||
mock_communicator.return_value = MockCommunicator( |
|||
discrete_action=True, visual_inputs=0 |
|||
) |
|||
env = UnityEnvironment(" ") |
|||
memory_size = 128 |
|||
model = SACModel( |
|||
env.brains["RealFakeBrain"], use_recurrent=True, m_size=memory_size |
|||
) |
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
|
|||
run_list = [ |
|||
model.output, |
|||
model.all_log_probs, |
|||
model.value, |
|||
model.entropy, |
|||
model.learning_rate, |
|||
model.memory_out, |
|||
] |
|||
feed_dict = { |
|||
model.batch_size: 1, |
|||
model.sequence_length: 2, |
|||
model.prev_action: [[0], [0]], |
|||
model.memory_in: np.zeros((1, memory_size)), |
|||
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]), |
|||
model.action_masks: np.ones([1, 2]), |
|||
} |
|||
sess.run(run_list, feed_dict=feed_dict) |
|||
env.close() |
|||
|
|||
|
|||
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher") |
|||
@mock.patch("mlagents.envs.UnityEnvironment.get_communicator") |
|||
def test_sac_model_cc_vector_rnn(mock_communicator, mock_launcher): |
|||
tf.reset_default_graph() |
|||
with tf.Session() as sess: |
|||
with tf.variable_scope("FakeGraphScope"): |
|||
mock_communicator.return_value = MockCommunicator( |
|||
discrete_action=False, visual_inputs=0 |
|||
) |
|||
env = UnityEnvironment(" ") |
|||
memory_size = 128 |
|||
model = SACModel( |
|||
env.brains["RealFakeBrain"], use_recurrent=True, m_size=memory_size |
|||
) |
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
|
|||
run_list = [ |
|||
model.output, |
|||
model.all_log_probs, |
|||
model.value, |
|||
model.entropy, |
|||
model.learning_rate, |
|||
model.memory_out, |
|||
] |
|||
feed_dict = { |
|||
model.batch_size: 1, |
|||
model.sequence_length: 2, |
|||
model.memory_in: np.zeros((1, memory_size)), |
|||
model.vector_in: np.array([[1, 2, 3, 1, 2, 3], [3, 4, 5, 3, 4, 5]]), |
|||
} |
|||
sess.run(run_list, feed_dict=feed_dict) |
|||
env.close() |
|||
|
|||
|
|||
def test_sac_save_load_buffer(tmpdir): |
|||
env, mock_brain, _ = mb.setup_mock_env_and_brains( |
|||
mock.Mock(), |
|||
False, |
|||
False, |
|||
num_agents=NUM_AGENTS, |
|||
vector_action_space=VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
discrete_action_space=DISCRETE_ACTION_SPACE, |
|||
) |
|||
trainer_params = dummy_config() |
|||
trainer_params["summary_path"] = str(tmpdir) |
|||
trainer_params["model_path"] = str(tmpdir) |
|||
trainer_params["save_replay_buffer"] = True |
|||
trainer = SACTrainer(mock_brain, 1, trainer_params, True, False, 0, 0) |
|||
trainer.training_buffer = mb.simulate_rollout( |
|||
env, trainer.policy, BUFFER_INIT_SAMPLES |
|||
) |
|||
buffer_len = len(trainer.training_buffer.update_buffer["actions"]) |
|||
trainer.save_model() |
|||
|
|||
# Wipe Trainer and try to load |
|||
trainer2 = SACTrainer(mock_brain, 1, trainer_params, True, True, 0, 0) |
|||
assert len(trainer2.training_buffer.update_buffer["actions"]) == buffer_len |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
from .models import * |
|||
from .trainer import * |
|||
from .policy import * |
1001
ml-agents/mlagents/trainers/sac/models.py
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
import logging |
|||
from typing import Dict, List, Any |
|||
import numpy as np |
|||
import tensorflow as tf |
|||
|
|||
from mlagents.envs.timers import timed |
|||
from mlagents.trainers import BrainInfo, ActionInfo, BrainParameters |
|||
from mlagents.trainers.sac.models import SACModel |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|||
create_reward_signal, |
|||
) |
|||
from mlagents.trainers.components.reward_signals.reward_signal import RewardSignal |
|||
from mlagents.trainers.components.bc import BCModule |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class SACPolicy(TFPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
) -> None: |
|||
""" |
|||
Policy for Proximal Policy Optimization Networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
""" |
|||
super().__init__(seed, brain, trainer_params) |
|||
|
|||
reward_signal_configs = {} |
|||
for key, rsignal in trainer_params["reward_signals"].items(): |
|||
if type(rsignal) is dict: |
|||
reward_signal_configs[key] = rsignal |
|||
|
|||
self.inference_dict: Dict[str, tf.Tensor] = {} |
|||
self.update_dict: Dict[str, tf.Tensor] = {} |
|||
self.create_model( |
|||
brain, trainer_params, reward_signal_configs, is_training, load, seed |
|||
) |
|||
self.create_reward_signals(reward_signal_configs) |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
"Losses/Q1 Loss": "q1_loss", |
|||
"Losses/Q2 Loss": "q2_loss", |
|||
"Policy/Entropy Coeff": "entropy_coef", |
|||
} |
|||
|
|||
with self.graph.as_default(): |
|||
# Create pretrainer if needed |
|||
if "pretraining" in trainer_params: |
|||
BCModule.check_config(trainer_params["pretraining"]) |
|||
self.bc_module = BCModule( |
|||
self, |
|||
policy_learning_rate=trainer_params["learning_rate"], |
|||
default_batch_size=trainer_params["batch_size"], |
|||
default_num_epoch=1, |
|||
samples_per_update=trainer_params["batch_size"], |
|||
**trainer_params["pretraining"], |
|||
) |
|||
# SAC-specific setting - we don't want to do a whole epoch each update! |
|||
if "samples_per_update" in trainer_params["pretraining"]: |
|||
logger.warning( |
|||
"Pretraining: Samples Per Update is not a valid setting for SAC." |
|||
) |
|||
self.bc_module.samples_per_update = 1 |
|||
else: |
|||
self.bc_module = None |
|||
|
|||
if load: |
|||
self._load_graph() |
|||
else: |
|||
self._initialize_graph() |
|||
self.sess.run(self.model.target_init_op) |
|||
|
|||
# Disable terminal states for certain reward signals to avoid survivor bias |
|||
for name, reward_signal in self.reward_signals.items(): |
|||
if not reward_signal.use_terminal_states: |
|||
self.sess.run(self.model.disable_use_dones[name]) |
|||
|
|||
def create_model( |
|||
self, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
reward_signal_configs: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
seed: int, |
|||
) -> None: |
|||
with self.graph.as_default(): |
|||
self.model = SACModel( |
|||
brain, |
|||
lr=float(trainer_params["learning_rate"]), |
|||
h_size=int(trainer_params["hidden_units"]), |
|||
init_entcoef=float(trainer_params["init_entcoef"]), |
|||
max_step=float(trainer_params["max_steps"]), |
|||
normalize=trainer_params["normalize"], |
|||
use_recurrent=trainer_params["use_recurrent"], |
|||
num_layers=int(trainer_params["num_layers"]), |
|||
m_size=self.m_size, |
|||
seed=seed, |
|||
stream_names=list(reward_signal_configs.keys()), |
|||
tau=float(trainer_params["tau"]), |
|||
gammas=list(_val["gamma"] for _val in reward_signal_configs.values()), |
|||
vis_encode_type=trainer_params["vis_encode_type"], |
|||
) |
|||
self.model.create_sac_optimizers() |
|||
|
|||
self.inference_dict.update( |
|||
{ |
|||
"action": self.model.output, |
|||
"log_probs": self.model.all_log_probs, |
|||
"value_heads": self.model.value_heads, |
|||
"value": self.model.value, |
|||
"entropy": self.model.entropy, |
|||
"learning_rate": self.model.learning_rate, |
|||
} |
|||
) |
|||
if self.use_continuous_act: |
|||
self.inference_dict["pre_action"] = self.model.output_pre |
|||
if self.use_recurrent: |
|||
self.inference_dict["memory_out"] = self.model.memory_out |
|||
if ( |
|||
is_training |
|||
and self.use_vec_obs |
|||
and trainer_params["normalize"] |
|||
and not load |
|||
): |
|||
self.inference_dict["update_mean"] = self.model.update_normalization |
|||
|
|||
self.update_dict.update( |
|||
{ |
|||
"value_loss": self.model.total_value_loss, |
|||
"policy_loss": self.model.policy_loss, |
|||
"q1_loss": self.model.q1_loss, |
|||
"q2_loss": self.model.q2_loss, |
|||
"entropy_coef": self.model.ent_coef, |
|||
"entropy": self.model.entropy, |
|||
"update_batch": self.model.update_batch_policy, |
|||
"update_value": self.model.update_batch_value, |
|||
"update_entropy": self.model.update_batch_entropy, |
|||
} |
|||
) |
|||
|
|||
def create_reward_signals(self, reward_signal_configs: Dict[str, Any]) -> None: |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
self.reward_signals: Dict[str, RewardSignal] = {} |
|||
with self.graph.as_default(): |
|||
# Create reward signals |
|||
for reward_signal, config in reward_signal_configs.items(): |
|||
if type(config) is dict: |
|||
self.reward_signals[reward_signal] = create_reward_signal( |
|||
self, self.model, reward_signal, config |
|||
) |
|||
|
|||
def evaluate(self, brain_info: BrainInfo) -> Dict[str, np.ndarray]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param brain_info: BrainInfo object containing inputs. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
feed_dict = { |
|||
self.model.batch_size: len(brain_info.vector_observations), |
|||
self.model.sequence_length: 1, |
|||
} |
|||
if self.use_recurrent: |
|||
if not self.use_continuous_act: |
|||
feed_dict[ |
|||
self.model.prev_action |
|||
] = brain_info.previous_vector_actions.reshape( |
|||
[-1, len(self.model.act_size)] |
|||
) |
|||
if brain_info.memories.shape[1] == 0: |
|||
brain_info.memories = self.make_empty_memory(len(brain_info.agents)) |
|||
feed_dict[self.model.memory_in] = brain_info.memories |
|||
|
|||
feed_dict = self.fill_eval_dict(feed_dict, brain_info) |
|||
run_out = self._execute_model(feed_dict, self.inference_dict) |
|||
return run_out |
|||
|
|||
@timed |
|||
def update( |
|||
self, mini_batch: Dict[str, Any], num_sequences: int, update_target: bool = True |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param mini_batch: Experience batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = self.construct_feed_dict(self.model, mini_batch, num_sequences) |
|||
stats_needed = self.stats_name_to_update_name |
|||
update_stats: Dict[str, float] = {} |
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
if update_target: |
|||
self.sess.run(self.model.target_update_op) |
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Dict[str, Dict], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Only update the reward signals. |
|||
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
""" |
|||
# Collect feed dicts for all reward signals. |
|||
feed_dict: Dict[tf.Tensor, Any] = {} |
|||
update_dict: Dict[str, tf.Tensor] = {} |
|||
update_stats: Dict[str, float] = {} |
|||
stats_needed: Dict[str, str] = {} |
|||
if reward_signal_minibatches: |
|||
self.add_reward_signal_dicts( |
|||
feed_dict, |
|||
update_dict, |
|||
stats_needed, |
|||
reward_signal_minibatches, |
|||
num_sequences, |
|||
) |
|||
update_vals = self._execute_model(feed_dict, update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def add_reward_signal_dicts( |
|||
self, |
|||
feed_dict: Dict[tf.Tensor, Any], |
|||
update_dict: Dict[str, tf.Tensor], |
|||
stats_needed: Dict[str, str], |
|||
reward_signal_minibatches: Dict[str, Dict], |
|||
num_sequences: int, |
|||
) -> None: |
|||
""" |
|||
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict. |
|||
:param feed_dict: Feed dict needed update |
|||
:param update_dit: Update dict that needs update |
|||
:param stats_needed: Stats needed to get from the update. |
|||
:param reward_signal_minibatches: Minibatches to use for updating the reward signals, |
|||
indexed by name. |
|||
""" |
|||
for name, r_mini_batch in reward_signal_minibatches.items(): |
|||
feed_dict.update( |
|||
self.reward_signals[name].prepare_update( |
|||
self.model, r_mini_batch, num_sequences |
|||
) |
|||
) |
|||
update_dict.update(self.reward_signals[name].update_dict) |
|||
stats_needed.update(self.reward_signals[name].stats_name_to_update_name) |
|||
|
|||
def construct_feed_dict( |
|||
self, model: SACModel, mini_batch: Dict[str, Any], num_sequences: int |
|||
) -> Dict[tf.Tensor, Any]: |
|||
""" |
|||
Builds the feed dict for updating the SAC model. |
|||
:param model: The model to update. May be different when, e.g. using multi-GPU. |
|||
:param mini_batch: Mini-batch to use to update. |
|||
:param num_sequences: Number of LSTM sequences in mini_batch. |
|||
""" |
|||
feed_dict = { |
|||
self.model.batch_size: num_sequences, |
|||
self.model.sequence_length: self.sequence_length, |
|||
self.model.next_sequence_length: self.sequence_length, |
|||
self.model.mask_input: mini_batch["masks"], |
|||
} |
|||
for name in self.reward_signals: |
|||
feed_dict[model.rewards_holders[name]] = mini_batch[ |
|||
"{}_rewards".format(name) |
|||
] |
|||
|
|||
if self.use_continuous_act: |
|||
feed_dict[model.action_holder] = mini_batch["actions"] |
|||
else: |
|||
feed_dict[model.action_holder] = mini_batch["actions"] |
|||
if self.use_recurrent: |
|||
feed_dict[model.prev_action] = mini_batch["prev_action"] |
|||
feed_dict[model.action_masks] = mini_batch["action_mask"] |
|||
if self.use_vec_obs: |
|||
feed_dict[model.vector_in] = mini_batch["vector_obs"] |
|||
feed_dict[model.next_vector_in] = mini_batch["next_vector_in"] |
|||
if self.model.vis_obs_size > 0: |
|||
for i, _ in enumerate(model.visual_in): |
|||
_obs = mini_batch["visual_obs%d" % i] |
|||
feed_dict[model.visual_in[i]] = _obs |
|||
for i, _ in enumerate(model.next_visual_in): |
|||
_obs = mini_batch["next_visual_obs%d" % i] |
|||
feed_dict[model.next_visual_in[i]] = _obs |
|||
if self.use_recurrent: |
|||
mem_in = [ |
|||
mini_batch["memory"][i] |
|||
for i in range(0, len(mini_batch["memory"]), self.sequence_length) |
|||
] |
|||
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|||
offset = 1 if self.sequence_length > 1 else 0 |
|||
next_mem_in = [ |
|||
mini_batch["memory"][i][ |
|||
: self.m_size // 4 |
|||
] # only pass value part of memory to target network |
|||
for i in range(offset, len(mini_batch["memory"]), self.sequence_length) |
|||
] |
|||
feed_dict[model.memory_in] = mem_in |
|||
feed_dict[model.next_memory_in] = next_mem_in |
|||
feed_dict[model.dones_holder] = mini_batch["done"] |
|||
return feed_dict |
|
|||
# # Unity ML-Agents Toolkit |
|||
# ## ML-Agent Learning (SAC) |
|||
# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290 |
|||
# and implemented in https://github.com/hill-a/stable-baselines |
|||
|
|||
import logging |
|||
from collections import deque, defaultdict |
|||
from typing import List, Any, Dict |
|||
import os |
|||
|
|||
import numpy as np |
|||
import tensorflow as tf |
|||
|
|||
from mlagents.envs import AllBrainInfo, BrainInfo |
|||
from mlagents.envs.action_info import ActionInfoOutputs |
|||
from mlagents.envs.timers import timed, hierarchical_timer |
|||
from mlagents.trainers.buffer import Buffer |
|||
from mlagents.trainers.sac.policy import SACPolicy |
|||
from mlagents.trainers.trainer import UnityTrainerException |
|||
from mlagents.trainers.rl_trainer import RLTrainer, AllRewardsOutput |
|||
from mlagents.trainers.components.reward_signals import RewardSignalResult |
|||
|
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
BUFFER_TRUNCATE_PERCENT = 0.8 |
|||
|
|||
|
|||
class SACTrainer(RLTrainer): |
|||
""" |
|||
The SACTrainer is an implementation of the SAC algorithm, with support |
|||
for discrete actions and recurrent networks. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, brain, reward_buff_cap, trainer_parameters, training, load, seed, run_id |
|||
): |
|||
""" |
|||
Responsible for collecting experiences and training SAC model. |
|||
:param trainer_parameters: The parameters for the trainer (dictionary). |
|||
:param training: Whether the trainer is set for training. |
|||
:param load: Whether the model should be loaded. |
|||
:param seed: The seed the model will be initialized with |
|||
:param run_id: The The identifier of the current run |
|||
""" |
|||
super().__init__(brain, trainer_parameters, training, run_id, reward_buff_cap) |
|||
self.param_keys = [ |
|||
"batch_size", |
|||
"buffer_size", |
|||
"buffer_init_steps", |
|||
"hidden_units", |
|||
"learning_rate", |
|||
"init_entcoef", |
|||
"max_steps", |
|||
"normalize", |
|||
"num_update", |
|||
"num_layers", |
|||
"time_horizon", |
|||
"sequence_length", |
|||
"summary_freq", |
|||
"tau", |
|||
"use_recurrent", |
|||
"summary_path", |
|||
"memory_size", |
|||
"model_path", |
|||
"reward_signals", |
|||
"vis_encode_type", |
|||
] |
|||
|
|||
self.check_param_keys() |
|||
|
|||
self.step = 0 |
|||
self.train_interval = ( |
|||
trainer_parameters["train_interval"] |
|||
if "train_interval" in trainer_parameters |
|||
else 1 |
|||
) |
|||
self.reward_signal_updates_per_train = ( |
|||
trainer_parameters["reward_signals"]["reward_signal_num_update"] |
|||
if "reward_signal_num_update" in trainer_parameters["reward_signals"] |
|||
else trainer_parameters["num_update"] |
|||
) |
|||
|
|||
self.checkpoint_replay_buffer = ( |
|||
trainer_parameters["save_replay_buffer"] |
|||
if "save_replay_buffer" in trainer_parameters |
|||
else False |
|||
) |
|||
self.policy = SACPolicy(seed, brain, trainer_parameters, self.is_training, load) |
|||
|
|||
# Load the replay buffer if load |
|||
if load and self.checkpoint_replay_buffer: |
|||
try: |
|||
self.load_replay_buffer() |
|||
except (AttributeError, FileNotFoundError): |
|||
LOGGER.warning( |
|||
"Replay buffer was unable to load, starting from scratch." |
|||
) |
|||
LOGGER.debug( |
|||
"Loaded update buffer with {} sequences".format( |
|||
len(self.training_buffer.update_buffer["actions"]) |
|||
) |
|||
) |
|||
|
|||
for _reward_signal in self.policy.reward_signals.keys(): |
|||
self.collected_rewards[_reward_signal] = {} |
|||
|
|||
self.episode_steps = {} |
|||
|
|||
def save_model(self) -> None: |
|||
""" |
|||
Saves the model. Overrides the default save_model since we want to save |
|||
the replay buffer as well. |
|||
""" |
|||
self.policy.save_model(self.get_step) |
|||
if self.checkpoint_replay_buffer: |
|||
self.save_replay_buffer() |
|||
|
|||
def save_replay_buffer(self) -> None: |
|||
""" |
|||
Save the training buffer's update buffer to a pickle file. |
|||
""" |
|||
filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5") |
|||
LOGGER.info("Saving Experience Replay Buffer to {}".format(filename)) |
|||
with open(filename, "wb") as file_object: |
|||
self.training_buffer.update_buffer.save_to_file(file_object) |
|||
|
|||
def load_replay_buffer(self) -> Buffer: |
|||
""" |
|||
Loads the last saved replay buffer from a file. |
|||
""" |
|||
filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5") |
|||
LOGGER.info("Loading Experience Replay Buffer from {}".format(filename)) |
|||
with open(filename, "rb+") as file_object: |
|||
self.training_buffer.update_buffer.load_from_file(file_object) |
|||
LOGGER.info( |
|||
"Experience replay buffer has {} experiences.".format( |
|||
len(self.training_buffer.update_buffer["actions"]) |
|||
) |
|||
) |
|||
|
|||
def add_policy_outputs( |
|||
self, take_action_outputs: ActionInfoOutputs, agent_id: str, agent_idx: int |
|||
) -> None: |
|||
""" |
|||
Takes the output of the last action and store it into the training buffer. |
|||
""" |
|||
actions = take_action_outputs["action"] |
|||
self.training_buffer[agent_id]["actions"].append(actions[agent_idx]) |
|||
|
|||
def add_rewards_outputs( |
|||
self, |
|||
rewards_out: AllRewardsOutput, |
|||
values: Dict[str, np.ndarray], |
|||
agent_id: str, |
|||
agent_idx: int, |
|||
agent_next_idx: int, |
|||
) -> None: |
|||
""" |
|||
Takes the value output of the last action and store it into the training buffer. |
|||
""" |
|||
self.training_buffer[agent_id]["environment_rewards"].append( |
|||
rewards_out.environment[agent_next_idx] |
|||
) |
|||
|
|||
def process_experiences( |
|||
self, current_info: AllBrainInfo, new_info: AllBrainInfo |
|||
) -> None: |
|||
""" |
|||
Checks agent histories for processing condition, and processes them as necessary. |
|||
:param current_info: Dictionary of all current brains and corresponding BrainInfo. |
|||
:param new_info: Dictionary of all next brains and corresponding BrainInfo. |
|||
""" |
|||
info = new_info[self.brain_name] |
|||
for l in range(len(info.agents)): |
|||
agent_actions = self.training_buffer[info.agents[l]]["actions"] |
|||
if ( |
|||
info.local_done[l] |
|||
or len(agent_actions) >= self.trainer_parameters["time_horizon"] |
|||
) and len(agent_actions) > 0: |
|||
agent_id = info.agents[l] |
|||
|
|||
# Bootstrap using last brain info. Set last element to duplicate obs and remove dones. |
|||
if info.max_reached[l]: |
|||
bootstrapping_info = self.training_buffer[agent_id].last_brain_info |
|||
idx = bootstrapping_info.agents.index(agent_id) |
|||
for i, obs in enumerate(bootstrapping_info.visual_observations): |
|||
self.training_buffer[agent_id]["next_visual_obs%d" % i][ |
|||
-1 |
|||
] = obs[idx] |
|||
if self.policy.use_vec_obs: |
|||
self.training_buffer[agent_id]["next_vector_in"][ |
|||
-1 |
|||
] = bootstrapping_info.vector_observations[idx] |
|||
self.training_buffer[agent_id]["done"][-1] = False |
|||
|
|||
self.training_buffer.append_update_buffer( |
|||
agent_id, |
|||
batch_size=None, |
|||
training_length=self.policy.sequence_length, |
|||
) |
|||
|
|||
self.training_buffer[agent_id].reset_agent() |
|||
if info.local_done[l]: |
|||
self.stats["Environment/Episode Length"].append( |
|||
self.episode_steps.get(agent_id, 0) |
|||
) |
|||
self.episode_steps[agent_id] = 0 |
|||
for name, rewards in self.collected_rewards.items(): |
|||
if name == "environment": |
|||
self.cumulative_returns_since_policy_update.append( |
|||
rewards.get(agent_id, 0) |
|||
) |
|||
self.stats["Environment/Cumulative Reward"].append( |
|||
rewards.get(agent_id, 0) |
|||
) |
|||
self.reward_buffer.appendleft(rewards.get(agent_id, 0)) |
|||
rewards[agent_id] = 0 |
|||
else: |
|||
self.stats[ |
|||
self.policy.reward_signals[name].stat_name |
|||
].append(rewards.get(agent_id, 0)) |
|||
rewards[agent_id] = 0 |
|||
|
|||
def is_ready_update(self) -> bool: |
|||
""" |
|||
Returns whether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to whether or not update_model() can be run |
|||
""" |
|||
return ( |
|||
len(self.training_buffer.update_buffer["actions"]) |
|||
>= self.trainer_parameters["batch_size"] |
|||
and self.step >= self.trainer_parameters["buffer_init_steps"] |
|||
) |
|||
|
|||
@timed |
|||
def update_policy(self) -> None: |
|||
""" |
|||
If train_interval is met, update the SAC policy given the current reward signals. |
|||
If reward_signal_train_interval is met, update the reward signals from the buffer. |
|||
""" |
|||
if self.step % self.train_interval == 0: |
|||
self.trainer_metrics.start_policy_update_timer( |
|||
number_experiences=len(self.training_buffer.update_buffer["actions"]), |
|||
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)), |
|||
) |
|||
self.update_sac_policy() |
|||
self.update_reward_signals() |
|||
self.trainer_metrics.end_policy_update() |
|||
|
|||
def update_sac_policy(self) -> None: |
|||
""" |
|||
Uses demonstration_buffer to update the policy. |
|||
The reward signal generators are updated using different mini batches. |
|||
If we want to imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated |
|||
N times, then the reward signals are updated N times, then reward_signal_updates_per_train |
|||
is greater than 1 and the reward signals are not updated in parallel. |
|||
""" |
|||
|
|||
self.cumulative_returns_since_policy_update: List[float] = [] |
|||
n_sequences = max( |
|||
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1 |
|||
) |
|||
|
|||
num_updates = self.trainer_parameters["num_update"] |
|||
batch_update_stats: Dict[str, list] = defaultdict(list) |
|||
for _ in range(num_updates): |
|||
LOGGER.debug("Updating SAC policy at step {}".format(self.step)) |
|||
buffer = self.training_buffer.update_buffer |
|||
if ( |
|||
len(self.training_buffer.update_buffer["actions"]) |
|||
>= self.trainer_parameters["batch_size"] |
|||
): |
|||
sampled_minibatch = buffer.sample_mini_batch( |
|||
self.trainer_parameters["batch_size"], |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
# Get rewards for each reward |
|||
for name, signal in self.policy.reward_signals.items(): |
|||
sampled_minibatch[ |
|||
"{}_rewards".format(name) |
|||
] = signal.evaluate_batch(sampled_minibatch).scaled_reward |
|||
|
|||
update_stats = self.policy.update( |
|||
sampled_minibatch, n_sequences, update_target=True |
|||
) |
|||
for stat_name, value in update_stats.items(): |
|||
batch_update_stats[stat_name].append(value) |
|||
|
|||
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating |
|||
# a large buffer at each update. |
|||
if ( |
|||
len(self.training_buffer.update_buffer["actions"]) |
|||
> self.trainer_parameters["buffer_size"] |
|||
): |
|||
self.training_buffer.truncate_update_buffer( |
|||
int(self.trainer_parameters["buffer_size"] * BUFFER_TRUNCATE_PERCENT) |
|||
) |
|||
|
|||
for stat, stat_list in batch_update_stats.items(): |
|||
self.stats[stat].append(np.mean(stat_list)) |
|||
|
|||
if self.policy.bc_module: |
|||
update_stats = self.policy.bc_module.update() |
|||
for stat, val in update_stats.items(): |
|||
self.stats[stat].append(val) |
|||
|
|||
def update_reward_signals(self) -> None: |
|||
""" |
|||
Iterate through the reward signals and update them. Unlike in PPO, |
|||
do it separate from the policy so that it can be done at a different |
|||
interval. |
|||
This function should only be used to simulate |
|||
http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated |
|||
N times, then the reward signals are updated N times. Normally, the reward signal |
|||
and policy are updated in parallel. |
|||
""" |
|||
buffer = self.training_buffer.update_buffer |
|||
num_updates = self.reward_signal_updates_per_train |
|||
n_sequences = max( |
|||
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1 |
|||
) |
|||
batch_update_stats: Dict[str, list] = defaultdict(list) |
|||
for _ in range(num_updates): |
|||
# Get minibatches for reward signal update if needed |
|||
reward_signal_minibatches = {} |
|||
for name, signal in self.policy.reward_signals.items(): |
|||
LOGGER.debug("Updating {} at step {}".format(name, self.step)) |
|||
# Some signals don't need a minibatch to be sampled - so we don't! |
|||
if signal.update_dict: |
|||
reward_signal_minibatches[name] = buffer.sample_mini_batch( |
|||
self.trainer_parameters["batch_size"], |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
update_stats = self.policy.update_reward_signals( |
|||
reward_signal_minibatches, n_sequences |
|||
) |
|||
for stat_name, value in update_stats.items(): |
|||
batch_update_stats[stat_name].append(value) |
|||
for stat, stat_list in batch_update_stats.items(): |
|||
self.stats[stat].append(np.mean(stat_list)) |
撰写
预览
正在加载...
取消
保存
Reference in new issue