浏览代码

Improvements for GAIL (#2296)

* Don't 0 value bootstrap for GAIL and Curiosity
* Add gradient penalties to GAN to help with stability
* Add gail_config.yaml with GAIL examples
* Cleaned up trainer_config.yaml and unnecessary gammas
* Documentation updates
* Code cleanup
/develop-generalizationTraining-TrainerController
GitHub 5 年前
当前提交
6a212f73
共有 10 个文件被更改,包括 235 次插入59 次删除
  1. 35
      config/trainer_config.yaml
  2. 62
      docs/Training-Imitation-Learning.md
  3. 6
      ml-agents/mlagents/trainers/components/bc/model.py
  4. 1
      ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
  5. 61
      ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
  6. 1
      ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
  7. 3
      ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py
  8. 12
      ml-agents/mlagents/trainers/ppo/policy.py
  9. 7
      ml-agents/mlagents/trainers/tests/test_ppo.py
  10. 106
      config/gail_config.yaml

35
config/trainer_config.yaml


time_horizon: 64
num_layers: 2
SmallWallJumpLearning:
SmallWallJumpLearning:
max_steps: 1.0e6
batch_size: 128
buffer_size: 2048

num_layers: 2
normalize: false
BigWallJumpLearning:
BigWallJumpLearning:
max_steps: 1.0e6
batch_size: 128
buffer_size: 2048

max_steps: 5.0e5
num_epoch: 3
reward_signals:
extrinsic:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:

VisualPyramidsLearning:
time_horizon: 128
batch_size: 64

max_steps: 5.0e5
num_epoch: 3
reward_signals:
extrinsic:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:

max_steps: 5.0e5
beta: 0.001
reward_signals:
extrinsic:
extrinsic:
strength: 1.0
gamma: 0.995

num_layers: 3
hidden_units: 512
reward_signals:
extrinsic:
extrinsic:
strength: 1.0
gamma: 0.995

num_layers: 3
hidden_units: 512
reward_signals:
extrinsic:
extrinsic:
strength: 1.0
gamma: 0.995

num_layers: 3
hidden_units: 512
reward_signals:
extrinsic:
extrinsic:
strength: 1.0
gamma: 0.995

time_horizon: 1000
batch_size: 2024
buffer_size: 20240
gamma: 0.995
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.995
HallwayLearning:
use_recurrent: true

hidden_units: 128
memory_size: 256
beta: 1.0e-2
gamma: 0.99
num_epoch: 3
buffer_size: 1024
batch_size: 64

hidden_units: 128
memory_size: 256
beta: 1.0e-2
gamma: 0.99
num_epoch: 3
buffer_size: 1024
batch_size: 64

num_layers: 1
hidden_units: 256
beta: 5.0e-3
gamma: 0.9
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.9
BasicLearning:
batch_size: 32

beta: 5.0e-3
gamma: 0.9
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.9

62
docs/Training-Imitation-Learning.md


Imitation Learning uses pairs of observations and actions from
from a demonstration to learn a policy. [Video Link](https://youtu.be/kpb8ZkMBFYs).
Imitation learning can also be used to help reinforcement learning. Especially in
Imitation learning can also be used to help reinforcement learning. Especially in
it is easier to just show the agent how to achieve the reward. In these cases,
it is easier to just show the agent how to achieve the reward. In these cases,
For instance, on the [Pyramids environment](Learning-Environment-Examples.md#pyramids),
For instance, on the [Pyramids environment](Learning-Environment-Examples.md#pyramids),
just 6 episodes of demonstrations can reduce training steps by more than 4 times.
<p align="center">

</p>
ML-Agents provides several ways to learn from demonstrations. For most situations,
[GAIL](Training-RewardSignals.md#the-gail-reward-signal) is the preferred approach.
ML-Agents provides several ways to learn from demonstrations.
number of demonstrations.
* To help bootstrap reinforcement learning, you can enable
[pretraining](Training-PPO.md#optional-pretraining-using-demonstrations)
on the PPO trainer, in addition to using a small GAIL reward signal.
* To train an agent to exactly mimic demonstrations, you can use the
number of demonstrations.
* To help bootstrap reinforcement learning, you can enable
[pretraining](Training-PPO.md#optional-pretraining-using-demonstrations)
on the PPO trainer, in addition to using a small GAIL reward signal.
* To train an agent to exactly mimic demonstrations, you can use the
### How to Choose
If you want to help your agents learn (especially with environments that have sparse rewards)
using pre-recorded demonstrations, you can generally enable both GAIL and Pretraining.
An example of this is provided for the Pyramids example environment under
`PyramidsLearning` in `config/gail_config.yaml`.
If you want to train purely from demonstrations, GAIL is generally the preferred approach, especially
if you have few (<10) episodes of demonstrations. An example of this is provided for the Crawler example
environment under `CrawlerStaticLearning` in `config/gail_config.yaml`.
If you have plenty of demonstrations and/or a very simple environment, Behavioral Cloning
(online and offline) can be effective and quick. However, it cannot be combined with RL.
It is possible to record demonstrations of agent behavior from the Unity Editor,
and save them as assets. These demonstrations contain information on the
observations, actions, and rewards for a given agent during the recording session.
They can be managed from the Editor, as well as used for training with Offline
It is possible to record demonstrations of agent behavior from the Unity Editor,
and save them as assets. These demonstrations contain information on the
observations, actions, and rewards for a given agent during the recording session.
They can be managed from the Editor, as well as used for training with Offline
In order to record demonstrations from an agent, add the `Demonstration Recorder`
component to a GameObject in the scene which contains an `Agent` component.
Once added, it is possible to name the demonstration that will be recorded
In order to record demonstrations from an agent, add the `Demonstration Recorder`
component to a GameObject in the scene which contains an `Agent` component.
Once added, it is possible to name the demonstration that will be recorded
from the agent.
<p align="center">

</p>
When `Record` is checked, a demonstration will be created whenever the scene
is played from the Editor. Depending on the complexity of the task, anywhere
from a few minutes or a few hours of demonstration data may be necessary to
be useful for imitation learning. When you have recorded enough data, end
the Editor play session, and a `.demo` file will be created in the
`Assets/Demonstrations` folder. This file contains the demonstrations.
Clicking on the file will provide metadata about the demonstration in the
When `Record` is checked, a demonstration will be created whenever the scene
is played from the Editor. Depending on the complexity of the task, anywhere
from a few minutes or a few hours of demonstration data may be necessary to
be useful for imitation learning. When you have recorded enough data, end
the Editor play session, and a `.demo` file will be created in the
`Assets/Demonstrations` folder. This file contains the demonstrations.
Clicking on the file will provide metadata about the demonstration in the
inspector.
<p align="center">

</p>

6
ml-agents/mlagents/trainers/components/bc/model.py


)
self.expert_action = tf.concat(
[
tf.one_hot(
self.action_in_expert[:, i], self.policy_model.act_size[i]
)
for i in range(len(self.policy_model.act_size))
tf.one_hot(self.action_in_expert[:, i], act_size)
for i, act_size in enumerate(self.policy_model.act_size)
],
axis=1,
)

1
ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py


policy.model, encoding_size=encoding_size, learning_rate=learning_rate
)
self.num_epoch = num_epoch
self.use_terminal_states = False
self.update_dict = {
"forward_loss": self.model.forward_loss,
"inverse_loss": self.model.inverse_loss,

61
ml-agents/mlagents/trainers/components/reward_signals/gail/model.py


import tensorflow as tf
from mlagents.trainers.models import LearningModel
EPSILON = 1e-7
class GAILModel(object):
def __init__(

encoding_size: int = 64,
use_actions: bool = False,
use_vail: bool = False,
gradient_penalty_weight: float = 10.0,
):
"""
The initializer for the GAIL reward generator.

self.mutual_information = 0.5
self.policy_model = policy_model
self.encoding_size = encoding_size
self.gradient_penalty_weight = gradient_penalty_weight
self.use_vail = use_vail
self.use_actions = use_actions # True # Not using actions
self.make_beta()

)
self.kl_div_input = tf.placeholder(shape=[], dtype=tf.float32)
new_beta = tf.maximum(
self.beta + self.alpha * (self.kl_div_input - self.mutual_information), 1e-7
self.beta + self.alpha * (self.kl_div_input - self.mutual_information),
EPSILON,
)
self.update_beta = tf.assign(self.beta, new_beta)

)
self.expert_action = tf.concat(
[
tf.one_hot(
self.action_in_expert[:, i], self.policy_model.act_size[i]
)
for i in range(len(self.policy_model.act_size))
tf.one_hot(self.action_in_expert[:, i], act_size)
for i, act_size in enumerate(self.policy_model.act_size)
],
axis=1,
)

def create_encoder(
self, state_in: tf.Tensor, action_in: tf.Tensor, done_in: tf.Tensor, reuse: bool
) -> Tuple[tf.Tensor, tf.Tensor]:
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""
Creates the encoder for the discriminator
:param state_in: The encoded observation input

name="d_estimate",
reuse=reuse,
)
return estimate, z_mean
return estimate, z_mean, concat_input
def create_network(self) -> None:
"""

initializer=tf.ones_initializer(),
)
self.z_sigma_sq = self.z_sigma * self.z_sigma
self.z_log_sigma_sq = tf.log(self.z_sigma_sq + 1e-7)
self.z_log_sigma_sq = tf.log(self.z_sigma_sq + EPSILON)
self.expert_estimate, self.z_mean_expert = self.create_encoder(
self.expert_estimate, self.z_mean_expert, _ = self.create_encoder(
self.policy_estimate, self.z_mean_policy = self.create_encoder(
self.policy_estimate, self.z_mean_policy, _ = self.create_encoder(
self.encoded_policy,
self.policy_model.selected_actions,
self.done_policy,

self.policy_estimate, [-1], name="GAIL_reward"
)
self.intrinsic_reward = -tf.log(1.0 - self.discriminator_score + 1e-7)
self.intrinsic_reward = -tf.log(1.0 - self.discriminator_score + EPSILON)
def create_gradient_magnitude(self) -> tf.Tensor:
"""
Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
for off-policy. Compute gradients w.r.t randomly interpolated input.
"""
expert = [self.encoded_expert, self.expert_action, self.done_expert]
policy = [
self.encoded_policy,
self.policy_model.selected_actions,
self.done_policy,
]
interp = []
for _expert_in, _policy_in in zip(expert, policy):
alpha = tf.random_uniform(tf.shape(_expert_in))
interp.append(alpha * _expert_in + (1 - alpha) * _policy_in)
grad_estimate, _, grad_input = self.create_encoder(
interp[0], interp[1], interp[2], reuse=True
)
grad = tf.gradients(grad_estimate, [grad_input])[0]
# Norm's gradient could be NaN at 0. Use our own safe_norm
safe_norm = tf.sqrt(tf.reduce_sum(grad ** 2, axis=-1) + EPSILON)
gradient_mag = tf.reduce_mean(tf.pow(safe_norm - 1, 2))
return gradient_mag
def create_loss(self, learning_rate: float) -> None:
"""

self.mean_policy_estimate = tf.reduce_mean(self.policy_estimate)
self.discriminator_loss = -tf.reduce_mean(
tf.log(self.expert_estimate + 1e-7)
+ tf.log(1.0 - self.policy_estimate + 1e-7)
tf.log(self.expert_estimate + EPSILON)
+ tf.log(1.0 - self.policy_estimate + EPSILON)
)
if self.use_vail:

)
else:
self.loss = self.discriminator_loss
if self.gradient_penalty_weight > 0.0:
self.loss += self.gradient_penalty_weight * self.create_gradient_magnitude()
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
self.update_batch = optimizer.minimize(self.loss)

1
ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py


super().__init__(policy, strength, gamma)
self.num_epoch = num_epoch
self.samples_per_update = samples_per_update
self.use_terminal_states = False
self.model = GAILModel(
policy.model, 128, learning_rate, encoding_size, use_actions, use_vail

3
ml-agents/mlagents/trainers/components/reward_signals/reward_signal.py


short_name = class_name.replace("RewardSignal", "")
self.stat_name = f"Policy/{short_name} Reward"
self.value_name = f"Policy/{short_name} Value Estimate"
# Terminate discounted reward computation at Done. Can disable to mitigate positive bias in rewards with
# no natural end, e.g. GAIL or Curiosity
self.use_terminal_states = True
self.gamma = gamma
self.policy = policy
self.strength = strength

12
ml-agents/mlagents/trainers/ppo/policy.py


:return: The value estimate dictionary with key being the name of the reward signal and the value the
corresponding value estimate.
"""
if done:
return {k: 0.0 for k in self.model.value_heads.keys()}
feed_dict: Dict[tf.Tensor, Any] = {
self.model.batch_size: 1,

].reshape([-1, len(self.model.act_size)])
value_estimates = self.sess.run(self.model.value_heads, feed_dict)
return {k: float(v) for k, v in value_estimates.items()}
value_estimates = {k: float(v) for k, v in value_estimates.items()}
# If we're done, reassign all of the value estimates that need terminal states.
if done:
for k in value_estimates:
if self.reward_signals[k].use_terminal_states:
value_estimates[k] = 0.0
return value_estimates
def get_action(self, brain_info: BrainInfo) -> ActionInfo:
"""

7
ml-agents/mlagents/trainers/tests/test_ppo.py


assert type(key) is str
assert val == 0.0
# Check if we ignore terminal states properly
policy.reward_signals["extrinsic"].use_terminal_states = False
run_out = policy.get_value_estimates(brain_info, 0, done=True)
for key, val in run_out.items():
assert type(key) is str
assert val != 0.0
env.close()

106
config/gail_config.yaml


default:
trainer: ppo
batch_size: 1024
beta: 5.0e-3
buffer_size: 10240
epsilon: 0.2
hidden_units: 128
lambd: 0.95
learning_rate: 3.0e-4
max_steps: 5.0e4
memory_size: 256
normalize: false
num_epoch: 3
num_layers: 2
time_horizon: 64
sequence_length: 64
summary_freq: 1000
use_recurrent: false
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
PyramidsLearning:
summary_freq: 2000
time_horizon: 128
batch_size: 128
buffer_size: 2048
hidden_units: 512
num_layers: 2
beta: 1.0e-2
max_steps: 5.0e5
num_epoch: 3
pretraining:
demo_path: ./demos/ExpertPyramid.demo
strength: 0.5
steps: 10000
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
curiosity:
strength: 0.02
gamma: 0.99
encoding_size: 256
gail:
strength: 0.01
gamma: 0.99
encoding_size: 128
demo_path: demos/ExpertPyramid.demo
CrawlerStaticLearning:
normalize: true
num_epoch: 3
time_horizon: 1000
batch_size: 2024
buffer_size: 20240
max_steps: 1e6
summary_freq: 3000
num_layers: 3
hidden_units: 512
reward_signals:
gail:
strength: 1.0
gamma: 0.99
encoding_size: 128
demo_path: demos/ExpertCrawlerSta.demo
PushBlockLearning:
max_steps: 5.0e4
batch_size: 128
buffer_size: 2048
beta: 1.0e-2
hidden_units: 256
summary_freq: 2000
time_horizon: 64
num_layers: 2
reward_signals:
gail:
strength: 1.0
gamma: 0.99
encoding_size: 128
demo_path: demos/ExpertPush.demo
HallwayLearning:
use_recurrent: true
sequence_length: 64
num_layers: 2
hidden_units: 128
memory_size: 256
beta: 1.0e-2
num_epoch: 3
buffer_size: 1024
batch_size: 128
max_steps: 5.0e5
summary_freq: 1000
time_horizon: 64
reward_signals:
extrinsic:
strength: 1.0
gamma: 0.99
gail:
strength: 0.1
gamma: 0.99
encoding_size: 128
demo_path: demos/ExpertHallway.demo
正在加载...
取消
保存