Andrew Cohen
5 年前
当前提交
5b0aca29
共有 72 个文件被更改,包括 3343 次插入 和 3825 次删除
-
4com.unity.ml-agents/CHANGELOG.md
-
4com.unity.ml-agents/Editor/DemonstrationImporter.cs
-
24com.unity.ml-agents/Runtime/Academy.cs
-
127com.unity.ml-agents/Runtime/Agent.cs
-
10com.unity.ml-agents/Tests/Editor/DemonstrationTests.cs
-
2config/gail_config.yaml
-
9config/sac_trainer_config.yaml
-
8config/trainer_config.yaml
-
9docs/API-Reference.md
-
1docs/Migrating.md
-
1docs/Training-ML-Agents.md
-
6docs/Training-PPO.md
-
6docs/Training-SAC.md
-
133docs/dox-ml-agents.conf
-
48ml-agents-envs/mlagents_envs/environment.py
-
3ml-agents/mlagents/trainers/agent_processor.py
-
39ml-agents/mlagents/trainers/components/bc/model.py
-
36ml-agents/mlagents/trainers/components/bc/module.py
-
19ml-agents/mlagents/trainers/components/reward_signals/__init__.py
-
74ml-agents/mlagents/trainers/components/reward_signals/curiosity/model.py
-
43ml-agents/mlagents/trainers/components/reward_signals/curiosity/signal.py
-
66ml-agents/mlagents/trainers/components/reward_signals/gail/model.py
-
37ml-agents/mlagents/trainers/components/reward_signals/gail/signal.py
-
10ml-agents/mlagents/trainers/components/reward_signals/reward_signal_factory.py
-
8ml-agents/mlagents/trainers/exception.py
-
14ml-agents/mlagents/trainers/ghost/trainer.py
-
20ml-agents/mlagents/trainers/learn.py
-
270ml-agents/mlagents/trainers/models.py
-
78ml-agents/mlagents/trainers/ppo/trainer.py
-
9ml-agents/mlagents/trainers/rl_trainer.py
-
39ml-agents/mlagents/trainers/sac/trainer.py
-
3ml-agents/mlagents/trainers/tests/mock_brain.py
-
134ml-agents/mlagents/trainers/tests/test_bcmodule.py
-
10ml-agents/mlagents/trainers/tests/test_ghost.py
-
16ml-agents/mlagents/trainers/tests/test_learn.py
-
2ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
-
16ml-agents/mlagents/trainers/tests/test_policy.py
-
403ml-agents/mlagents/trainers/tests/test_ppo.py
-
66ml-agents/mlagents/trainers/tests/test_reward_signals.py
-
257ml-agents/mlagents/trainers/tests/test_sac.py
-
22ml-agents/mlagents/trainers/tests/test_trainer_util.py
-
221ml-agents/mlagents/trainers/tf_policy.py
-
10ml-agents/mlagents/trainers/trainer.py
-
5ml-agents/mlagents/trainers/trainer_util.py
-
2com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs.meta
-
2com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta
-
6com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs
-
2com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta
-
8com.unity.ml-agents/Runtime/Demonstrations.meta
-
352ml-agents/mlagents/trainers/ppo/optimizer.py
-
447ml-agents/mlagents/trainers/sac/network.py
-
643ml-agents/mlagents/trainers/sac/optimizer.py
-
189ml-agents/mlagents/trainers/tests/test_nn_policy.py
-
179com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs
-
0ml-agents/mlagents/trainers/common/__init__.py
-
393ml-agents/mlagents/trainers/common/nn_policy.py
-
21ml-agents/mlagents/trainers/common/optimizer.py
-
156ml-agents/mlagents/trainers/common/tf_optimizer.py
-
179com.unity.ml-agents/Runtime/DemonstrationRecorder.cs
-
382ml-agents/mlagents/trainers/ppo/models.py
-
219ml-agents/mlagents/trainers/ppo/multi_gpu_policy.py
-
227ml-agents/mlagents/trainers/ppo/policy.py
-
1001ml-agents/mlagents/trainers/sac/models.py
-
315ml-agents/mlagents/trainers/sac/policy.py
-
123ml-agents/mlagents/trainers/tests/test_multigpu.py
-
0/com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs.meta
-
0/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationRecorder.cs.meta
-
0/com.unity.ml-agents/Runtime/Demonstrations/Demonstration.cs
-
0/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs
-
0/com.unity.ml-agents/Runtime/Demonstrations/DemonstrationWriter.cs.meta
|
|||
fileFormatVersion: 2 |
|||
guid: 85e02c21d231b4f5fa0c5f87e5f907a2 |
|||
folderAsset: yes |
|||
DefaultImporter: |
|||
externalObjects: {} |
|||
userData: |
|||
assetBundleName: |
|||
assetBundleVariant: |
|
|||
import logging |
|||
from typing import Optional, Any, Dict |
|||
|
|||
import numpy as np |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.models import ModelUtils, EncoderType, LearningRateSchedule |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.common.tf_optimizer import TFOptimizer |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
|
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class PPOOptimizer(TFOptimizer): |
|||
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|||
""" |
|||
Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. |
|||
The PPO optimizer has a value estimator and a loss function. |
|||
:param policy: A TFPolicy object that will be updated by this PPO Optimizer. |
|||
:param trainer_params: Trainer parameters dictionary that specifies the properties of the trainer. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
policy.create_tf_graph() |
|||
|
|||
with policy.graph.as_default(): |
|||
with tf.variable_scope("optimizer/"): |
|||
super().__init__(policy, trainer_params) |
|||
|
|||
lr = float(trainer_params["learning_rate"]) |
|||
lr_schedule = LearningRateSchedule( |
|||
trainer_params.get("learning_rate_schedule", "linear") |
|||
) |
|||
h_size = int(trainer_params["hidden_units"]) |
|||
epsilon = float(trainer_params["epsilon"]) |
|||
beta = float(trainer_params["beta"]) |
|||
max_step = float(trainer_params["max_steps"]) |
|||
num_layers = int(trainer_params["num_layers"]) |
|||
vis_encode_type = EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
) |
|||
self.burn_in_ratio = float(trainer_params.get("burn_in_ratio", 0.0)) |
|||
|
|||
self.stream_names = list(self.reward_signals.keys()) |
|||
|
|||
self.tf_optimizer: Optional[tf.train.AdamOptimizer] = None |
|||
self.grads = None |
|||
self.update_batch: Optional[tf.Operation] = None |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
"Policy/Learning Rate": "learning_rate", |
|||
} |
|||
if self.policy.use_recurrent: |
|||
self.m_size = self.policy.m_size |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], |
|||
dtype=tf.float32, |
|||
name="recurrent_value_in", |
|||
) |
|||
|
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
if policy.use_continuous_act: |
|||
self._create_cc_critic(h_size, num_layers, vis_encode_type) |
|||
else: |
|||
self._create_dc_critic(h_size, num_layers, vis_encode_type) |
|||
|
|||
self.learning_rate = ModelUtils.create_learning_rate( |
|||
lr_schedule, lr, self.policy.global_step, int(max_step) |
|||
) |
|||
self._create_losses( |
|||
self.policy.log_probs, |
|||
self.old_log_probs, |
|||
self.value_heads, |
|||
self.policy.entropy, |
|||
beta, |
|||
epsilon, |
|||
lr, |
|||
max_step, |
|||
) |
|||
self._create_ppo_optimizer_ops() |
|||
|
|||
self.update_dict.update( |
|||
{ |
|||
"value_loss": self.value_loss, |
|||
"policy_loss": self.abs_policy_loss, |
|||
"update_batch": self.update_batch, |
|||
"learning_rate": self.learning_rate, |
|||
} |
|||
) |
|||
|
|||
self.policy.initialize_or_load() |
|||
|
|||
def _create_cc_critic( |
|||
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|||
) -> None: |
|||
""" |
|||
Creates Continuous control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: The type of visual encoder to use. |
|||
""" |
|||
hidden_stream = ModelUtils.create_observation_streams( |
|||
self.policy.visual_in, |
|||
self.policy.processed_vector_in, |
|||
1, |
|||
h_size, |
|||
num_layers, |
|||
vis_encode_type, |
|||
)[0] |
|||
|
|||
if self.policy.use_recurrent: |
|||
hidden_value, memory_value_out = ModelUtils.create_recurrent_encoder( |
|||
hidden_stream, |
|||
self.memory_in, |
|||
self.policy.sequence_length_ph, |
|||
name="lstm_value", |
|||
) |
|||
self.memory_out = memory_value_out |
|||
else: |
|||
hidden_value = hidden_stream |
|||
|
|||
self.value_heads, self.value = ModelUtils.create_value_heads( |
|||
self.stream_names, hidden_value |
|||
) |
|||
self.all_old_log_probs = tf.placeholder( |
|||
shape=[None, 1], dtype=tf.float32, name="old_probabilities" |
|||
) |
|||
|
|||
self.old_log_probs = tf.reduce_sum( |
|||
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True |
|||
) |
|||
|
|||
def _create_dc_critic( |
|||
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|||
) -> None: |
|||
""" |
|||
Creates Discrete control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: The type of visual encoder to use. |
|||
""" |
|||
hidden_stream = ModelUtils.create_observation_streams( |
|||
self.policy.visual_in, |
|||
self.policy.processed_vector_in, |
|||
1, |
|||
h_size, |
|||
num_layers, |
|||
vis_encode_type, |
|||
)[0] |
|||
|
|||
if self.policy.use_recurrent: |
|||
hidden_value, memory_value_out = ModelUtils.create_recurrent_encoder( |
|||
hidden_stream, |
|||
self.memory_in, |
|||
self.policy.sequence_length_ph, |
|||
name="lstm_value", |
|||
) |
|||
self.memory_out = memory_value_out |
|||
else: |
|||
hidden_value = hidden_stream |
|||
|
|||
self.value_heads, self.value = ModelUtils.create_value_heads( |
|||
self.stream_names, hidden_value |
|||
) |
|||
|
|||
self.all_old_log_probs = tf.placeholder( |
|||
shape=[None, sum(self.policy.act_size)], |
|||
dtype=tf.float32, |
|||
name="old_probabilities", |
|||
) |
|||
_, _, old_normalized_logits = ModelUtils.create_discrete_action_masking_layer( |
|||
self.all_old_log_probs, self.policy.action_masks, self.policy.act_size |
|||
) |
|||
|
|||
action_idx = [0] + list(np.cumsum(self.policy.act_size)) |
|||
|
|||
self.old_log_probs = tf.reduce_sum( |
|||
( |
|||
tf.stack( |
|||
[ |
|||
-tf.nn.softmax_cross_entropy_with_logits_v2( |
|||
labels=self.policy.action_oh[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
logits=old_normalized_logits[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
) |
|||
for i in range(len(self.policy.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
), |
|||
axis=1, |
|||
keepdims=True, |
|||
) |
|||
|
|||
def _create_losses( |
|||
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step |
|||
): |
|||
""" |
|||
Creates training-specific Tensorflow ops for PPO models. |
|||
:param probs: Current policy probabilities |
|||
:param old_probs: Past policy probabilities |
|||
:param value_heads: Value estimate tensors from each value stream |
|||
:param beta: Entropy regularization strength |
|||
:param entropy: Current policy entropy |
|||
:param epsilon: Value for policy-divergence threshold |
|||
:param lr: Learning rate |
|||
:param max_step: Total number of training steps. |
|||
""" |
|||
self.returns_holders = {} |
|||
self.old_values = {} |
|||
for name in value_heads.keys(): |
|||
returns_holder = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="{}_returns".format(name) |
|||
) |
|||
old_value = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="{}_value_estimate".format(name) |
|||
) |
|||
self.returns_holders[name] = returns_holder |
|||
self.old_values[name] = old_value |
|||
self.advantage = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="advantages" |
|||
) |
|||
advantage = tf.expand_dims(self.advantage, -1) |
|||
|
|||
decay_epsilon = tf.train.polynomial_decay( |
|||
epsilon, self.policy.global_step, max_step, 0.1, power=1.0 |
|||
) |
|||
decay_beta = tf.train.polynomial_decay( |
|||
beta, self.policy.global_step, max_step, 1e-5, power=1.0 |
|||
) |
|||
|
|||
value_losses = [] |
|||
for name, head in value_heads.items(): |
|||
clipped_value_estimate = self.old_values[name] + tf.clip_by_value( |
|||
tf.reduce_sum(head, axis=1) - self.old_values[name], |
|||
-decay_epsilon, |
|||
decay_epsilon, |
|||
) |
|||
v_opt_a = tf.squared_difference( |
|||
self.returns_holders[name], tf.reduce_sum(head, axis=1) |
|||
) |
|||
v_opt_b = tf.squared_difference( |
|||
self.returns_holders[name], clipped_value_estimate |
|||
) |
|||
value_loss = tf.reduce_mean( |
|||
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.policy.mask, 2)[ |
|||
1 |
|||
] |
|||
) |
|||
value_losses.append(value_loss) |
|||
self.value_loss = tf.reduce_mean(value_losses) |
|||
|
|||
r_theta = tf.exp(probs - old_probs) |
|||
p_opt_a = r_theta * advantage |
|||
p_opt_b = ( |
|||
tf.clip_by_value(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) |
|||
* advantage |
|||
) |
|||
self.policy_loss = -tf.reduce_mean( |
|||
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.policy.mask, 2)[1] |
|||
) |
|||
# For cleaner stats reporting |
|||
self.abs_policy_loss = tf.abs(self.policy_loss) |
|||
|
|||
self.loss = ( |
|||
self.policy_loss |
|||
+ 0.5 * self.value_loss |
|||
- decay_beta |
|||
* tf.reduce_mean(tf.dynamic_partition(entropy, self.policy.mask, 2)[1]) |
|||
) |
|||
|
|||
def _create_ppo_optimizer_ops(self): |
|||
self.tf_optimizer = self.create_optimizer_op(self.learning_rate) |
|||
self.grads = self.tf_optimizer.compute_gradients(self.loss) |
|||
self.update_batch = self.tf_optimizer.minimize(self.loss) |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Performs update on model. |
|||
:param mini_batch: Batch of experiences. |
|||
:param num_sequences: Number of sequences to process. |
|||
:return: Results of update. |
|||
""" |
|||
feed_dict = self._construct_feed_dict(batch, num_sequences) |
|||
stats_needed = self.stats_name_to_update_name |
|||
update_stats = {} |
|||
# Collect feed dicts for all reward signals. |
|||
for _, reward_signal in self.reward_signals.items(): |
|||
feed_dict.update( |
|||
reward_signal.prepare_update(self.policy, batch, num_sequences) |
|||
) |
|||
stats_needed.update(reward_signal.stats_name_to_update_name) |
|||
|
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def _construct_feed_dict( |
|||
self, mini_batch: AgentBuffer, num_sequences: int |
|||
) -> Dict[tf.Tensor, Any]: |
|||
# Do an optional burn-in for memories |
|||
num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length) |
|||
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32) |
|||
burn_in_mask[range(0, num_burn_in)] = 0 |
|||
burn_in_mask = np.tile(burn_in_mask, num_sequences) |
|||
feed_dict = { |
|||
self.policy.batch_size_ph: num_sequences, |
|||
self.policy.sequence_length_ph: self.policy.sequence_length, |
|||
self.policy.mask_input: mini_batch["masks"] * burn_in_mask, |
|||
self.advantage: mini_batch["advantages"], |
|||
self.all_old_log_probs: mini_batch["action_probs"], |
|||
} |
|||
for name in self.reward_signals: |
|||
feed_dict[self.returns_holders[name]] = mini_batch[ |
|||
"{}_returns".format(name) |
|||
] |
|||
feed_dict[self.old_values[name]] = mini_batch[ |
|||
"{}_value_estimates".format(name) |
|||
] |
|||
|
|||
if self.policy.output_pre is not None and "actions_pre" in mini_batch: |
|||
feed_dict[self.policy.output_pre] = mini_batch["actions_pre"] |
|||
else: |
|||
feed_dict[self.policy.action_holder] = mini_batch["actions"] |
|||
if self.policy.use_recurrent: |
|||
feed_dict[self.policy.prev_action] = mini_batch["prev_action"] |
|||
feed_dict[self.policy.action_masks] = mini_batch["action_mask"] |
|||
if "vector_obs" in mini_batch: |
|||
feed_dict[self.policy.vector_in] = mini_batch["vector_obs"] |
|||
if self.policy.vis_obs_size > 0: |
|||
for i, _ in enumerate(self.policy.visual_in): |
|||
feed_dict[self.policy.visual_in[i]] = mini_batch["visual_obs%d" % i] |
|||
if self.policy.use_recurrent: |
|||
feed_dict[self.policy.memory_in] = [ |
|||
mini_batch["memory"][i] |
|||
for i in range( |
|||
0, len(mini_batch["memory"]), self.policy.sequence_length |
|||
) |
|||
] |
|||
feed_dict[self.memory_in] = self._make_zero_mem( |
|||
self.m_size, mini_batch.num_experiences |
|||
) |
|||
return feed_dict |
|
|||
import logging |
|||
from typing import Dict, Optional |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from mlagents.trainers.models import ModelUtils, EncoderType |
|||
|
|||
LOG_STD_MAX = 2 |
|||
LOG_STD_MIN = -20 |
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
DISCRETE_TARGET_ENTROPY_SCALE = 0.2 # Roughly equal to e-greedy 0.05 |
|||
CONTINUOUS_TARGET_ENTROPY_SCALE = 1.0 # TODO: Make these an optional hyperparam. |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
|
|||
POLICY_SCOPE = "" |
|||
TARGET_SCOPE = "target_network" |
|||
|
|||
|
|||
class SACNetwork: |
|||
""" |
|||
Base class for an SAC network. Implements methods for creating the actor and critic heads. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
policy=None, |
|||
m_size=None, |
|||
h_size=128, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
self.normalize = normalize |
|||
self.use_recurrent = use_recurrent |
|||
self.num_layers = num_layers |
|||
self.stream_names = stream_names |
|||
self.h_size = h_size |
|||
self.activ_fn = ModelUtils.swish |
|||
|
|||
self.sequence_length_ph = tf.placeholder( |
|||
shape=None, dtype=tf.int32, name="sac_sequence_length" |
|||
) |
|||
|
|||
self.policy_memory_in: Optional[tf.Tensor] = None |
|||
self.policy_memory_out: Optional[tf.Tensor] = None |
|||
self.value_memory_in: Optional[tf.Tensor] = None |
|||
self.value_memory_out: Optional[tf.Tensor] = None |
|||
self.q1: Optional[tf.Tensor] = None |
|||
self.q2: Optional[tf.Tensor] = None |
|||
self.q1_p: Optional[tf.Tensor] = None |
|||
self.q2_p: Optional[tf.Tensor] = None |
|||
self.q1_memory_in: Optional[tf.Tensor] = None |
|||
self.q2_memory_in: Optional[tf.Tensor] = None |
|||
self.q1_memory_out: Optional[tf.Tensor] = None |
|||
self.q2_memory_out: Optional[tf.Tensor] = None |
|||
self.prev_action: Optional[tf.Tensor] = None |
|||
self.action_masks: Optional[tf.Tensor] = None |
|||
self.external_action_in: Optional[tf.Tensor] = None |
|||
self.log_sigma_sq: Optional[tf.Tensor] = None |
|||
self.entropy: Optional[tf.Tensor] = None |
|||
self.deterministic_output: Optional[tf.Tensor] = None |
|||
self.normalized_logprobs: Optional[tf.Tensor] = None |
|||
self.action_probs: Optional[tf.Tensor] = None |
|||
self.output_oh: Optional[tf.Tensor] = None |
|||
self.output_pre: Optional[tf.Tensor] = None |
|||
|
|||
self.value_vars = None |
|||
self.q_vars = None |
|||
self.critic_vars = None |
|||
self.policy_vars = None |
|||
|
|||
self.q1_heads: Dict[str, tf.Tensor] = None |
|||
self.q2_heads: Dict[str, tf.Tensor] = None |
|||
self.q1_pheads: Dict[str, tf.Tensor] = None |
|||
self.q2_pheads: Dict[str, tf.Tensor] = None |
|||
|
|||
self.policy = policy |
|||
|
|||
def get_vars(self, scope): |
|||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) |
|||
|
|||
def join_scopes(self, scope_1, scope_2): |
|||
""" |
|||
Joins two scopes. Does so safetly (i.e., if one of the two scopes doesn't |
|||
exist, don't add any backslashes) |
|||
""" |
|||
if not scope_1: |
|||
return scope_2 |
|||
if not scope_2: |
|||
return scope_1 |
|||
else: |
|||
return "/".join(filter(None, [scope_1, scope_2])) |
|||
|
|||
def create_value_heads(self, stream_names, hidden_input): |
|||
""" |
|||
Creates one value estimator head for each reward signal in stream_names. |
|||
Also creates the node corresponding to the mean of all the value heads in self.value. |
|||
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal. |
|||
:param stream_names: The list of reward signal names |
|||
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top |
|||
of the hidden input. |
|||
""" |
|||
self.value_heads = {} |
|||
for name in stream_names: |
|||
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name)) |
|||
self.value_heads[name] = value |
|||
self.value = tf.reduce_mean(list(self.value_heads.values()), 0) |
|||
|
|||
def _create_cc_critic(self, hidden_value, scope, create_qs=True): |
|||
""" |
|||
Creates just the critic network |
|||
""" |
|||
scope = self.join_scopes(scope, "critic") |
|||
self.create_sac_value_head( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "value"), |
|||
) |
|||
|
|||
self.value_vars = self.get_vars(self.join_scopes(scope, "value")) |
|||
if create_qs: |
|||
hidden_q = tf.concat([hidden_value, self.policy.action_holder], axis=-1) |
|||
hidden_qp = tf.concat([hidden_value, self.policy.output], axis=-1) |
|||
self.q1_heads, self.q2_heads, self.q1, self.q2 = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_q, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
) |
|||
self.q1_pheads, self.q2_pheads, self.q1_p, self.q2_p = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_qp, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
reuse=True, |
|||
) |
|||
self.q_vars = self.get_vars(self.join_scopes(scope, "q")) |
|||
self.critic_vars = self.get_vars(scope) |
|||
|
|||
def _create_dc_critic(self, hidden_value, scope, create_qs=True): |
|||
""" |
|||
Creates just the critic network |
|||
""" |
|||
scope = self.join_scopes(scope, "critic") |
|||
self.create_sac_value_head( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "value"), |
|||
) |
|||
|
|||
self.value_vars = self.get_vars("/".join([scope, "value"])) |
|||
|
|||
if create_qs: |
|||
self.q1_heads, self.q2_heads, self.q1, self.q2 = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
num_outputs=sum(self.policy.act_size), |
|||
) |
|||
self.q1_pheads, self.q2_pheads, self.q1_p, self.q2_p = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
reuse=True, |
|||
num_outputs=sum(self.policy.act_size), |
|||
) |
|||
self.q_vars = self.get_vars(scope) |
|||
self.critic_vars = self.get_vars(scope) |
|||
|
|||
def create_sac_value_head( |
|||
self, stream_names, hidden_input, num_layers, h_size, scope |
|||
): |
|||
""" |
|||
Creates one value estimator head for each reward signal in stream_names. |
|||
Also creates the node corresponding to the mean of all the value heads in self.value. |
|||
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal. |
|||
:param stream_names: The list of reward signal names |
|||
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top |
|||
of the hidden input. |
|||
:param num_layers: Number of hidden layers for value network |
|||
:param h_size: size of hidden layers for value network |
|||
:param scope: TF scope for value network. |
|||
""" |
|||
with tf.variable_scope(scope): |
|||
value_hidden = ModelUtils.create_vector_observation_encoder( |
|||
hidden_input, h_size, self.activ_fn, num_layers, "encoder", False |
|||
) |
|||
if self.use_recurrent: |
|||
value_hidden, memory_out = ModelUtils.create_recurrent_encoder( |
|||
value_hidden, |
|||
self.value_memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_value", |
|||
) |
|||
self.value_memory_out = memory_out |
|||
self.create_value_heads(stream_names, value_hidden) |
|||
|
|||
def create_q_heads( |
|||
self, |
|||
stream_names, |
|||
hidden_input, |
|||
num_layers, |
|||
h_size, |
|||
scope, |
|||
reuse=False, |
|||
num_outputs=1, |
|||
): |
|||
""" |
|||
Creates two q heads for each reward signal in stream_names. |
|||
Also creates the node corresponding to the mean of all the value heads in self.value. |
|||
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal. |
|||
:param stream_names: The list of reward signal names |
|||
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top |
|||
of the hidden input. |
|||
:param num_layers: Number of hidden layers for Q network |
|||
:param h_size: size of hidden layers for Q network |
|||
:param scope: TF scope for Q network. |
|||
:param reuse: Whether or not to reuse variables. Useful for creating Q of policy. |
|||
:param num_outputs: Number of outputs of each Q function. If discrete, equal to number of actions. |
|||
""" |
|||
with tf.variable_scope(self.join_scopes(scope, "q1_encoding"), reuse=reuse): |
|||
q1_hidden = ModelUtils.create_vector_observation_encoder( |
|||
hidden_input, h_size, self.activ_fn, num_layers, "q1_encoder", reuse |
|||
) |
|||
if self.use_recurrent: |
|||
q1_hidden, memory_out = ModelUtils.create_recurrent_encoder( |
|||
q1_hidden, |
|||
self.q1_memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_q1", |
|||
) |
|||
self.q1_memory_out = memory_out |
|||
|
|||
q1_heads = {} |
|||
for name in stream_names: |
|||
_q1 = tf.layers.dense(q1_hidden, num_outputs, name="{}_q1".format(name)) |
|||
q1_heads[name] = _q1 |
|||
|
|||
q1 = tf.reduce_mean(list(q1_heads.values()), axis=0) |
|||
with tf.variable_scope(self.join_scopes(scope, "q2_encoding"), reuse=reuse): |
|||
q2_hidden = ModelUtils.create_vector_observation_encoder( |
|||
hidden_input, h_size, self.activ_fn, num_layers, "q2_encoder", reuse |
|||
) |
|||
if self.use_recurrent: |
|||
q2_hidden, memory_out = ModelUtils.create_recurrent_encoder( |
|||
q2_hidden, |
|||
self.q2_memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_q2", |
|||
) |
|||
self.q2_memory_out = memory_out |
|||
|
|||
q2_heads = {} |
|||
for name in stream_names: |
|||
_q2 = tf.layers.dense(q2_hidden, num_outputs, name="{}_q2".format(name)) |
|||
q2_heads[name] = _q2 |
|||
|
|||
q2 = tf.reduce_mean(list(q2_heads.values()), axis=0) |
|||
|
|||
return q1_heads, q2_heads, q1, q2 |
|||
|
|||
|
|||
class SACTargetNetwork(SACNetwork): |
|||
""" |
|||
Instantiation for the SAC target network. Only contains a single |
|||
value estimator and is updated from the Policy Network. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
policy, |
|||
m_size=None, |
|||
h_size=128, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
super().__init__( |
|||
policy, |
|||
m_size, |
|||
h_size, |
|||
normalize, |
|||
use_recurrent, |
|||
num_layers, |
|||
stream_names, |
|||
vis_encode_type, |
|||
) |
|||
with tf.variable_scope(TARGET_SCOPE): |
|||
self.visual_in = ModelUtils.create_visual_input_placeholders( |
|||
policy.brain.camera_resolutions |
|||
) |
|||
self.vector_in = ModelUtils.create_vector_input(policy.vec_obs_size) |
|||
if self.policy.normalize: |
|||
normalization_tensors = ModelUtils.create_normalizer(self.vector_in) |
|||
self.update_normalization_op = normalization_tensors.update_op |
|||
self.normalization_steps = normalization_tensors.steps |
|||
self.running_mean = normalization_tensors.running_mean |
|||
self.running_variance = normalization_tensors.running_variance |
|||
self.processed_vector_in = ModelUtils.normalize_vector_obs( |
|||
self.vector_in, |
|||
self.running_mean, |
|||
self.running_variance, |
|||
self.normalization_steps, |
|||
) |
|||
else: |
|||
self.processed_vector_in = self.vector_in |
|||
self.update_normalization_op = None |
|||
|
|||
if self.policy.use_recurrent: |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in" |
|||
) |
|||
self.value_memory_in = self.memory_in |
|||
hidden_streams = ModelUtils.create_observation_streams( |
|||
self.visual_in, |
|||
self.processed_vector_in, |
|||
1, |
|||
self.h_size, |
|||
0, |
|||
vis_encode_type=vis_encode_type, |
|||
stream_scopes=["critic/value/"], |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
self._create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|||
else: |
|||
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|||
if self.use_recurrent: |
|||
self.memory_out = tf.concat( |
|||
self.value_memory_out, axis=1 |
|||
) # Needed for Barracuda to work |
|||
|
|||
def copy_normalization(self, mean, variance, steps): |
|||
""" |
|||
Copies the mean, variance, and steps into the normalizers of the |
|||
input of this SACNetwork. Used to copy the normalizer from the policy network |
|||
to the target network. |
|||
param mean: Tensor containing the mean. |
|||
param variance: Tensor containing the variance |
|||
param steps: Tensor containing the number of steps. |
|||
""" |
|||
update_mean = tf.assign(self.running_mean, mean) |
|||
update_variance = tf.assign(self.running_variance, variance) |
|||
update_norm_step = tf.assign(self.normalization_steps, steps) |
|||
return tf.group([update_mean, update_variance, update_norm_step]) |
|||
|
|||
|
|||
class SACPolicyNetwork(SACNetwork): |
|||
""" |
|||
Instantiation for SAC policy network. Contains a dual Q estimator, |
|||
a value estimator, and a reference to the actual policy network. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
policy, |
|||
m_size=None, |
|||
h_size=128, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
super().__init__( |
|||
policy, |
|||
m_size, |
|||
h_size, |
|||
normalize, |
|||
use_recurrent, |
|||
num_layers, |
|||
stream_names, |
|||
vis_encode_type, |
|||
) |
|||
if self.policy.use_recurrent: |
|||
self._create_memory_ins(m_size) |
|||
|
|||
hidden_critic = self._create_observation_in(vis_encode_type) |
|||
self.policy.output = self.policy.output |
|||
# Use the sequence length of the policy |
|||
self.sequence_length_ph = self.policy.sequence_length_ph |
|||
|
|||
if self.policy.use_continuous_act: |
|||
self._create_cc_critic(hidden_critic, POLICY_SCOPE) |
|||
|
|||
else: |
|||
self._create_dc_critic(hidden_critic, POLICY_SCOPE) |
|||
|
|||
if self.use_recurrent: |
|||
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out] |
|||
self.memory_out = tf.concat(mem_outs, axis=1) |
|||
|
|||
def _create_memory_ins(self, m_size): |
|||
""" |
|||
Creates the memory input placeholders for LSTM. |
|||
:param m_size: the total size of the memory. |
|||
""" |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, m_size * 3], dtype=tf.float32, name="value_recurrent_in" |
|||
) |
|||
|
|||
# Re-break-up for each network |
|||
num_mems = 3 |
|||
input_size = self.memory_in.get_shape().as_list()[1] |
|||
mem_ins = [] |
|||
for i in range(num_mems): |
|||
_start = input_size // num_mems * i |
|||
_end = input_size // num_mems * (i + 1) |
|||
mem_ins.append(self.memory_in[:, _start:_end]) |
|||
self.value_memory_in = mem_ins[0] |
|||
self.q1_memory_in = mem_ins[1] |
|||
self.q2_memory_in = mem_ins[2] |
|||
|
|||
def _create_observation_in(self, vis_encode_type): |
|||
""" |
|||
Creates the observation inputs, and a CNN if needed, |
|||
:param vis_encode_type: Type of CNN encoder. |
|||
:param share_ac_cnn: Whether or not to share the actor and critic CNNs. |
|||
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|||
once and thrown away. |
|||
""" |
|||
with tf.variable_scope(POLICY_SCOPE): |
|||
hidden_streams = ModelUtils.create_observation_streams( |
|||
self.policy.visual_in, |
|||
self.policy.processed_vector_in, |
|||
1, |
|||
self.h_size, |
|||
0, |
|||
vis_encode_type=vis_encode_type, |
|||
stream_scopes=["critic/value/"], |
|||
) |
|||
hidden_critic = hidden_streams[0] |
|||
return hidden_critic |
|
|||
import logging |
|||
import numpy as np |
|||
from typing import Dict, List, Optional, Any, Mapping |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork |
|||
from mlagents.trainers.models import LearningRateSchedule, EncoderType, ModelUtils |
|||
from mlagents.trainers.common.tf_optimizer import TFOptimizer |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.timers import timed |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
|
|||
POLICY_SCOPE = "" |
|||
TARGET_SCOPE = "target_network" |
|||
|
|||
|
|||
class SACOptimizer(TFOptimizer): |
|||
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|||
""" |
|||
Takes a Unity environment and model-specific hyper-parameters and returns the |
|||
appropriate PPO agent model for the environment. |
|||
:param brain: Brain parameters used to generate specific network graph. |
|||
:param lr: Learning rate. |
|||
:param lr_schedule: Learning rate decay schedule. |
|||
:param h_size: Size of hidden layers |
|||
:param init_entcoef: Initial value for entropy coefficient. Set lower to learn faster, |
|||
set higher to explore more. |
|||
:return: a sub-class of PPOAgent tailored to the environment. |
|||
:param max_step: Total number of training steps. |
|||
:param normalize: Whether to normalize vector observation input. |
|||
:param use_recurrent: Whether to use an LSTM layer in the network. |
|||
:param num_layers: Number of hidden layers between encoded input and policy & value layers |
|||
:param tau: Strength of soft-Q update. |
|||
:param m_size: Size of brain memory. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
policy.create_tf_graph() |
|||
|
|||
with policy.graph.as_default(): |
|||
with tf.variable_scope(""): |
|||
super().__init__(policy, trainer_params) |
|||
lr = float(trainer_params["learning_rate"]) |
|||
lr_schedule = LearningRateSchedule( |
|||
trainer_params.get("learning_rate_schedule", "constant") |
|||
) |
|||
self.policy = policy |
|||
self.act_size = self.policy.act_size |
|||
h_size = int(trainer_params["hidden_units"]) |
|||
max_step = float(trainer_params["max_steps"]) |
|||
num_layers = int(trainer_params["num_layers"]) |
|||
vis_encode_type = EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
) |
|||
self.tau = trainer_params.get("tau", 0.005) |
|||
self.burn_in_ratio = float(trainer_params.get("burn_in_ratio", 0.0)) |
|||
|
|||
# Non-exposed SAC parameters |
|||
self.discrete_target_entropy_scale = ( |
|||
0.2 |
|||
) # Roughly equal to e-greedy 0.05 |
|||
self.continuous_target_entropy_scale = 1.0 |
|||
|
|||
self.init_entcoef = trainer_params.get("init_entcoef", 1.0) |
|||
stream_names = list(self.reward_signals.keys()) |
|||
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|||
self.gammas = [ |
|||
_val["gamma"] for _val in trainer_params["reward_signals"].values() |
|||
] |
|||
self.use_dones_in_backup = { |
|||
name: tf.Variable(1.0) for name in stream_names |
|||
} |
|||
self.disable_use_dones = { |
|||
name: self.use_dones_in_backup[name].assign(0.0) |
|||
for name in stream_names |
|||
} |
|||
|
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
|
|||
self.target_init_op: List[tf.Tensor] = [] |
|||
self.target_update_op: List[tf.Tensor] = [] |
|||
self.update_batch_policy: Optional[tf.Operation] = None |
|||
self.update_batch_value: Optional[tf.Operation] = None |
|||
self.update_batch_entropy: Optional[tf.Operation] = None |
|||
|
|||
self.policy_network = SACPolicyNetwork( |
|||
policy=self.policy, |
|||
m_size=self.policy.m_size, # 3x policy.m_size |
|||
h_size=h_size, |
|||
normalize=self.policy.normalize, |
|||
use_recurrent=self.policy.use_recurrent, |
|||
num_layers=num_layers, |
|||
stream_names=stream_names, |
|||
vis_encode_type=vis_encode_type, |
|||
) |
|||
self.target_network = SACTargetNetwork( |
|||
policy=self.policy, |
|||
m_size=self.policy.m_size, # 1x policy.m_size |
|||
h_size=h_size, |
|||
normalize=self.policy.normalize, |
|||
use_recurrent=self.policy.use_recurrent, |
|||
num_layers=num_layers, |
|||
stream_names=stream_names, |
|||
vis_encode_type=vis_encode_type, |
|||
) |
|||
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value) |
|||
self.m_size = 3 * self.policy.m_size |
|||
self._create_inputs_and_outputs() |
|||
self.learning_rate = ModelUtils.create_learning_rate( |
|||
lr_schedule, lr, self.policy.global_step, int(max_step) |
|||
) |
|||
self._create_losses( |
|||
self.policy_network.q1_heads, |
|||
self.policy_network.q2_heads, |
|||
lr, |
|||
int(max_step), |
|||
stream_names, |
|||
discrete=not self.policy.use_continuous_act, |
|||
) |
|||
self._create_sac_optimizer_ops() |
|||
|
|||
self.selected_actions = ( |
|||
self.policy.selected_actions |
|||
) # For GAIL and other reward signals |
|||
if self.policy.normalize: |
|||
target_update_norm = self.target_network.copy_normalization( |
|||
self.policy.running_mean, |
|||
self.policy.running_variance, |
|||
self.policy.normalization_steps, |
|||
) |
|||
# Update the normalization of the optimizer when the policy does. |
|||
self.policy.update_normalization_op = tf.group( |
|||
[self.policy.update_normalization_op, target_update_norm] |
|||
) |
|||
|
|||
self.policy.initialize_or_load() |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
"Losses/Q1 Loss": "q1_loss", |
|||
"Losses/Q2 Loss": "q2_loss", |
|||
"Policy/Entropy Coeff": "entropy_coef", |
|||
"Policy/Learning Rate": "learning_rate", |
|||
} |
|||
|
|||
self.update_dict = { |
|||
"value_loss": self.total_value_loss, |
|||
"policy_loss": self.policy_loss, |
|||
"q1_loss": self.q1_loss, |
|||
"q2_loss": self.q2_loss, |
|||
"entropy_coef": self.ent_coef, |
|||
"entropy": self.policy.entropy, |
|||
"update_batch": self.update_batch_policy, |
|||
"update_value": self.update_batch_value, |
|||
"update_entropy": self.update_batch_entropy, |
|||
"learning_rate": self.learning_rate, |
|||
} |
|||
|
|||
def _create_inputs_and_outputs(self) -> None: |
|||
""" |
|||
Assign the higher-level SACModel's inputs and outputs to those of its policy or |
|||
target network. |
|||
""" |
|||
self.vector_in = self.policy.vector_in |
|||
self.visual_in = self.policy.visual_in |
|||
self.next_vector_in = self.target_network.vector_in |
|||
self.next_visual_in = self.target_network.visual_in |
|||
self.action_holder = self.policy.action_holder |
|||
self.sequence_length_ph = self.policy.sequence_length_ph |
|||
self.next_sequence_length_ph = self.target_network.sequence_length_ph |
|||
if not self.policy.use_continuous_act: |
|||
self.action_masks = self.policy_network.action_masks |
|||
else: |
|||
self.output_pre = self.policy_network.output_pre |
|||
|
|||
# Don't use value estimate during inference. TODO: Check why PPO uses value_estimate in inference. |
|||
self.value = tf.identity( |
|||
self.policy_network.value, name="value_estimate_unused" |
|||
) |
|||
self.value_heads = self.policy_network.value_heads |
|||
self.dones_holder = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="dones_holder" |
|||
) |
|||
|
|||
if self.policy.use_recurrent: |
|||
self.memory_in = self.policy_network.memory_in |
|||
self.memory_out = self.policy_network.memory_out |
|||
if not self.policy.use_continuous_act: |
|||
self.prev_action = self.policy_network.prev_action |
|||
self.next_memory_in = self.target_network.memory_in |
|||
|
|||
def _create_losses( |
|||
self, |
|||
q1_streams: Dict[str, tf.Tensor], |
|||
q2_streams: Dict[str, tf.Tensor], |
|||
lr: tf.Tensor, |
|||
max_step: int, |
|||
stream_names: List[str], |
|||
discrete: bool = False, |
|||
) -> None: |
|||
""" |
|||
Creates training-specific Tensorflow ops for SAC models. |
|||
:param q1_streams: Q1 streams from policy network |
|||
:param q1_streams: Q2 streams from policy network |
|||
:param lr: Learning rate |
|||
:param max_step: Total number of training steps. |
|||
:param stream_names: List of reward stream names. |
|||
:param discrete: Whether or not to use discrete action losses. |
|||
""" |
|||
|
|||
if discrete: |
|||
self.target_entropy = [ |
|||
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|||
for i in self.act_size |
|||
] |
|||
discrete_action_probs = tf.exp(self.policy.all_log_probs) |
|||
per_action_entropy = discrete_action_probs * self.policy.all_log_probs |
|||
else: |
|||
self.target_entropy = ( |
|||
-1 |
|||
* self.continuous_target_entropy_scale |
|||
* np.prod(self.act_size[0]).astype(np.float32) |
|||
) |
|||
|
|||
self.rewards_holders = {} |
|||
self.min_policy_qs = {} |
|||
|
|||
for name in stream_names: |
|||
if discrete: |
|||
_branched_mpq1 = self._apply_as_branches( |
|||
self.policy_network.q1_pheads[name] * discrete_action_probs |
|||
) |
|||
branched_mpq1 = tf.stack( |
|||
[ |
|||
tf.reduce_sum(_br, axis=1, keep_dims=True) |
|||
for _br in _branched_mpq1 |
|||
] |
|||
) |
|||
_q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0) |
|||
|
|||
_branched_mpq2 = self._apply_as_branches( |
|||
self.policy_network.q2_pheads[name] * discrete_action_probs |
|||
) |
|||
branched_mpq2 = tf.stack( |
|||
[ |
|||
tf.reduce_sum(_br, axis=1, keep_dims=True) |
|||
for _br in _branched_mpq2 |
|||
] |
|||
) |
|||
_q2_p_mean = tf.reduce_mean(branched_mpq2, axis=0) |
|||
|
|||
self.min_policy_qs[name] = tf.minimum(_q1_p_mean, _q2_p_mean) |
|||
else: |
|||
self.min_policy_qs[name] = tf.minimum( |
|||
self.policy_network.q1_pheads[name], |
|||
self.policy_network.q2_pheads[name], |
|||
) |
|||
|
|||
rewards_holder = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="{}_rewards".format(name) |
|||
) |
|||
self.rewards_holders[name] = rewards_holder |
|||
|
|||
q1_losses = [] |
|||
q2_losses = [] |
|||
# Multiple q losses per stream |
|||
expanded_dones = tf.expand_dims(self.dones_holder, axis=-1) |
|||
for i, name in enumerate(stream_names): |
|||
_expanded_rewards = tf.expand_dims(self.rewards_holders[name], axis=-1) |
|||
|
|||
q_backup = tf.stop_gradient( |
|||
_expanded_rewards |
|||
+ (1.0 - self.use_dones_in_backup[name] * expanded_dones) |
|||
* self.gammas[i] |
|||
* self.target_network.value_heads[name] |
|||
) |
|||
|
|||
if discrete: |
|||
# We need to break up the Q functions by branch, and update them individually. |
|||
branched_q1_stream = self._apply_as_branches( |
|||
self.policy.action_oh * q1_streams[name] |
|||
) |
|||
branched_q2_stream = self._apply_as_branches( |
|||
self.policy.action_oh * q2_streams[name] |
|||
) |
|||
|
|||
# Reduce each branch into scalar |
|||
branched_q1_stream = [ |
|||
tf.reduce_sum(_branch, axis=1, keep_dims=True) |
|||
for _branch in branched_q1_stream |
|||
] |
|||
branched_q2_stream = [ |
|||
tf.reduce_sum(_branch, axis=1, keep_dims=True) |
|||
for _branch in branched_q2_stream |
|||
] |
|||
|
|||
q1_stream = tf.reduce_mean(branched_q1_stream, axis=0) |
|||
q2_stream = tf.reduce_mean(branched_q2_stream, axis=0) |
|||
|
|||
else: |
|||
q1_stream = q1_streams[name] |
|||
q2_stream = q2_streams[name] |
|||
|
|||
_q1_loss = 0.5 * tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference(q_backup, q1_stream) |
|||
) |
|||
|
|||
_q2_loss = 0.5 * tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference(q_backup, q2_stream) |
|||
) |
|||
|
|||
q1_losses.append(_q1_loss) |
|||
q2_losses.append(_q2_loss) |
|||
|
|||
self.q1_loss = tf.reduce_mean(q1_losses) |
|||
self.q2_loss = tf.reduce_mean(q2_losses) |
|||
|
|||
# Learn entropy coefficient |
|||
if discrete: |
|||
# Create a log_ent_coef for each branch |
|||
self.log_ent_coef = tf.get_variable( |
|||
"log_ent_coef", |
|||
dtype=tf.float32, |
|||
initializer=np.log([self.init_entcoef] * len(self.act_size)).astype( |
|||
np.float32 |
|||
), |
|||
trainable=True, |
|||
) |
|||
else: |
|||
self.log_ent_coef = tf.get_variable( |
|||
"log_ent_coef", |
|||
dtype=tf.float32, |
|||
initializer=np.log(self.init_entcoef).astype(np.float32), |
|||
trainable=True, |
|||
) |
|||
|
|||
self.ent_coef = tf.exp(self.log_ent_coef) |
|||
if discrete: |
|||
# We also have to do a different entropy and target_entropy per branch. |
|||
branched_per_action_ent = self._apply_as_branches(per_action_entropy) |
|||
branched_ent_sums = tf.stack( |
|||
[ |
|||
tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te |
|||
for _lp, _te in zip(branched_per_action_ent, self.target_entropy) |
|||
], |
|||
axis=1, |
|||
) |
|||
self.entropy_loss = -tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.reduce_mean( |
|||
self.log_ent_coef |
|||
* tf.squeeze(tf.stop_gradient(branched_ent_sums), axis=2), |
|||
axis=1, |
|||
) |
|||
) |
|||
|
|||
# Same with policy loss, we have to do the loss per branch and average them, |
|||
# so that larger branches don't get more weight. |
|||
# The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q |
|||
branched_q_term = self._apply_as_branches( |
|||
discrete_action_probs * self.policy_network.q1_p |
|||
) |
|||
|
|||
branched_policy_loss = tf.stack( |
|||
[ |
|||
tf.reduce_sum(self.ent_coef[i] * _lp - _qt, axis=1, keep_dims=True) |
|||
for i, (_lp, _qt) in enumerate( |
|||
zip(branched_per_action_ent, branched_q_term) |
|||
) |
|||
] |
|||
) |
|||
self.policy_loss = tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) * tf.squeeze(branched_policy_loss) |
|||
) |
|||
|
|||
# Do vbackup entropy bonus per branch as well. |
|||
branched_ent_bonus = tf.stack( |
|||
[ |
|||
tf.reduce_sum(self.ent_coef[i] * _lp, axis=1, keep_dims=True) |
|||
for i, _lp in enumerate(branched_per_action_ent) |
|||
] |
|||
) |
|||
value_losses = [] |
|||
for name in stream_names: |
|||
v_backup = tf.stop_gradient( |
|||
self.min_policy_qs[name] |
|||
- tf.reduce_mean(branched_ent_bonus, axis=0) |
|||
) |
|||
value_losses.append( |
|||
0.5 |
|||
* tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference( |
|||
self.policy_network.value_heads[name], v_backup |
|||
) |
|||
) |
|||
) |
|||
|
|||
else: |
|||
self.entropy_loss = -tf.reduce_mean( |
|||
self.log_ent_coef |
|||
* tf.to_float(self.policy.mask) |
|||
* tf.stop_gradient( |
|||
tf.reduce_sum( |
|||
self.policy.all_log_probs + self.target_entropy, |
|||
axis=1, |
|||
keep_dims=True, |
|||
) |
|||
) |
|||
) |
|||
batch_policy_loss = tf.reduce_mean( |
|||
self.ent_coef * self.policy.all_log_probs - self.policy_network.q1_p, |
|||
axis=1, |
|||
) |
|||
self.policy_loss = tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) * batch_policy_loss |
|||
) |
|||
|
|||
value_losses = [] |
|||
for name in stream_names: |
|||
v_backup = tf.stop_gradient( |
|||
self.min_policy_qs[name] |
|||
- tf.reduce_sum(self.ent_coef * self.policy.all_log_probs, axis=1) |
|||
) |
|||
value_losses.append( |
|||
0.5 |
|||
* tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference( |
|||
self.policy_network.value_heads[name], v_backup |
|||
) |
|||
) |
|||
) |
|||
self.value_loss = tf.reduce_mean(value_losses) |
|||
|
|||
self.total_value_loss = self.q1_loss + self.q2_loss + self.value_loss |
|||
|
|||
self.entropy = self.policy_network.entropy |
|||
|
|||
def _apply_as_branches(self, concat_logits: tf.Tensor) -> List[tf.Tensor]: |
|||
""" |
|||
Takes in a concatenated set of logits and breaks it up into a list of non-concatenated logits, one per |
|||
action branch |
|||
""" |
|||
action_idx = [0] + list(np.cumsum(self.act_size)) |
|||
branches_logits = [ |
|||
concat_logits[:, action_idx[i] : action_idx[i + 1]] |
|||
for i in range(len(self.act_size)) |
|||
] |
|||
return branches_logits |
|||
|
|||
def _create_sac_optimizer_ops(self) -> None: |
|||
""" |
|||
Creates the Adam optimizers and update ops for SAC, including |
|||
the policy, value, and entropy updates, as well as the target network update. |
|||
""" |
|||
policy_optimizer = self.create_optimizer_op( |
|||
learning_rate=self.learning_rate, name="sac_policy_opt" |
|||
) |
|||
entropy_optimizer = self.create_optimizer_op( |
|||
learning_rate=self.learning_rate, name="sac_entropy_opt" |
|||
) |
|||
value_optimizer = self.create_optimizer_op( |
|||
learning_rate=self.learning_rate, name="sac_value_opt" |
|||
) |
|||
|
|||
self.target_update_op = [ |
|||
tf.assign(target, (1 - self.tau) * target + self.tau * source) |
|||
for target, source in zip( |
|||
self.target_network.value_vars, self.policy_network.value_vars |
|||
) |
|||
] |
|||
LOGGER.debug("value_vars") |
|||
self.print_all_vars(self.policy_network.value_vars) |
|||
LOGGER.debug("targvalue_vars") |
|||
self.print_all_vars(self.target_network.value_vars) |
|||
LOGGER.debug("critic_vars") |
|||
self.print_all_vars(self.policy_network.critic_vars) |
|||
LOGGER.debug("q_vars") |
|||
self.print_all_vars(self.policy_network.q_vars) |
|||
LOGGER.debug("policy_vars") |
|||
policy_vars = self.policy.get_trainable_variables() |
|||
self.print_all_vars(policy_vars) |
|||
|
|||
self.target_init_op = [ |
|||
tf.assign(target, source) |
|||
for target, source in zip( |
|||
self.target_network.value_vars, self.policy_network.value_vars |
|||
) |
|||
] |
|||
|
|||
self.update_batch_policy = policy_optimizer.minimize( |
|||
self.policy_loss, var_list=policy_vars |
|||
) |
|||
|
|||
# Make sure policy is updated first, then value, then entropy. |
|||
with tf.control_dependencies([self.update_batch_policy]): |
|||
self.update_batch_value = value_optimizer.minimize( |
|||
self.total_value_loss, var_list=self.policy_network.critic_vars |
|||
) |
|||
# Add entropy coefficient optimization operation |
|||
with tf.control_dependencies([self.update_batch_value]): |
|||
self.update_batch_entropy = entropy_optimizer.minimize( |
|||
self.entropy_loss, var_list=self.log_ent_coef |
|||
) |
|||
|
|||
def print_all_vars(self, variables): |
|||
for _var in variables: |
|||
LOGGER.debug(_var) |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param batch: Experience mini-batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = self._construct_feed_dict(self.policy, batch, num_sequences) |
|||
stats_needed = self.stats_name_to_update_name |
|||
update_stats: Dict[str, float] = {} |
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
# Update target network. By default, target update happens at every policy update. |
|||
self.sess.run(self.target_update_op) |
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Mapping[str, Dict], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Only update the reward signals. |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
""" |
|||
# Collect feed dicts for all reward signals. |
|||
feed_dict: Dict[tf.Tensor, Any] = {} |
|||
update_dict: Dict[str, tf.Tensor] = {} |
|||
update_stats: Dict[str, float] = {} |
|||
stats_needed: Dict[str, str] = {} |
|||
if reward_signal_minibatches: |
|||
self.add_reward_signal_dicts( |
|||
feed_dict, |
|||
update_dict, |
|||
stats_needed, |
|||
reward_signal_minibatches, |
|||
num_sequences, |
|||
) |
|||
update_vals = self._execute_model(feed_dict, update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def add_reward_signal_dicts( |
|||
self, |
|||
feed_dict: Dict[tf.Tensor, Any], |
|||
update_dict: Dict[str, tf.Tensor], |
|||
stats_needed: Dict[str, str], |
|||
reward_signal_minibatches: Mapping[str, Dict], |
|||
num_sequences: int, |
|||
) -> None: |
|||
""" |
|||
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict. |
|||
:param feed_dict: Feed dict needed update |
|||
:param update_dit: Update dict that needs update |
|||
:param stats_needed: Stats needed to get from the update. |
|||
:param reward_signal_minibatches: Minibatches to use for updating the reward signals, |
|||
indexed by name. |
|||
""" |
|||
for name, r_batch in reward_signal_minibatches.items(): |
|||
feed_dict.update( |
|||
self.reward_signals[name].prepare_update( |
|||
self.policy, r_batch, num_sequences |
|||
) |
|||
) |
|||
update_dict.update(self.reward_signals[name].update_dict) |
|||
stats_needed.update(self.reward_signals[name].stats_name_to_update_name) |
|||
|
|||
def _construct_feed_dict( |
|||
self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int |
|||
) -> Dict[tf.Tensor, Any]: |
|||
""" |
|||
Builds the feed dict for updating the SAC model. |
|||
:param model: The model to update. May be different when, e.g. using multi-GPU. |
|||
:param batch: Mini-batch to use to update. |
|||
:param num_sequences: Number of LSTM sequences in batch. |
|||
""" |
|||
# Do an optional burn-in for memories |
|||
num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length) |
|||
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32) |
|||
burn_in_mask[range(0, num_burn_in)] = 0 |
|||
burn_in_mask = np.tile(burn_in_mask, num_sequences) |
|||
feed_dict = { |
|||
policy.batch_size_ph: num_sequences, |
|||
policy.sequence_length_ph: self.policy.sequence_length, |
|||
self.next_sequence_length_ph: self.policy.sequence_length, |
|||
self.policy.mask_input: batch["masks"] * burn_in_mask, |
|||
} |
|||
for name in self.reward_signals: |
|||
feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)] |
|||
|
|||
if self.policy.use_continuous_act: |
|||
feed_dict[policy.action_holder] = batch["actions"] |
|||
else: |
|||
feed_dict[policy.action_holder] = batch["actions"] |
|||
if self.policy.use_recurrent: |
|||
feed_dict[policy.prev_action] = batch["prev_action"] |
|||
feed_dict[policy.action_masks] = batch["action_mask"] |
|||
if self.policy.use_vec_obs: |
|||
feed_dict[policy.vector_in] = batch["vector_obs"] |
|||
feed_dict[self.next_vector_in] = batch["next_vector_in"] |
|||
if self.policy.vis_obs_size > 0: |
|||
for i, _ in enumerate(policy.visual_in): |
|||
_obs = batch["visual_obs%d" % i] |
|||
feed_dict[policy.visual_in[i]] = _obs |
|||
for i, _ in enumerate(self.next_visual_in): |
|||
_obs = batch["next_visual_obs%d" % i] |
|||
feed_dict[self.next_visual_in[i]] = _obs |
|||
if self.policy.use_recurrent: |
|||
feed_dict[policy.memory_in] = [ |
|||
batch["memory"][i] |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
feed_dict[self.policy_network.memory_in] = self._make_zero_mem( |
|||
self.m_size, batch.num_experiences |
|||
) |
|||
feed_dict[self.target_network.memory_in] = self._make_zero_mem( |
|||
self.m_size // 3, batch.num_experiences |
|||
) |
|||
feed_dict[self.dones_holder] = batch["done"] |
|||
return feed_dict |
|
|||
import pytest |
|||
|
|||
import numpy as np |
|||
from mlagents.tf_utils import tf |
|||
|
|||
import yaml |
|||
|
|||
from mlagents.trainers.common.nn_policy import NNPolicy |
|||
from mlagents.trainers.models import EncoderType, ModelUtils |
|||
from mlagents.trainers.exception import UnityTrainerException |
|||
from mlagents.trainers.brain import BrainParameters, CameraResolution |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.safe_load( |
|||
""" |
|||
trainer: ppo |
|||
batch_size: 32 |
|||
beta: 5.0e-3 |
|||
buffer_size: 512 |
|||
epsilon: 0.2 |
|||
hidden_units: 128 |
|||
lambd: 0.95 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 5.0e4 |
|||
normalize: true |
|||
num_epoch: 5 |
|||
num_layers: 2 |
|||
time_horizon: 64 |
|||
sequence_length: 64 |
|||
summary_freq: 1000 |
|||
use_recurrent: false |
|||
normalize: true |
|||
memory_size: 8 |
|||
curiosity_strength: 0.0 |
|||
curiosity_enc_size: 1 |
|||
summary_path: test |
|||
model_path: test |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
""" |
|||
) |
|||
|
|||
|
|||
VECTOR_ACTION_SPACE = [2] |
|||
VECTOR_OBS_SPACE = 8 |
|||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|||
BUFFER_INIT_SAMPLES = 32 |
|||
NUM_AGENTS = 12 |
|||
|
|||
|
|||
def create_policy_mock(dummy_config, use_rnn, use_discrete, use_visual): |
|||
mock_brain = mb.setup_mock_brain( |
|||
use_discrete, |
|||
use_visual, |
|||
vector_action_space=VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
discrete_action_space=DISCRETE_ACTION_SPACE, |
|||
) |
|||
|
|||
trainer_parameters = dummy_config |
|||
model_path = "testmodel" |
|||
trainer_parameters["model_path"] = model_path |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
trainer_parameters["use_recurrent"] = use_rnn |
|||
policy = NNPolicy(0, mock_brain, trainer_parameters, False, False) |
|||
return policy |
|||
|
|||
|
|||
@pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) |
|||
@pytest.mark.parametrize("visual", [True, False], ids=["visual", "vector"]) |
|||
@pytest.mark.parametrize("rnn", [True, False], ids=["rnn", "no_rnn"]) |
|||
def test_policy_evaluate(dummy_config, rnn, visual, discrete): |
|||
# Test evaluate |
|||
tf.reset_default_graph() |
|||
policy = create_policy_mock( |
|||
dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual |
|||
) |
|||
step = mb.create_batchedstep_from_brainparams(policy.brain, num_agents=NUM_AGENTS) |
|||
|
|||
run_out = policy.evaluate(step, list(step.agent_id)) |
|||
if discrete: |
|||
run_out["action"].shape == (NUM_AGENTS, len(DISCRETE_ACTION_SPACE)) |
|||
else: |
|||
assert run_out["action"].shape == (NUM_AGENTS, VECTOR_ACTION_SPACE[0]) |
|||
|
|||
|
|||
def test_normalization(dummy_config): |
|||
brain_params = BrainParameters( |
|||
brain_name="test_brain", |
|||
vector_observation_space_size=1, |
|||
camera_resolutions=[], |
|||
vector_action_space_size=[2], |
|||
vector_action_descriptions=[], |
|||
vector_action_space_type=0, |
|||
) |
|||
dummy_config["summary_path"] = "./summaries/test_trainer_summary" |
|||
dummy_config["model_path"] = "./models/test_trainer_models/TestModel" |
|||
|
|||
time_horizon = 6 |
|||
trajectory = make_fake_trajectory( |
|||
length=time_horizon, |
|||
max_step_complete=True, |
|||
vec_obs_size=1, |
|||
num_vis_obs=0, |
|||
action_space=[2], |
|||
) |
|||
# Change half of the obs to 0 |
|||
for i in range(3): |
|||
trajectory.steps[i].obs[0] = np.zeros(1, dtype=np.float32) |
|||
policy = policy = NNPolicy(0, brain_params, dummy_config, False, False) |
|||
|
|||
trajectory_buffer = trajectory.to_agentbuffer() |
|||
policy.update_normalization(trajectory_buffer["vector_obs"]) |
|||
|
|||
# Check that the running mean and variance is correct |
|||
steps, mean, variance = policy.sess.run( |
|||
[policy.normalization_steps, policy.running_mean, policy.running_variance] |
|||
) |
|||
|
|||
assert steps == 6 |
|||
assert mean[0] == 0.5 |
|||
# Note: variance is divided by number of steps, and initialized to 1 to avoid |
|||
# divide by 0. The right answer is 0.25 |
|||
assert (variance[0] - 1) / steps == 0.25 |
|||
|
|||
# Make another update, this time with all 1's |
|||
time_horizon = 10 |
|||
trajectory = make_fake_trajectory( |
|||
length=time_horizon, |
|||
max_step_complete=True, |
|||
vec_obs_size=1, |
|||
num_vis_obs=0, |
|||
action_space=[2], |
|||
) |
|||
trajectory_buffer = trajectory.to_agentbuffer() |
|||
policy.update_normalization(trajectory_buffer["vector_obs"]) |
|||
|
|||
# Check that the running mean and variance is correct |
|||
steps, mean, variance = policy.sess.run( |
|||
[policy.normalization_steps, policy.running_mean, policy.running_variance] |
|||
) |
|||
|
|||
assert steps == 16 |
|||
assert mean[0] == 0.8125 |
|||
assert (variance[0] - 1) / steps == pytest.approx(0.152, abs=0.01) |
|||
|
|||
|
|||
def test_min_visual_size(): |
|||
# Make sure each EncoderType has an entry in MIS_RESOLUTION_FOR_ENCODER |
|||
assert set(ModelUtils.MIN_RESOLUTION_FOR_ENCODER.keys()) == set(EncoderType) |
|||
|
|||
for encoder_type in EncoderType: |
|||
with tf.Graph().as_default(): |
|||
good_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] |
|||
good_res = CameraResolution( |
|||
width=good_size, height=good_size, num_channels=3 |
|||
) |
|||
vis_input = ModelUtils.create_visual_input(good_res, "test_min_visual_size") |
|||
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type) |
|||
enc_func = ModelUtils.get_encoder_for_type(encoder_type) |
|||
enc_func(vis_input, 32, ModelUtils.swish, 1, "test", False) |
|||
|
|||
# Anything under the min size should raise an exception. If not, decrease the min size! |
|||
with pytest.raises(Exception): |
|||
with tf.Graph().as_default(): |
|||
bad_size = ModelUtils.MIN_RESOLUTION_FOR_ENCODER[encoder_type] - 1 |
|||
bad_res = CameraResolution( |
|||
width=bad_size, height=bad_size, num_channels=3 |
|||
) |
|||
vis_input = ModelUtils.create_visual_input( |
|||
bad_res, "test_min_visual_size" |
|||
) |
|||
|
|||
with pytest.raises(UnityTrainerException): |
|||
# Make sure we'd hit a friendly error during model setup time. |
|||
ModelUtils._check_resolution_for_encoder(vis_input, encoder_type) |
|||
|
|||
enc_func = ModelUtils.get_encoder_for_type(encoder_type) |
|||
enc_func(vis_input, 32, ModelUtils.swish, 1, "test", False) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
using System.IO.Abstractions; |
|||
using System.Text.RegularExpressions; |
|||
using UnityEngine; |
|||
using System.IO; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// Demonstration Recorder Component.
|
|||
/// </summary>
|
|||
[RequireComponent(typeof(Agent))] |
|||
[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)] |
|||
public class DemonstrationRecorder : MonoBehaviour |
|||
{ |
|||
[Tooltip("Whether or not to record demonstrations.")] |
|||
public bool record; |
|||
|
|||
[Tooltip("Base demonstration file name. Will have numbers appended to make unique.")] |
|||
public string demonstrationName; |
|||
|
|||
[Tooltip("Base directory to write the demo files. If null, will use {Application.dataPath}/Demonstrations.")] |
|||
public string demonstrationDirectory; |
|||
|
|||
DemonstrationWriter m_DemoWriter; |
|||
internal const int MaxNameLength = 16; |
|||
|
|||
const string k_ExtensionType = ".demo"; |
|||
IFileSystem m_FileSystem; |
|||
|
|||
Agent m_Agent; |
|||
|
|||
void OnEnable() |
|||
{ |
|||
m_Agent = GetComponent<Agent>(); |
|||
} |
|||
|
|||
void Update() |
|||
{ |
|||
if (record) |
|||
{ |
|||
LazyInitialize(); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Creates demonstration store for use in recording.
|
|||
/// Has no effect if the demonstration store was already created.
|
|||
/// </summary>
|
|||
internal DemonstrationWriter LazyInitialize(IFileSystem fileSystem = null) |
|||
{ |
|||
if (m_DemoWriter != null) |
|||
{ |
|||
return m_DemoWriter; |
|||
} |
|||
|
|||
if (m_Agent == null) |
|||
{ |
|||
m_Agent = GetComponent<Agent>(); |
|||
} |
|||
|
|||
m_FileSystem = fileSystem ?? new FileSystem(); |
|||
var behaviorParams = GetComponent<BehaviorParameters>(); |
|||
if (string.IsNullOrEmpty(demonstrationName)) |
|||
{ |
|||
demonstrationName = behaviorParams.behaviorName; |
|||
} |
|||
if (string.IsNullOrEmpty(demonstrationDirectory)) |
|||
{ |
|||
demonstrationDirectory = Path.Combine(Application.dataPath, "Demonstrations"); |
|||
} |
|||
|
|||
demonstrationName = SanitizeName(demonstrationName, MaxNameLength); |
|||
var filePath = MakeDemonstrationFilePath(m_FileSystem, demonstrationDirectory, demonstrationName); |
|||
var stream = m_FileSystem.File.Create(filePath); |
|||
m_DemoWriter = new DemonstrationWriter(stream); |
|||
|
|||
m_DemoWriter.Initialize( |
|||
demonstrationName, |
|||
behaviorParams.brainParameters, |
|||
behaviorParams.fullyQualifiedBehaviorName |
|||
); |
|||
|
|||
AddDemonstrationWriterToAgent(m_DemoWriter); |
|||
|
|||
return m_DemoWriter; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Removes all characters except alphanumerics from demonstration name.
|
|||
/// Shorten name if it is longer than the maxNameLength.
|
|||
/// </summary>
|
|||
internal static string SanitizeName(string demoName, int maxNameLength) |
|||
{ |
|||
var rgx = new Regex("[^a-zA-Z0-9 -]"); |
|||
demoName = rgx.Replace(demoName, ""); |
|||
// If the string is too long, it will overflow the metadata.
|
|||
if (demoName.Length > maxNameLength) |
|||
{ |
|||
demoName = demoName.Substring(0, maxNameLength); |
|||
} |
|||
return demoName; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets a unique path for the demonstrationName in the demonstrationDirectory.
|
|||
/// </summary>
|
|||
/// <param name="fileSystem"></param>
|
|||
/// <param name="demonstrationDirectory"></param>
|
|||
/// <param name="demonstrationName"></param>
|
|||
/// <returns></returns>
|
|||
internal static string MakeDemonstrationFilePath( |
|||
IFileSystem fileSystem, string demonstrationDirectory, string demonstrationName |
|||
) |
|||
{ |
|||
// Create the directory if it doesn't already exist
|
|||
if (!fileSystem.Directory.Exists(demonstrationDirectory)) |
|||
{ |
|||
fileSystem.Directory.CreateDirectory(demonstrationDirectory); |
|||
} |
|||
|
|||
var literalName = demonstrationName; |
|||
var filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType); |
|||
var uniqueNameCounter = 0; |
|||
while (fileSystem.File.Exists(filePath)) |
|||
{ |
|||
// TODO should we use a timestamp instead of a counter here? This loops an increasing number of times
|
|||
// as the number of demos increases.
|
|||
literalName = demonstrationName + "_" + uniqueNameCounter; |
|||
filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType); |
|||
uniqueNameCounter++; |
|||
} |
|||
|
|||
return filePath; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Close the DemonstrationWriter and remove it from the Agent.
|
|||
/// Has no effect if the DemonstrationWriter is already closed (or wasn't opened)
|
|||
/// </summary>
|
|||
public void Close() |
|||
{ |
|||
if (m_DemoWriter != null) |
|||
{ |
|||
RemoveDemonstrationWriterFromAgent(m_DemoWriter); |
|||
|
|||
m_DemoWriter.Close(); |
|||
m_DemoWriter = null; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Clean up the DemonstrationWriter when shutting down or destroying the Agent.
|
|||
/// </summary>
|
|||
void OnDestroy() |
|||
{ |
|||
Close(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add additional DemonstrationWriter to the Agent. It is still up to the user to Close this
|
|||
/// DemonstrationWriters when recording is done.
|
|||
/// </summary>
|
|||
/// <param name="demoWriter"></param>
|
|||
public void AddDemonstrationWriterToAgent(DemonstrationWriter demoWriter) |
|||
{ |
|||
m_Agent.DemonstrationWriters.Add(demoWriter); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Remove additional DemonstrationWriter to the Agent. It is still up to the user to Close this
|
|||
/// DemonstrationWriters when recording is done.
|
|||
/// </summary>
|
|||
/// <param name="demoWriter"></param>
|
|||
public void RemoveDemonstrationWriterFromAgent(DemonstrationWriter demoWriter) |
|||
{ |
|||
m_Agent.DemonstrationWriters.Remove(demoWriter); |
|||
} |
|||
} |
|||
} |
|
|||
import logging |
|||
import numpy as np |
|||
from typing import Any, Dict, Optional, List |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents_envs.base_env import BatchedStepResult |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.models import EncoderType |
|||
from mlagents.trainers.models import ModelUtils |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class NNPolicy(TFPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
create_tf_graph: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output. |
|||
""" |
|||
super().__init__(seed, brain, trainer_params, load) |
|||
self.grads = None |
|||
self.update_batch: Optional[tf.Operation] = None |
|||
num_layers = trainer_params["num_layers"] |
|||
self.h_size = trainer_params["hidden_units"] |
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
self.num_layers = num_layers |
|||
self.vis_encode_type = EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
) |
|||
self.tanh_squash = tanh_squash |
|||
self.reparameterize = reparameterize |
|||
self.condition_sigma_on_obs = condition_sigma_on_obs |
|||
self.trainable_variables: List[tf.Variable] = [] |
|||
|
|||
# Non-exposed parameters; these aren't exposed because they don't have a |
|||
# good explanation and usually shouldn't be touched. |
|||
self.log_std_min = -20 |
|||
self.log_std_max = 2 |
|||
if create_tf_graph: |
|||
self.create_tf_graph() |
|||
|
|||
def get_trainable_variables(self) -> List[tf.Variable]: |
|||
""" |
|||
Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called, |
|||
returns empty list. |
|||
""" |
|||
return self.trainable_variables |
|||
|
|||
def create_tf_graph(self) -> None: |
|||
""" |
|||
Builds the tensorflow graph needed for this policy. |
|||
""" |
|||
with self.graph.as_default(): |
|||
tf.set_random_seed(self.seed) |
|||
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|||
if len(_vars) > 0: |
|||
# We assume the first thing created in the graph is the Policy. If |
|||
# already populated, don't create more tensors. |
|||
return |
|||
|
|||
self.create_input_placeholders() |
|||
encoded = self._create_encoder( |
|||
self.visual_in, |
|||
self.processed_vector_in, |
|||
self.h_size, |
|||
self.num_layers, |
|||
self.vis_encode_type, |
|||
) |
|||
if self.use_continuous_act: |
|||
self._create_cc_actor( |
|||
encoded, |
|||
self.tanh_squash, |
|||
self.reparameterize, |
|||
self.condition_sigma_on_obs, |
|||
) |
|||
else: |
|||
self._create_dc_actor(encoded) |
|||
self.trainable_variables = tf.get_collection( |
|||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy" |
|||
) |
|||
self.trainable_variables += tf.get_collection( |
|||
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm" |
|||
) # LSTMs need to be root scope for Barracuda export |
|||
|
|||
self.inference_dict: Dict[str, tf.Tensor] = { |
|||
"action": self.output, |
|||
"log_probs": self.all_log_probs, |
|||
"entropy": self.entropy, |
|||
} |
|||
if self.use_continuous_act: |
|||
self.inference_dict["pre_action"] = self.output_pre |
|||
if self.use_recurrent: |
|||
self.inference_dict["memory_out"] = self.memory_out |
|||
|
|||
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed, |
|||
# it will re-load the full graph |
|||
self._initialize_graph() |
|||
|
|||
@timed |
|||
def evaluate( |
|||
self, batched_step_result: BatchedStepResult, global_agent_ids: List[str] |
|||
) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param batched_step_result: BatchedStepResult object containing inputs. |
|||
:param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
feed_dict = { |
|||
self.batch_size_ph: batched_step_result.n_agents(), |
|||
self.sequence_length_ph: 1, |
|||
} |
|||
if self.use_recurrent: |
|||
if not self.use_continuous_act: |
|||
feed_dict[self.prev_action] = self.retrieve_previous_action( |
|||
global_agent_ids |
|||
) |
|||
feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids) |
|||
feed_dict = self.fill_eval_dict(feed_dict, batched_step_result) |
|||
run_out = self._execute_model(feed_dict, self.inference_dict) |
|||
return run_out |
|||
|
|||
def _create_encoder( |
|||
self, |
|||
visual_in: List[tf.Tensor], |
|||
vector_in: tf.Tensor, |
|||
h_size: int, |
|||
num_layers: int, |
|||
vis_encode_type: EncoderType, |
|||
) -> tf.Tensor: |
|||
""" |
|||
Creates an encoder for visual and vector observations. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: Type of visual encoder to use if visual input. |
|||
:return: The hidden layer (tf.Tensor) after the encoder. |
|||
""" |
|||
with tf.variable_scope("policy"): |
|||
encoded = ModelUtils.create_observation_streams( |
|||
self.visual_in, |
|||
self.processed_vector_in, |
|||
1, |
|||
h_size, |
|||
num_layers, |
|||
vis_encode_type, |
|||
)[0] |
|||
return encoded |
|||
|
|||
def _create_cc_actor( |
|||
self, |
|||
encoded: tf.Tensor, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
) -> None: |
|||
""" |
|||
Creates Continuous control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: Type of visual encoder to use if visual input. |
|||
:param tanh_squash: Whether to use a tanh function, or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy. |
|||
""" |
|||
if self.use_recurrent: |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
) |
|||
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|||
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy" |
|||
) |
|||
|
|||
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out") |
|||
else: |
|||
hidden_policy = encoded |
|||
|
|||
with tf.variable_scope("policy"): |
|||
mu = tf.layers.dense( |
|||
hidden_policy, |
|||
self.act_size[0], |
|||
activation=None, |
|||
name="mu", |
|||
kernel_initializer=ModelUtils.scaled_init(0.01), |
|||
reuse=tf.AUTO_REUSE, |
|||
) |
|||
|
|||
# Policy-dependent log_sigma |
|||
if condition_sigma_on_obs: |
|||
log_sigma = tf.layers.dense( |
|||
hidden_policy, |
|||
self.act_size[0], |
|||
activation=None, |
|||
name="log_sigma", |
|||
kernel_initializer=ModelUtils.scaled_init(0.01), |
|||
) |
|||
else: |
|||
log_sigma = tf.get_variable( |
|||
"log_sigma", |
|||
[self.act_size[0]], |
|||
dtype=tf.float32, |
|||
initializer=tf.zeros_initializer(), |
|||
) |
|||
log_sigma = tf.clip_by_value(log_sigma, self.log_std_min, self.log_std_max) |
|||
|
|||
sigma = tf.exp(log_sigma) |
|||
|
|||
epsilon = tf.random_normal(tf.shape(mu)) |
|||
|
|||
sampled_policy = mu + sigma * epsilon |
|||
|
|||
# Stop gradient if we're not doing the resampling trick |
|||
if not reparameterize: |
|||
sampled_policy_probs = tf.stop_gradient(sampled_policy) |
|||
else: |
|||
sampled_policy_probs = sampled_policy |
|||
|
|||
# Compute probability of model output. |
|||
_gauss_pre = -0.5 * ( |
|||
((sampled_policy_probs - mu) / (sigma + EPSILON)) ** 2 |
|||
+ 2 * log_sigma |
|||
+ np.log(2 * np.pi) |
|||
) |
|||
all_probs = _gauss_pre |
|||
all_probs = tf.reduce_sum(_gauss_pre, axis=1, keepdims=True) |
|||
|
|||
if tanh_squash: |
|||
self.output_pre = tf.tanh(sampled_policy) |
|||
|
|||
# Squash correction |
|||
all_probs -= tf.reduce_sum( |
|||
tf.log(1 - self.output_pre ** 2 + EPSILON), axis=1, keepdims=True |
|||
) |
|||
self.output = tf.identity(self.output_pre, name="action") |
|||
else: |
|||
self.output_pre = sampled_policy |
|||
# Clip and scale output to ensure actions are always within [-1, 1] range. |
|||
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3 |
|||
self.output = tf.identity(output_post, name="action") |
|||
|
|||
self.selected_actions = tf.stop_gradient(self.output) |
|||
|
|||
self.all_log_probs = tf.identity(all_probs, name="action_probs") |
|||
|
|||
single_dim_entropy = 0.5 * tf.reduce_mean( |
|||
tf.log(2 * np.pi * np.e) + 2 * log_sigma |
|||
) |
|||
# Make entropy the right shape |
|||
self.entropy = tf.ones_like(tf.reshape(mu[:, 0], [-1])) * single_dim_entropy |
|||
|
|||
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control. |
|||
self.log_probs = tf.reduce_sum( |
|||
(tf.identity(self.all_log_probs)), axis=1, keepdims=True |
|||
) |
|||
|
|||
self.action_holder = tf.placeholder( |
|||
shape=[None, self.act_size[0]], dtype=tf.float32, name="action_holder" |
|||
) |
|||
|
|||
def _create_dc_actor(self, encoded: tf.Tensor) -> None: |
|||
""" |
|||
Creates Discrete control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
:param vis_encode_type: Type of visual encoder to use if visual input. |
|||
""" |
|||
if self.use_recurrent: |
|||
self.prev_action = tf.placeholder( |
|||
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|||
) |
|||
prev_action_oh = tf.concat( |
|||
[ |
|||
tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) |
|||
|
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
) |
|||
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|||
hidden_policy, |
|||
self.memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_policy", |
|||
) |
|||
|
|||
self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|||
else: |
|||
hidden_policy = encoded |
|||
|
|||
policy_branches = [] |
|||
with tf.variable_scope("policy"): |
|||
for size in self.act_size: |
|||
policy_branches.append( |
|||
tf.layers.dense( |
|||
hidden_policy, |
|||
size, |
|||
activation=None, |
|||
use_bias=False, |
|||
kernel_initializer=ModelUtils.scaled_init(0.01), |
|||
) |
|||
) |
|||
|
|||
raw_log_probs = tf.concat(policy_branches, axis=1, name="action_probs") |
|||
|
|||
self.action_masks = tf.placeholder( |
|||
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" |
|||
) |
|||
output, self.action_probs, normalized_logits = ModelUtils.create_discrete_action_masking_layer( |
|||
raw_log_probs, self.action_masks, self.act_size |
|||
) |
|||
|
|||
self.output = tf.identity(output) |
|||
self.all_log_probs = tf.identity(normalized_logits, name="action") |
|||
|
|||
self.action_holder = tf.placeholder( |
|||
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder" |
|||
) |
|||
self.action_oh = tf.concat( |
|||
[ |
|||
tf.one_hot(self.action_holder[:, i], self.act_size[i]) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
self.selected_actions = tf.stop_gradient(self.action_oh) |
|||
|
|||
action_idx = [0] + list(np.cumsum(self.act_size)) |
|||
|
|||
self.entropy = tf.reduce_sum( |
|||
( |
|||
tf.stack( |
|||
[ |
|||
tf.nn.softmax_cross_entropy_with_logits_v2( |
|||
labels=tf.nn.softmax( |
|||
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]] |
|||
), |
|||
logits=self.all_log_probs[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
), |
|||
axis=1, |
|||
) |
|||
|
|||
self.log_probs = tf.reduce_sum( |
|||
( |
|||
tf.stack( |
|||
[ |
|||
-tf.nn.softmax_cross_entropy_with_logits_v2( |
|||
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]], |
|||
logits=normalized_logits[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
), |
|||
axis=1, |
|||
keepdims=True, |
|||
) |
|
|||
import abc |
|||
from typing import Dict |
|||
|
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
|
|||
|
|||
class Optimizer(abc.ABC): |
|||
""" |
|||
Creates loss functions and auxillary networks (e.g. Q or Value) needed for training. |
|||
Provides methods to update the Policy. |
|||
""" |
|||
|
|||
@abc.abstractmethod |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Update the Policy based on the batch that was passed in. |
|||
:param batch: AgentBuffer that contains the minibatch of data used for this update. |
|||
:param num_sequences: Number of recurrent sequences found in the minibatch. |
|||
:return: A Dict containing statistics (name, value) from the update (e.g. loss) |
|||
""" |
|||
pass |
|
|||
from typing import Dict, Any, List, Tuple, Optional |
|||
import numpy as np |
|||
|
|||
from mlagents.tf_utils.tf import tf |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.common.optimizer import Optimizer |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|||
create_reward_signal, |
|||
) |
|||
from mlagents.trainers.components.bc.module import BCModule |
|||
|
|||
|
|||
class TFOptimizer(Optimizer): # pylint: disable=W0223 |
|||
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|||
self.sess = policy.sess |
|||
self.policy = policy |
|||
self.update_dict: Dict[str, tf.Tensor] = {} |
|||
self.value_heads: Dict[str, tf.Tensor] = {} |
|||
self.create_reward_signals(trainer_params["reward_signals"]) |
|||
self.memory_in: tf.Tensor = None |
|||
self.memory_out: tf.Tensor = None |
|||
self.m_size: int = 0 |
|||
self.bc_module: Optional[BCModule] = None |
|||
# Create pretrainer if needed |
|||
if "behavioral_cloning" in trainer_params: |
|||
BCModule.check_config(trainer_params["behavioral_cloning"]) |
|||
self.bc_module = BCModule( |
|||
self.policy, |
|||
policy_learning_rate=trainer_params["learning_rate"], |
|||
default_batch_size=trainer_params["batch_size"], |
|||
default_num_epoch=3, |
|||
**trainer_params["behavioral_cloning"], |
|||
) |
|||
|
|||
def get_trajectory_value_estimates( |
|||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool |
|||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: |
|||
feed_dict: Dict[tf.Tensor, Any] = { |
|||
self.policy.batch_size_ph: batch.num_experiences, |
|||
self.policy.sequence_length_ph: batch.num_experiences, # We want to feed data in batch-wise, not time-wise. |
|||
} |
|||
|
|||
if self.policy.vec_obs_size > 0: |
|||
feed_dict[self.policy.vector_in] = batch["vector_obs"] |
|||
if self.policy.vis_obs_size > 0: |
|||
for i in range(len(self.policy.visual_in)): |
|||
_obs = batch["visual_obs%d" % i] |
|||
feed_dict[self.policy.visual_in[i]] = _obs |
|||
if self.policy.use_recurrent: |
|||
feed_dict[self.policy.memory_in] = [np.zeros((self.policy.m_size))] |
|||
feed_dict[self.memory_in] = [np.zeros((self.m_size))] |
|||
if self.policy.prev_action is not None: |
|||
feed_dict[self.policy.prev_action] = batch["prev_action"] |
|||
|
|||
if self.policy.use_recurrent: |
|||
value_estimates, policy_mem, value_mem = self.sess.run( |
|||
[self.value_heads, self.policy.memory_out, self.memory_out], feed_dict |
|||
) |
|||
prev_action = batch["actions"][-1] |
|||
else: |
|||
value_estimates = self.sess.run(self.value_heads, feed_dict) |
|||
prev_action = None |
|||
policy_mem = None |
|||
value_mem = None |
|||
value_estimates = {k: np.squeeze(v, axis=1) for k, v in value_estimates.items()} |
|||
|
|||
# We do this in a separate step to feed the memory outs - a further optimization would |
|||
# be to append to the obs before running sess.run. |
|||
final_value_estimates = self._get_value_estimates( |
|||
next_obs, done, policy_mem, value_mem, prev_action |
|||
) |
|||
|
|||
return value_estimates, final_value_estimates |
|||
|
|||
def _get_value_estimates( |
|||
self, |
|||
next_obs: List[np.ndarray], |
|||
done: bool, |
|||
policy_memory: np.ndarray = None, |
|||
value_memory: np.ndarray = None, |
|||
prev_action: np.ndarray = None, |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Generates value estimates for bootstrapping. |
|||
:param experience: AgentExperience to be used for bootstrapping. |
|||
:param done: Whether or not this is the last element of the episode, in which case the value estimate will be 0. |
|||
:return: The value estimate dictionary with key being the name of the reward signal and the value the |
|||
corresponding value estimate. |
|||
""" |
|||
|
|||
feed_dict: Dict[tf.Tensor, Any] = { |
|||
self.policy.batch_size_ph: 1, |
|||
self.policy.sequence_length_ph: 1, |
|||
} |
|||
vec_vis_obs = SplitObservations.from_observations(next_obs) |
|||
for i in range(len(vec_vis_obs.visual_observations)): |
|||
feed_dict[self.policy.visual_in[i]] = [vec_vis_obs.visual_observations[i]] |
|||
|
|||
if self.policy.vec_obs_size > 0: |
|||
feed_dict[self.policy.vector_in] = [vec_vis_obs.vector_observations] |
|||
if policy_memory is not None: |
|||
feed_dict[self.policy.memory_in] = policy_memory |
|||
if value_memory is not None: |
|||
feed_dict[self.memory_in] = value_memory |
|||
if prev_action is not None: |
|||
feed_dict[self.policy.prev_action] = [prev_action] |
|||
value_estimates = self.sess.run(self.value_heads, feed_dict) |
|||
|
|||
value_estimates = {k: float(v) for k, v in value_estimates.items()} |
|||
|
|||
# If we're done, reassign all of the value estimates that need terminal states. |
|||
if done: |
|||
for k in value_estimates: |
|||
if self.reward_signals[k].use_terminal_states: |
|||
value_estimates[k] = 0.0 |
|||
|
|||
return value_estimates |
|||
|
|||
def create_reward_signals(self, reward_signal_configs: Dict[str, Any]) -> None: |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
self.reward_signals = {} |
|||
# Create reward signals |
|||
for reward_signal, config in reward_signal_configs.items(): |
|||
self.reward_signals[reward_signal] = create_reward_signal( |
|||
self.policy, reward_signal, config |
|||
) |
|||
self.update_dict.update(self.reward_signals[reward_signal].update_dict) |
|||
|
|||
def create_optimizer_op( |
|||
self, learning_rate: tf.Tensor, name: str = "Adam" |
|||
) -> tf.train.Optimizer: |
|||
return tf.train.AdamOptimizer(learning_rate=learning_rate, name=name) |
|||
|
|||
def _execute_model( |
|||
self, feed_dict: Dict[tf.Tensor, np.ndarray], out_dict: Dict[str, tf.Tensor] |
|||
) -> Dict[str, np.ndarray]: |
|||
""" |
|||
Executes model. |
|||
:param feed_dict: Input dictionary mapping nodes to input data. |
|||
:param out_dict: Output dictionary mapping names to nodes. |
|||
:return: Dictionary mapping names to input data. |
|||
""" |
|||
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict) |
|||
run_out = dict(zip(list(out_dict.keys()), network_out)) |
|||
return run_out |
|||
|
|||
def _make_zero_mem(self, m_size: int, length: int) -> List[np.ndarray]: |
|||
return [ |
|||
np.zeros((m_size), dtype=np.float32) |
|||
for i in range(0, length, self.policy.sequence_length) |
|||
] |
|
|||
using System.IO.Abstractions; |
|||
using System.Text.RegularExpressions; |
|||
using UnityEngine; |
|||
using System.IO; |
|||
|
|||
namespace MLAgents |
|||
{ |
|||
/// <summary>
|
|||
/// Demonstration Recorder Component.
|
|||
/// </summary>
|
|||
[RequireComponent(typeof(Agent))] |
|||
[AddComponentMenu("ML Agents/Demonstration Recorder", (int)MenuGroup.Default)] |
|||
public class DemonstrationRecorder : MonoBehaviour |
|||
{ |
|||
[Tooltip("Whether or not to record demonstrations.")] |
|||
public bool record; |
|||
|
|||
[Tooltip("Base demonstration file name. Will have numbers appended to make unique.")] |
|||
public string demonstrationName; |
|||
|
|||
[Tooltip("Base directory to write the demo files. If null, will use {Application.dataPath}/Demonstrations.")] |
|||
public string demonstrationDirectory; |
|||
|
|||
DemonstrationStore m_DemoStore; |
|||
internal const int MaxNameLength = 16; |
|||
|
|||
const string k_ExtensionType = ".demo"; |
|||
IFileSystem m_FileSystem; |
|||
|
|||
Agent m_Agent; |
|||
|
|||
void OnEnable() |
|||
{ |
|||
m_Agent = GetComponent<Agent>(); |
|||
} |
|||
|
|||
void Update() |
|||
{ |
|||
if (record) |
|||
{ |
|||
LazyInitialize(); |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Creates demonstration store for use in recording.
|
|||
/// Has no effect if the demonstration store was already created.
|
|||
/// </summary>
|
|||
internal DemonstrationStore LazyInitialize(IFileSystem fileSystem = null) |
|||
{ |
|||
if (m_DemoStore != null) |
|||
{ |
|||
return m_DemoStore; |
|||
} |
|||
|
|||
if (m_Agent == null) |
|||
{ |
|||
m_Agent = GetComponent<Agent>(); |
|||
} |
|||
|
|||
m_FileSystem = fileSystem ?? new FileSystem(); |
|||
var behaviorParams = GetComponent<BehaviorParameters>(); |
|||
if (string.IsNullOrEmpty(demonstrationName)) |
|||
{ |
|||
demonstrationName = behaviorParams.behaviorName; |
|||
} |
|||
if (string.IsNullOrEmpty(demonstrationDirectory)) |
|||
{ |
|||
demonstrationDirectory = Path.Combine(Application.dataPath, "Demonstrations"); |
|||
} |
|||
|
|||
demonstrationName = SanitizeName(demonstrationName, MaxNameLength); |
|||
var filePath = MakeDemonstrationFilePath(m_FileSystem, demonstrationDirectory, demonstrationName); |
|||
var stream = m_FileSystem.File.Create(filePath); |
|||
m_DemoStore = new DemonstrationStore(stream); |
|||
|
|||
m_DemoStore.Initialize( |
|||
demonstrationName, |
|||
behaviorParams.brainParameters, |
|||
behaviorParams.fullyQualifiedBehaviorName |
|||
); |
|||
|
|||
AddDemonstrationStoreToAgent(m_DemoStore); |
|||
|
|||
return m_DemoStore; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Removes all characters except alphanumerics from demonstration name.
|
|||
/// Shorten name if it is longer than the maxNameLength.
|
|||
/// </summary>
|
|||
internal static string SanitizeName(string demoName, int maxNameLength) |
|||
{ |
|||
var rgx = new Regex("[^a-zA-Z0-9 -]"); |
|||
demoName = rgx.Replace(demoName, ""); |
|||
// If the string is too long, it will overflow the metadata.
|
|||
if (demoName.Length > maxNameLength) |
|||
{ |
|||
demoName = demoName.Substring(0, maxNameLength); |
|||
} |
|||
return demoName; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Gets a unique path for the demonstrationName in the demonstrationDirectory.
|
|||
/// </summary>
|
|||
/// <param name="fileSystem"></param>
|
|||
/// <param name="demonstrationDirectory"></param>
|
|||
/// <param name="demonstrationName"></param>
|
|||
/// <returns></returns>
|
|||
internal static string MakeDemonstrationFilePath( |
|||
IFileSystem fileSystem, string demonstrationDirectory, string demonstrationName |
|||
) |
|||
{ |
|||
// Create the directory if it doesn't already exist
|
|||
if (!fileSystem.Directory.Exists(demonstrationDirectory)) |
|||
{ |
|||
fileSystem.Directory.CreateDirectory(demonstrationDirectory); |
|||
} |
|||
|
|||
var literalName = demonstrationName; |
|||
var filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType); |
|||
var uniqueNameCounter = 0; |
|||
while (fileSystem.File.Exists(filePath)) |
|||
{ |
|||
// TODO should we use a timestamp instead of a counter here? This loops an increasing number of times
|
|||
// as the number of demos increases.
|
|||
literalName = demonstrationName + "_" + uniqueNameCounter; |
|||
filePath = Path.Combine(demonstrationDirectory, literalName + k_ExtensionType); |
|||
uniqueNameCounter++; |
|||
} |
|||
|
|||
return filePath; |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Close the DemonstrationStore and remove it from the Agent.
|
|||
/// Has no effect if the DemonstrationStore is already closed (or wasn't opened)
|
|||
/// </summary>
|
|||
public void Close() |
|||
{ |
|||
if (m_DemoStore != null) |
|||
{ |
|||
RemoveDemonstrationStoreFromAgent(m_DemoStore); |
|||
|
|||
m_DemoStore.Close(); |
|||
m_DemoStore = null; |
|||
} |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Clean up the DemonstrationStore when shutting down or destroying the Agent.
|
|||
/// </summary>
|
|||
void OnDestroy() |
|||
{ |
|||
Close(); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Add additional DemonstrationStore to the Agent. It is still up to the user to Close this
|
|||
/// DemonstrationStores when recording is done.
|
|||
/// </summary>
|
|||
/// <param name="demoStore"></param>
|
|||
public void AddDemonstrationStoreToAgent(DemonstrationStore demoStore) |
|||
{ |
|||
m_Agent.DemonstrationStores.Add(demoStore); |
|||
} |
|||
|
|||
/// <summary>
|
|||
/// Remove additional DemonstrationStore to the Agent. It is still up to the user to Close this
|
|||
/// DemonstrationStores when recording is done.
|
|||
/// </summary>
|
|||
/// <param name="demoStore"></param>
|
|||
public void RemoveDemonstrationStoreFromAgent(DemonstrationStore demoStore) |
|||
{ |
|||
m_Agent.DemonstrationStores.Remove(demoStore); |
|||
} |
|||
} |
|||
} |
|
|||
import logging |
|||
from typing import Optional |
|||
|
|||
import numpy as np |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.models import LearningModel, EncoderType, LearningRateSchedule |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class PPOModel(LearningModel): |
|||
def __init__( |
|||
self, |
|||
brain, |
|||
lr=1e-4, |
|||
lr_schedule=LearningRateSchedule.LINEAR, |
|||
h_size=128, |
|||
epsilon=0.2, |
|||
beta=1e-3, |
|||
max_step=5e6, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
m_size=None, |
|||
seed=0, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
""" |
|||
Takes a Unity environment and model-specific hyper-parameters and returns the |
|||
appropriate PPO agent model for the environment. |
|||
:param brain: brain parameters used to generate specific network graph. |
|||
:param lr: Learning rate. |
|||
:param lr_schedule: Learning rate decay schedule. |
|||
:param h_size: Size of hidden layers |
|||
:param epsilon: Value for policy-divergence threshold. |
|||
:param beta: Strength of entropy regularization. |
|||
:param max_step: Total number of training steps. |
|||
:param normalize: Whether to normalize vector observation input. |
|||
:param use_recurrent: Whether to use an LSTM layer in the network. |
|||
:param num_layers Number of hidden layers between encoded input and policy & value layers |
|||
:param m_size: Size of brain memory. |
|||
:param seed: Seed to use for initialization of model. |
|||
:param stream_names: List of names of value streams. Usually, a list of the Reward Signals being used. |
|||
:return: a sub-class of PPOAgent tailored to the environment. |
|||
""" |
|||
LearningModel.__init__( |
|||
self, m_size, normalize, use_recurrent, brain, seed, stream_names |
|||
) |
|||
|
|||
self.optimizer: Optional[tf.train.AdamOptimizer] = None |
|||
self.grads = None |
|||
self.update_batch: Optional[tf.Operation] = None |
|||
|
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
if brain.vector_action_space_type == "continuous": |
|||
self.create_cc_actor_critic(h_size, num_layers, vis_encode_type) |
|||
self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy |
|||
else: |
|||
self.create_dc_actor_critic(h_size, num_layers, vis_encode_type) |
|||
self.learning_rate = self.create_learning_rate( |
|||
lr_schedule, lr, self.global_step, max_step |
|||
) |
|||
self.create_losses( |
|||
self.log_probs, |
|||
self.old_log_probs, |
|||
self.value_heads, |
|||
self.entropy, |
|||
beta, |
|||
epsilon, |
|||
lr, |
|||
max_step, |
|||
) |
|||
|
|||
def create_cc_actor_critic( |
|||
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|||
) -> None: |
|||
""" |
|||
Creates Continuous control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
""" |
|||
hidden_streams = self.create_observation_streams( |
|||
2, h_size, num_layers, vis_encode_type |
|||
) |
|||
|
|||
if self.use_recurrent: |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
) |
|||
_half_point = int(self.m_size / 2) |
|||
hidden_policy, memory_policy_out = self.create_recurrent_encoder( |
|||
hidden_streams[0], |
|||
self.memory_in[:, :_half_point], |
|||
self.sequence_length, |
|||
name="lstm_policy", |
|||
) |
|||
|
|||
hidden_value, memory_value_out = self.create_recurrent_encoder( |
|||
hidden_streams[1], |
|||
self.memory_in[:, _half_point:], |
|||
self.sequence_length, |
|||
name="lstm_value", |
|||
) |
|||
self.memory_out = tf.concat( |
|||
[memory_policy_out, memory_value_out], axis=1, name="recurrent_out" |
|||
) |
|||
else: |
|||
hidden_policy = hidden_streams[0] |
|||
hidden_value = hidden_streams[1] |
|||
|
|||
mu = tf.layers.dense( |
|||
hidden_policy, |
|||
self.act_size[0], |
|||
activation=None, |
|||
kernel_initializer=LearningModel.scaled_init(0.01), |
|||
reuse=tf.AUTO_REUSE, |
|||
) |
|||
|
|||
self.log_sigma_sq = tf.get_variable( |
|||
"log_sigma_squared", |
|||
[self.act_size[0]], |
|||
dtype=tf.float32, |
|||
initializer=tf.zeros_initializer(), |
|||
) |
|||
|
|||
sigma_sq = tf.exp(self.log_sigma_sq) |
|||
|
|||
self.epsilon = tf.placeholder( |
|||
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon" |
|||
) |
|||
# Clip and scale output to ensure actions are always within [-1, 1] range. |
|||
self.output_pre = mu + tf.sqrt(sigma_sq) * self.epsilon |
|||
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3 |
|||
self.output = tf.identity(output_post, name="action") |
|||
self.selected_actions = tf.stop_gradient(output_post) |
|||
|
|||
# Compute probability of model output. |
|||
all_probs = ( |
|||
-0.5 * tf.square(tf.stop_gradient(self.output_pre) - mu) / sigma_sq |
|||
- 0.5 * tf.log(2.0 * np.pi) |
|||
- 0.5 * self.log_sigma_sq |
|||
) |
|||
|
|||
self.all_log_probs = tf.identity(all_probs, name="action_probs") |
|||
|
|||
self.entropy = 0.5 * tf.reduce_mean( |
|||
tf.log(2 * np.pi * np.e) + self.log_sigma_sq |
|||
) |
|||
|
|||
self.create_value_heads(self.stream_names, hidden_value) |
|||
|
|||
self.all_old_log_probs = tf.placeholder( |
|||
shape=[None, self.act_size[0]], dtype=tf.float32, name="old_probabilities" |
|||
) |
|||
|
|||
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control. |
|||
self.log_probs = tf.reduce_sum( |
|||
(tf.identity(self.all_log_probs)), axis=1, keepdims=True |
|||
) |
|||
self.old_log_probs = tf.reduce_sum( |
|||
(tf.identity(self.all_old_log_probs)), axis=1, keepdims=True |
|||
) |
|||
|
|||
def create_dc_actor_critic( |
|||
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|||
) -> None: |
|||
""" |
|||
Creates Discrete control actor-critic model. |
|||
:param h_size: Size of hidden linear layers. |
|||
:param num_layers: Number of hidden linear layers. |
|||
""" |
|||
hidden_streams = self.create_observation_streams( |
|||
1, h_size, num_layers, vis_encode_type |
|||
) |
|||
hidden = hidden_streams[0] |
|||
|
|||
if self.use_recurrent: |
|||
self.prev_action = tf.placeholder( |
|||
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|||
) |
|||
prev_action_oh = tf.concat( |
|||
[ |
|||
tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
hidden = tf.concat([hidden, prev_action_oh], axis=1) |
|||
|
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
) |
|||
hidden, memory_out = self.create_recurrent_encoder( |
|||
hidden, self.memory_in, self.sequence_length |
|||
) |
|||
self.memory_out = tf.identity(memory_out, name="recurrent_out") |
|||
|
|||
policy_branches = [] |
|||
for size in self.act_size: |
|||
policy_branches.append( |
|||
tf.layers.dense( |
|||
hidden, |
|||
size, |
|||
activation=None, |
|||
use_bias=False, |
|||
kernel_initializer=LearningModel.scaled_init(0.01), |
|||
) |
|||
) |
|||
|
|||
self.all_log_probs = tf.concat(policy_branches, axis=1, name="action_probs") |
|||
|
|||
self.action_masks = tf.placeholder( |
|||
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" |
|||
) |
|||
output, _, normalized_logits = self.create_discrete_action_masking_layer( |
|||
self.all_log_probs, self.action_masks, self.act_size |
|||
) |
|||
|
|||
self.output = tf.identity(output) |
|||
self.normalized_logits = tf.identity(normalized_logits, name="action") |
|||
|
|||
self.create_value_heads(self.stream_names, hidden) |
|||
|
|||
self.action_holder = tf.placeholder( |
|||
shape=[None, len(policy_branches)], dtype=tf.int32, name="action_holder" |
|||
) |
|||
self.action_oh = tf.concat( |
|||
[ |
|||
tf.one_hot(self.action_holder[:, i], self.act_size[i]) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
self.selected_actions = tf.stop_gradient(self.action_oh) |
|||
|
|||
self.all_old_log_probs = tf.placeholder( |
|||
shape=[None, sum(self.act_size)], dtype=tf.float32, name="old_probabilities" |
|||
) |
|||
_, _, old_normalized_logits = self.create_discrete_action_masking_layer( |
|||
self.all_old_log_probs, self.action_masks, self.act_size |
|||
) |
|||
|
|||
action_idx = [0] + list(np.cumsum(self.act_size)) |
|||
|
|||
self.entropy = tf.reduce_sum( |
|||
( |
|||
tf.stack( |
|||
[ |
|||
tf.nn.softmax_cross_entropy_with_logits_v2( |
|||
labels=tf.nn.softmax( |
|||
self.all_log_probs[:, action_idx[i] : action_idx[i + 1]] |
|||
), |
|||
logits=self.all_log_probs[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
), |
|||
axis=1, |
|||
) |
|||
|
|||
self.log_probs = tf.reduce_sum( |
|||
( |
|||
tf.stack( |
|||
[ |
|||
-tf.nn.softmax_cross_entropy_with_logits_v2( |
|||
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]], |
|||
logits=normalized_logits[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
), |
|||
axis=1, |
|||
keepdims=True, |
|||
) |
|||
self.old_log_probs = tf.reduce_sum( |
|||
( |
|||
tf.stack( |
|||
[ |
|||
-tf.nn.softmax_cross_entropy_with_logits_v2( |
|||
labels=self.action_oh[:, action_idx[i] : action_idx[i + 1]], |
|||
logits=old_normalized_logits[ |
|||
:, action_idx[i] : action_idx[i + 1] |
|||
], |
|||
) |
|||
for i in range(len(self.act_size)) |
|||
], |
|||
axis=1, |
|||
) |
|||
), |
|||
axis=1, |
|||
keepdims=True, |
|||
) |
|||
|
|||
def create_losses( |
|||
self, probs, old_probs, value_heads, entropy, beta, epsilon, lr, max_step |
|||
): |
|||
""" |
|||
Creates training-specific Tensorflow ops for PPO models. |
|||
:param probs: Current policy probabilities |
|||
:param old_probs: Past policy probabilities |
|||
:param value_heads: Value estimate tensors from each value stream |
|||
:param beta: Entropy regularization strength |
|||
:param entropy: Current policy entropy |
|||
:param epsilon: Value for policy-divergence threshold |
|||
:param lr: Learning rate |
|||
:param max_step: Total number of training steps. |
|||
""" |
|||
self.returns_holders = {} |
|||
self.old_values = {} |
|||
for name in value_heads.keys(): |
|||
returns_holder = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="{}_returns".format(name) |
|||
) |
|||
old_value = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="{}_value_estimate".format(name) |
|||
) |
|||
self.returns_holders[name] = returns_holder |
|||
self.old_values[name] = old_value |
|||
self.advantage = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="advantages" |
|||
) |
|||
advantage = tf.expand_dims(self.advantage, -1) |
|||
|
|||
decay_epsilon = tf.train.polynomial_decay( |
|||
epsilon, self.global_step, max_step, 0.1, power=1.0 |
|||
) |
|||
decay_beta = tf.train.polynomial_decay( |
|||
beta, self.global_step, max_step, 1e-5, power=1.0 |
|||
) |
|||
|
|||
value_losses = [] |
|||
for name, head in value_heads.items(): |
|||
clipped_value_estimate = self.old_values[name] + tf.clip_by_value( |
|||
tf.reduce_sum(head, axis=1) - self.old_values[name], |
|||
-decay_epsilon, |
|||
decay_epsilon, |
|||
) |
|||
v_opt_a = tf.squared_difference( |
|||
self.returns_holders[name], tf.reduce_sum(head, axis=1) |
|||
) |
|||
v_opt_b = tf.squared_difference( |
|||
self.returns_holders[name], clipped_value_estimate |
|||
) |
|||
value_loss = tf.reduce_mean( |
|||
tf.dynamic_partition(tf.maximum(v_opt_a, v_opt_b), self.mask, 2)[1] |
|||
) |
|||
value_losses.append(value_loss) |
|||
self.value_loss = tf.reduce_mean(value_losses) |
|||
|
|||
r_theta = tf.exp(probs - old_probs) |
|||
p_opt_a = r_theta * advantage |
|||
p_opt_b = ( |
|||
tf.clip_by_value(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) |
|||
* advantage |
|||
) |
|||
self.policy_loss = -tf.reduce_mean( |
|||
tf.dynamic_partition(tf.minimum(p_opt_a, p_opt_b), self.mask, 2)[1] |
|||
) |
|||
# For cleaner stats reporting |
|||
self.abs_policy_loss = tf.abs(self.policy_loss) |
|||
|
|||
self.loss = ( |
|||
self.policy_loss |
|||
+ 0.5 * self.value_loss |
|||
- decay_beta |
|||
* tf.reduce_mean(tf.dynamic_partition(entropy, self.mask, 2)[1]) |
|||
) |
|||
|
|||
def create_ppo_optimizer(self): |
|||
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate) |
|||
self.grads = self.optimizer.compute_gradients(self.loss) |
|||
self.update_batch = self.optimizer.minimize(self.loss) |
|
|||
import logging |
|||
from typing import Any, Dict, List, Optional |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from tensorflow.python.client import device_lib |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.models import EncoderType, LearningRateSchedule |
|||
from mlagents.trainers.ppo.policy import PPOPolicy |
|||
from mlagents.trainers.ppo.models import PPOModel |
|||
from mlagents.trainers.components.reward_signals import RewardSignal |
|||
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|||
create_reward_signal, |
|||
) |
|||
|
|||
# Variable scope in which created variables will be placed under |
|||
TOWER_SCOPE_NAME = "tower" |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class MultiGpuPPOPolicy(PPOPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
): |
|||
self.towers: List[PPOModel] = [] |
|||
self.devices: List[str] = [] |
|||
self.model: Optional[PPOModel] = None |
|||
self.total_policy_loss: Optional[tf.Tensor] = None |
|||
self.reward_signal_towers: List[Dict[str, RewardSignal]] = [] |
|||
self.reward_signals: Dict[str, RewardSignal] = {} |
|||
|
|||
super().__init__(seed, brain, trainer_params, is_training, load) |
|||
|
|||
def create_model( |
|||
self, brain, trainer_params, reward_signal_configs, is_training, load, seed |
|||
): |
|||
""" |
|||
Create PPO models, one on each device |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param reward_signal_configs: Reward signal config |
|||
:param seed: Random seed. |
|||
""" |
|||
self.devices = get_devices() |
|||
|
|||
with self.graph.as_default(): |
|||
with tf.variable_scope("", reuse=tf.AUTO_REUSE): |
|||
for device in self.devices: |
|||
with tf.device(device): |
|||
self.towers.append( |
|||
PPOModel( |
|||
brain=brain, |
|||
lr=float(trainer_params["learning_rate"]), |
|||
lr_schedule=LearningRateSchedule( |
|||
trainer_params.get( |
|||
"learning_rate_schedule", "linear" |
|||
) |
|||
), |
|||
h_size=int(trainer_params["hidden_units"]), |
|||
epsilon=float(trainer_params["epsilon"]), |
|||
beta=float(trainer_params["beta"]), |
|||
max_step=float(trainer_params["max_steps"]), |
|||
normalize=trainer_params["normalize"], |
|||
use_recurrent=trainer_params["use_recurrent"], |
|||
num_layers=int(trainer_params["num_layers"]), |
|||
m_size=self.m_size, |
|||
seed=seed, |
|||
stream_names=list(reward_signal_configs.keys()), |
|||
vis_encode_type=EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
), |
|||
) |
|||
) |
|||
self.towers[-1].create_ppo_optimizer() |
|||
self.model = self.towers[0] |
|||
avg_grads = self.average_gradients([t.grads for t in self.towers]) |
|||
update_batch = self.model.optimizer.apply_gradients(avg_grads) |
|||
|
|||
avg_value_loss = tf.reduce_mean( |
|||
tf.stack([model.value_loss for model in self.towers]), 0 |
|||
) |
|||
avg_policy_loss = tf.reduce_mean( |
|||
tf.stack([model.policy_loss for model in self.towers]), 0 |
|||
) |
|||
|
|||
self.inference_dict.update( |
|||
{ |
|||
"action": self.model.output, |
|||
"log_probs": self.model.all_log_probs, |
|||
"value_heads": self.model.value_heads, |
|||
"value": self.model.value, |
|||
"entropy": self.model.entropy, |
|||
"learning_rate": self.model.learning_rate, |
|||
} |
|||
) |
|||
if self.use_continuous_act: |
|||
self.inference_dict["pre_action"] = self.model.output_pre |
|||
if self.use_recurrent: |
|||
self.inference_dict["memory_out"] = self.model.memory_out |
|||
if ( |
|||
is_training |
|||
and self.use_vec_obs |
|||
and trainer_params["normalize"] |
|||
and not load |
|||
): |
|||
self.inference_dict["update_mean"] = self.model.update_normalization |
|||
|
|||
self.total_policy_loss = self.model.abs_policy_loss |
|||
self.update_dict.update( |
|||
{ |
|||
"value_loss": avg_value_loss, |
|||
"policy_loss": avg_policy_loss, |
|||
"update_batch": update_batch, |
|||
} |
|||
) |
|||
|
|||
def create_reward_signals(self, reward_signal_configs): |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
with self.graph.as_default(): |
|||
with tf.variable_scope(TOWER_SCOPE_NAME, reuse=tf.AUTO_REUSE): |
|||
for device_id, device in enumerate(self.devices): |
|||
with tf.device(device): |
|||
reward_tower = {} |
|||
for reward_signal, config in reward_signal_configs.items(): |
|||
reward_tower[reward_signal] = create_reward_signal( |
|||
self, self.towers[device_id], reward_signal, config |
|||
) |
|||
for k, v in reward_tower[reward_signal].update_dict.items(): |
|||
self.update_dict[k + "_" + str(device_id)] = v |
|||
self.reward_signal_towers.append(reward_tower) |
|||
for _, reward_tower in self.reward_signal_towers[0].items(): |
|||
for _, update_key in reward_tower.stats_name_to_update_name.items(): |
|||
all_reward_signal_stats = tf.stack( |
|||
[ |
|||
self.update_dict[update_key + "_" + str(i)] |
|||
for i in range(len(self.towers)) |
|||
] |
|||
) |
|||
mean_reward_signal_stats = tf.reduce_mean( |
|||
all_reward_signal_stats, 0 |
|||
) |
|||
self.update_dict.update({update_key: mean_reward_signal_stats}) |
|||
|
|||
self.reward_signals = self.reward_signal_towers[0] |
|||
|
|||
@timed |
|||
def update(self, mini_batch, num_sequences): |
|||
""" |
|||
Updates model using buffer. |
|||
:param n_sequences: Number of trajectories in batch. |
|||
:param mini_batch: Experience batch. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = {} |
|||
stats_needed = self.stats_name_to_update_name |
|||
|
|||
device_batch_size = num_sequences // len(self.devices) |
|||
device_batches = [] |
|||
for i in range(len(self.devices)): |
|||
device_batches.append( |
|||
{ |
|||
k: v[ |
|||
i * device_batch_size : i * device_batch_size |
|||
+ device_batch_size |
|||
] |
|||
for (k, v) in mini_batch.items() |
|||
} |
|||
) |
|||
|
|||
for batch, tower, reward_tower in zip( |
|||
device_batches, self.towers, self.reward_signal_towers |
|||
): |
|||
feed_dict.update(self.construct_feed_dict(tower, batch, num_sequences)) |
|||
stats_needed.update(self.stats_name_to_update_name) |
|||
for _, reward_signal in reward_tower.items(): |
|||
feed_dict.update( |
|||
reward_signal.prepare_update(tower, batch, num_sequences) |
|||
) |
|||
stats_needed.update(reward_signal.stats_name_to_update_name) |
|||
|
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
update_stats = {} |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def average_gradients(self, tower_grads): |
|||
""" |
|||
Average gradients from all towers |
|||
:param tower_grads: Gradients from all towers |
|||
""" |
|||
average_grads = [] |
|||
for grad_and_vars in zip(*tower_grads): |
|||
grads = [g for g, _ in grad_and_vars if g is not None] |
|||
if not grads: |
|||
continue |
|||
avg_grad = tf.reduce_mean(tf.stack(grads), 0) |
|||
var = grad_and_vars[0][1] |
|||
average_grads.append((avg_grad, var)) |
|||
return average_grads |
|||
|
|||
|
|||
def get_devices() -> List[str]: |
|||
""" |
|||
Get all available GPU devices |
|||
""" |
|||
local_device_protos = device_lib.list_local_devices() |
|||
devices = [x.name for x in local_device_protos if x.device_type == "GPU"] |
|||
return devices |
|
|||
import logging |
|||
import numpy as np |
|||
from typing import Any, Dict, Optional, List |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents_envs.base_env import BatchedStepResult |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.models import EncoderType, LearningRateSchedule |
|||
from mlagents.trainers.ppo.models import PPOModel |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|||
create_reward_signal, |
|||
) |
|||
from mlagents.trainers.components.bc.module import BCModule |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class PPOPolicy(TFPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
): |
|||
""" |
|||
Policy for Proximal Policy Optimization Networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
""" |
|||
super().__init__(seed, brain, trainer_params) |
|||
|
|||
reward_signal_configs = trainer_params["reward_signals"] |
|||
self.inference_dict: Dict[str, tf.Tensor] = {} |
|||
self.update_dict: Dict[str, tf.Tensor] = {} |
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
|
|||
self.create_model( |
|||
brain, trainer_params, reward_signal_configs, is_training, load, seed |
|||
) |
|||
self.create_reward_signals(reward_signal_configs) |
|||
|
|||
with self.graph.as_default(): |
|||
self.bc_module: Optional[BCModule] = None |
|||
# Create pretrainer if needed |
|||
if "behavioral_cloning" in trainer_params: |
|||
BCModule.check_config(trainer_params["behavioral_cloning"]) |
|||
self.bc_module = BCModule( |
|||
self, |
|||
policy_learning_rate=trainer_params["learning_rate"], |
|||
default_batch_size=trainer_params["batch_size"], |
|||
default_num_epoch=3, |
|||
**trainer_params["behavioral_cloning"], |
|||
) |
|||
|
|||
if load: |
|||
self._load_graph() |
|||
else: |
|||
self._initialize_graph() |
|||
|
|||
def create_model( |
|||
self, brain, trainer_params, reward_signal_configs, is_training, load, seed |
|||
): |
|||
""" |
|||
Create PPO model |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param reward_signal_configs: Reward signal config |
|||
:param seed: Random seed. |
|||
""" |
|||
with self.graph.as_default(): |
|||
self.model = PPOModel( |
|||
brain=brain, |
|||
lr=float(trainer_params["learning_rate"]), |
|||
lr_schedule=LearningRateSchedule( |
|||
trainer_params.get("learning_rate_schedule", "linear") |
|||
), |
|||
h_size=int(trainer_params["hidden_units"]), |
|||
epsilon=float(trainer_params["epsilon"]), |
|||
beta=float(trainer_params["beta"]), |
|||
max_step=float(trainer_params["max_steps"]), |
|||
normalize=trainer_params["normalize"], |
|||
use_recurrent=trainer_params["use_recurrent"], |
|||
num_layers=int(trainer_params["num_layers"]), |
|||
m_size=self.m_size, |
|||
seed=seed, |
|||
stream_names=list(reward_signal_configs.keys()), |
|||
vis_encode_type=EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
), |
|||
) |
|||
self.model.create_ppo_optimizer() |
|||
|
|||
self.inference_dict.update( |
|||
{ |
|||
"action": self.model.output, |
|||
"log_probs": self.model.all_log_probs, |
|||
"entropy": self.model.entropy, |
|||
"learning_rate": self.model.learning_rate, |
|||
} |
|||
) |
|||
if self.use_continuous_act: |
|||
self.inference_dict["pre_action"] = self.model.output_pre |
|||
if self.use_recurrent: |
|||
self.inference_dict["memory_out"] = self.model.memory_out |
|||
|
|||
self.total_policy_loss = self.model.abs_policy_loss |
|||
self.update_dict.update( |
|||
{ |
|||
"value_loss": self.model.value_loss, |
|||
"policy_loss": self.total_policy_loss, |
|||
"update_batch": self.model.update_batch, |
|||
} |
|||
) |
|||
|
|||
def create_reward_signals(self, reward_signal_configs): |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
self.reward_signals = {} |
|||
with self.graph.as_default(): |
|||
# Create reward signals |
|||
for reward_signal, config in reward_signal_configs.items(): |
|||
self.reward_signals[reward_signal] = create_reward_signal( |
|||
self, self.model, reward_signal, config |
|||
) |
|||
self.update_dict.update(self.reward_signals[reward_signal].update_dict) |
|||
|
|||
@timed |
|||
def evaluate( |
|||
self, batched_step_result: BatchedStepResult, global_agent_ids: List[str] |
|||
) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param batched_step_result: BatchedStepResult object containing inputs. |
|||
:param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
feed_dict = { |
|||
self.model.batch_size: batched_step_result.n_agents(), |
|||
self.model.sequence_length: 1, |
|||
} |
|||
epsilon = None |
|||
if self.use_recurrent: |
|||
if not self.use_continuous_act: |
|||
feed_dict[self.model.prev_action] = self.retrieve_previous_action( |
|||
global_agent_ids |
|||
) |
|||
feed_dict[self.model.memory_in] = self.retrieve_memories(global_agent_ids) |
|||
if self.use_continuous_act: |
|||
epsilon = np.random.normal( |
|||
size=(batched_step_result.n_agents(), self.model.act_size[0]) |
|||
) |
|||
feed_dict[self.model.epsilon] = epsilon |
|||
feed_dict = self.fill_eval_dict(feed_dict, batched_step_result) |
|||
run_out = self._execute_model(feed_dict, self.inference_dict) |
|||
return run_out |
|||
|
|||
@timed |
|||
def update(self, mini_batch, num_sequences): |
|||
""" |
|||
Performs update on model. |
|||
:param mini_batch: Batch of experiences. |
|||
:param num_sequences: Number of sequences to process. |
|||
:return: Results of update. |
|||
""" |
|||
feed_dict = self.construct_feed_dict(self.model, mini_batch, num_sequences) |
|||
stats_needed = self.stats_name_to_update_name |
|||
update_stats = {} |
|||
# Collect feed dicts for all reward signals. |
|||
for _, reward_signal in self.reward_signals.items(): |
|||
feed_dict.update( |
|||
reward_signal.prepare_update(self.model, mini_batch, num_sequences) |
|||
) |
|||
stats_needed.update(reward_signal.stats_name_to_update_name) |
|||
|
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def construct_feed_dict(self, model, mini_batch, num_sequences): |
|||
feed_dict = { |
|||
model.batch_size: num_sequences, |
|||
model.sequence_length: self.sequence_length, |
|||
model.mask_input: mini_batch["masks"], |
|||
model.advantage: mini_batch["advantages"], |
|||
model.all_old_log_probs: mini_batch["action_probs"], |
|||
} |
|||
for name in self.reward_signals: |
|||
feed_dict[model.returns_holders[name]] = mini_batch[ |
|||
"{}_returns".format(name) |
|||
] |
|||
feed_dict[model.old_values[name]] = mini_batch[ |
|||
"{}_value_estimates".format(name) |
|||
] |
|||
|
|||
if self.use_continuous_act: |
|||
feed_dict[model.output_pre] = mini_batch["actions_pre"] |
|||
else: |
|||
feed_dict[model.action_holder] = mini_batch["actions"] |
|||
if self.use_recurrent: |
|||
feed_dict[model.prev_action] = mini_batch["prev_action"] |
|||
feed_dict[model.action_masks] = mini_batch["action_mask"] |
|||
if self.use_vec_obs: |
|||
feed_dict[model.vector_in] = mini_batch["vector_obs"] |
|||
if self.model.vis_obs_size > 0: |
|||
for i, _ in enumerate(self.model.visual_in): |
|||
feed_dict[model.visual_in[i]] = mini_batch["visual_obs%d" % i] |
|||
if self.use_recurrent: |
|||
mem_in = [ |
|||
mini_batch["memory"][i] |
|||
for i in range(0, len(mini_batch["memory"]), self.sequence_length) |
|||
] |
|||
feed_dict[model.memory_in] = mem_in |
|||
return feed_dict |
1001
ml-agents/mlagents/trainers/sac/models.py
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
import logging |
|||
from typing import Dict, Any, Optional, Mapping, List |
|||
import numpy as np |
|||
from mlagents.tf_utils import tf |
|||
|
|||
from mlagents_envs.timers import timed |
|||
from mlagents_envs.base_env import BatchedStepResult |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.models import EncoderType, LearningRateSchedule |
|||
from mlagents.trainers.sac.models import SACModel |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
from mlagents.trainers.components.reward_signals.reward_signal_factory import ( |
|||
create_reward_signal, |
|||
) |
|||
from mlagents.trainers.components.reward_signals import RewardSignal |
|||
from mlagents.trainers.components.bc.module import BCModule |
|||
|
|||
logger = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class SACPolicy(TFPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
) -> None: |
|||
""" |
|||
Policy for Proximal Policy Optimization Networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned Brain object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
""" |
|||
super().__init__(seed, brain, trainer_params) |
|||
|
|||
reward_signal_configs = {} |
|||
for key, rsignal in trainer_params["reward_signals"].items(): |
|||
if type(rsignal) is dict: |
|||
reward_signal_configs[key] = rsignal |
|||
|
|||
self.inference_dict: Dict[str, tf.Tensor] = {} |
|||
self.update_dict: Dict[str, tf.Tensor] = {} |
|||
self.create_model( |
|||
brain, trainer_params, reward_signal_configs, is_training, load, seed |
|||
) |
|||
self.create_reward_signals(reward_signal_configs) |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
"Losses/Q1 Loss": "q1_loss", |
|||
"Losses/Q2 Loss": "q2_loss", |
|||
"Policy/Entropy Coeff": "entropy_coef", |
|||
} |
|||
|
|||
with self.graph.as_default(): |
|||
# Create pretrainer if needed |
|||
self.bc_module: Optional[BCModule] = None |
|||
if "behavioral_cloning" in trainer_params: |
|||
BCModule.check_config(trainer_params["behavioral_cloning"]) |
|||
self.bc_module = BCModule( |
|||
self, |
|||
policy_learning_rate=trainer_params["learning_rate"], |
|||
default_batch_size=trainer_params["batch_size"], |
|||
default_num_epoch=1, |
|||
samples_per_update=trainer_params["batch_size"], |
|||
**trainer_params["behavioral_cloning"], |
|||
) |
|||
# SAC-specific setting - we don't want to do a whole epoch each update! |
|||
if "samples_per_update" in trainer_params["behavioral_cloning"]: |
|||
logger.warning( |
|||
"Pretraining: Samples Per Update is not a valid setting for SAC." |
|||
) |
|||
self.bc_module.samples_per_update = 1 |
|||
|
|||
if load: |
|||
self._load_graph() |
|||
else: |
|||
self._initialize_graph() |
|||
self.sess.run(self.model.target_init_op) |
|||
|
|||
# Disable terminal states for certain reward signals to avoid survivor bias |
|||
for name, reward_signal in self.reward_signals.items(): |
|||
if not reward_signal.use_terminal_states: |
|||
self.sess.run(self.model.disable_use_dones[name]) |
|||
|
|||
def create_model( |
|||
self, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
reward_signal_configs: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
seed: int, |
|||
) -> None: |
|||
with self.graph.as_default(): |
|||
self.model = SACModel( |
|||
brain, |
|||
lr=float(trainer_params["learning_rate"]), |
|||
lr_schedule=LearningRateSchedule( |
|||
trainer_params.get("learning_rate_schedule", "constant") |
|||
), |
|||
h_size=int(trainer_params["hidden_units"]), |
|||
init_entcoef=float(trainer_params["init_entcoef"]), |
|||
max_step=float(trainer_params["max_steps"]), |
|||
normalize=trainer_params["normalize"], |
|||
use_recurrent=trainer_params["use_recurrent"], |
|||
num_layers=int(trainer_params["num_layers"]), |
|||
m_size=self.m_size, |
|||
seed=seed, |
|||
stream_names=list(reward_signal_configs.keys()), |
|||
tau=float(trainer_params["tau"]), |
|||
gammas=[_val["gamma"] for _val in reward_signal_configs.values()], |
|||
vis_encode_type=EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
), |
|||
) |
|||
self.model.create_sac_optimizers() |
|||
|
|||
self.inference_dict.update( |
|||
{ |
|||
"action": self.model.output, |
|||
"log_probs": self.model.all_log_probs, |
|||
"entropy": self.model.entropy, |
|||
"learning_rate": self.model.learning_rate, |
|||
} |
|||
) |
|||
if self.use_continuous_act: |
|||
self.inference_dict["pre_action"] = self.model.output_pre |
|||
if self.use_recurrent: |
|||
self.inference_dict["memory_out"] = self.model.memory_out |
|||
|
|||
self.update_dict.update( |
|||
{ |
|||
"value_loss": self.model.total_value_loss, |
|||
"policy_loss": self.model.policy_loss, |
|||
"q1_loss": self.model.q1_loss, |
|||
"q2_loss": self.model.q2_loss, |
|||
"entropy_coef": self.model.ent_coef, |
|||
"entropy": self.model.entropy, |
|||
"update_batch": self.model.update_batch_policy, |
|||
"update_value": self.model.update_batch_value, |
|||
"update_entropy": self.model.update_batch_entropy, |
|||
} |
|||
) |
|||
|
|||
def create_reward_signals(self, reward_signal_configs: Dict[str, Any]) -> None: |
|||
""" |
|||
Create reward signals |
|||
:param reward_signal_configs: Reward signal config. |
|||
""" |
|||
self.reward_signals: Dict[str, RewardSignal] = {} |
|||
with self.graph.as_default(): |
|||
# Create reward signals |
|||
for reward_signal, config in reward_signal_configs.items(): |
|||
if type(config) is dict: |
|||
self.reward_signals[reward_signal] = create_reward_signal( |
|||
self, self.model, reward_signal, config |
|||
) |
|||
|
|||
def evaluate( |
|||
self, batched_step_result: BatchedStepResult, global_agent_ids: List[str] |
|||
) -> Dict[str, np.ndarray]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param batched_step_result: BatchedStepResult object containing inputs. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
feed_dict = { |
|||
self.model.batch_size: batched_step_result.n_agents(), |
|||
self.model.sequence_length: 1, |
|||
} |
|||
if self.use_recurrent: |
|||
if not self.use_continuous_act: |
|||
feed_dict[self.model.prev_action] = self.retrieve_previous_action( |
|||
global_agent_ids |
|||
) |
|||
feed_dict[self.model.memory_in] = self.retrieve_memories(global_agent_ids) |
|||
|
|||
feed_dict = self.fill_eval_dict(feed_dict, batched_step_result) |
|||
run_out = self._execute_model(feed_dict, self.inference_dict) |
|||
return run_out |
|||
|
|||
@timed |
|||
def update( |
|||
self, mini_batch: Dict[str, Any], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param mini_batch: Experience batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = self.construct_feed_dict(self.model, mini_batch, num_sequences) |
|||
stats_needed = self.stats_name_to_update_name |
|||
update_stats: Dict[str, float] = {} |
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
# Update target network. By default, target update happens at every policy update. |
|||
self.sess.run(self.model.target_update_op) |
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Mapping[str, Dict], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Only update the reward signals. |
|||
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
""" |
|||
# Collect feed dicts for all reward signals. |
|||
feed_dict: Dict[tf.Tensor, Any] = {} |
|||
update_dict: Dict[str, tf.Tensor] = {} |
|||
update_stats: Dict[str, float] = {} |
|||
stats_needed: Dict[str, str] = {} |
|||
if reward_signal_minibatches: |
|||
self.add_reward_signal_dicts( |
|||
feed_dict, |
|||
update_dict, |
|||
stats_needed, |
|||
reward_signal_minibatches, |
|||
num_sequences, |
|||
) |
|||
update_vals = self._execute_model(feed_dict, update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def add_reward_signal_dicts( |
|||
self, |
|||
feed_dict: Dict[tf.Tensor, Any], |
|||
update_dict: Dict[str, tf.Tensor], |
|||
stats_needed: Dict[str, str], |
|||
reward_signal_minibatches: Mapping[str, Dict], |
|||
num_sequences: int, |
|||
) -> None: |
|||
""" |
|||
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict. |
|||
:param feed_dict: Feed dict needed update |
|||
:param update_dit: Update dict that needs update |
|||
:param stats_needed: Stats needed to get from the update. |
|||
:param reward_signal_minibatches: Minibatches to use for updating the reward signals, |
|||
indexed by name. |
|||
""" |
|||
for name, r_mini_batch in reward_signal_minibatches.items(): |
|||
feed_dict.update( |
|||
self.reward_signals[name].prepare_update( |
|||
self.model, r_mini_batch, num_sequences |
|||
) |
|||
) |
|||
update_dict.update(self.reward_signals[name].update_dict) |
|||
stats_needed.update(self.reward_signals[name].stats_name_to_update_name) |
|||
|
|||
def construct_feed_dict( |
|||
self, model: SACModel, mini_batch: Dict[str, Any], num_sequences: int |
|||
) -> Dict[tf.Tensor, Any]: |
|||
""" |
|||
Builds the feed dict for updating the SAC model. |
|||
:param model: The model to update. May be different when, e.g. using multi-GPU. |
|||
:param mini_batch: Mini-batch to use to update. |
|||
:param num_sequences: Number of LSTM sequences in mini_batch. |
|||
""" |
|||
feed_dict = { |
|||
self.model.batch_size: num_sequences, |
|||
self.model.sequence_length: self.sequence_length, |
|||
self.model.next_sequence_length: self.sequence_length, |
|||
self.model.mask_input: mini_batch["masks"], |
|||
} |
|||
for name in self.reward_signals: |
|||
feed_dict[model.rewards_holders[name]] = mini_batch[ |
|||
"{}_rewards".format(name) |
|||
] |
|||
|
|||
if self.use_continuous_act: |
|||
feed_dict[model.action_holder] = mini_batch["actions"] |
|||
else: |
|||
feed_dict[model.action_holder] = mini_batch["actions"] |
|||
if self.use_recurrent: |
|||
feed_dict[model.prev_action] = mini_batch["prev_action"] |
|||
feed_dict[model.action_masks] = mini_batch["action_mask"] |
|||
if self.use_vec_obs: |
|||
feed_dict[model.vector_in] = mini_batch["vector_obs"] |
|||
feed_dict[model.next_vector_in] = mini_batch["next_vector_in"] |
|||
if self.model.vis_obs_size > 0: |
|||
for i, _ in enumerate(model.visual_in): |
|||
_obs = mini_batch["visual_obs%d" % i] |
|||
feed_dict[model.visual_in[i]] = _obs |
|||
for i, _ in enumerate(model.next_visual_in): |
|||
_obs = mini_batch["next_visual_obs%d" % i] |
|||
feed_dict[model.next_visual_in[i]] = _obs |
|||
if self.use_recurrent: |
|||
mem_in = [ |
|||
mini_batch["memory"][i] |
|||
for i in range(0, len(mini_batch["memory"]), self.sequence_length) |
|||
] |
|||
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. |
|||
offset = 1 if self.sequence_length > 1 else 0 |
|||
next_mem_in = [ |
|||
mini_batch["memory"][i][ |
|||
: self.m_size // 4 |
|||
] # only pass value part of memory to target network |
|||
for i in range(offset, len(mini_batch["memory"]), self.sequence_length) |
|||
] |
|||
feed_dict[model.memory_in] = mem_in |
|||
feed_dict[model.next_memory_in] = next_mem_in |
|||
feed_dict[model.dones_holder] = mini_batch["done"] |
|||
return feed_dict |
|
|||
from unittest import mock |
|||
import pytest |
|||
|
|||
from mlagents.tf_utils import tf |
|||
import yaml |
|||
|
|||
from mlagents.trainers.ppo.multi_gpu_policy import MultiGpuPPOPolicy |
|||
from mlagents.trainers.tests.mock_brain import create_mock_brainparams |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.safe_load( |
|||
""" |
|||
trainer: ppo |
|||
batch_size: 32 |
|||
beta: 5.0e-3 |
|||
buffer_size: 512 |
|||
epsilon: 0.2 |
|||
hidden_units: 128 |
|||
lambd: 0.95 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 5.0e4 |
|||
normalize: true |
|||
num_epoch: 5 |
|||
num_layers: 2 |
|||
time_horizon: 64 |
|||
sequence_length: 64 |
|||
summary_freq: 1000 |
|||
use_recurrent: false |
|||
memory_size: 8 |
|||
curiosity_strength: 0.0 |
|||
curiosity_enc_size: 1 |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
""" |
|||
) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices") |
|||
def test_create_model(mock_get_devices, dummy_config): |
|||
tf.reset_default_graph() |
|||
mock_get_devices.return_value = [ |
|||
"/device:GPU:0", |
|||
"/device:GPU:1", |
|||
"/device:GPU:2", |
|||
"/device:GPU:3", |
|||
] |
|||
|
|||
trainer_parameters = dummy_config |
|||
trainer_parameters["model_path"] = "" |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
brain = create_mock_brainparams() |
|||
|
|||
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False) |
|||
assert len(policy.towers) == len(mock_get_devices.return_value) |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices") |
|||
def test_average_gradients(mock_get_devices, dummy_config): |
|||
tf.reset_default_graph() |
|||
mock_get_devices.return_value = [ |
|||
"/device:GPU:0", |
|||
"/device:GPU:1", |
|||
"/device:GPU:2", |
|||
"/device:GPU:3", |
|||
] |
|||
|
|||
trainer_parameters = dummy_config |
|||
trainer_parameters["model_path"] = "" |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
brain = create_mock_brainparams() |
|||
with tf.Session() as sess: |
|||
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False) |
|||
var = tf.Variable(0) |
|||
tower_grads = [ |
|||
[(tf.constant(0.1), var)], |
|||
[(tf.constant(0.2), var)], |
|||
[(tf.constant(0.3), var)], |
|||
[(tf.constant(0.4), var)], |
|||
] |
|||
avg_grads = policy.average_gradients(tower_grads) |
|||
|
|||
init = tf.global_variables_initializer() |
|||
sess.run(init) |
|||
run_out = sess.run(avg_grads) |
|||
assert run_out == [(0.25, 0)] |
|||
|
|||
|
|||
@mock.patch("mlagents.trainers.tf_policy.TFPolicy._execute_model") |
|||
@mock.patch("mlagents.trainers.ppo.policy.PPOPolicy.construct_feed_dict") |
|||
@mock.patch("mlagents.trainers.ppo.multi_gpu_policy.get_devices") |
|||
def test_update( |
|||
mock_get_devices, mock_construct_feed_dict, mock_execute_model, dummy_config |
|||
): |
|||
tf.reset_default_graph() |
|||
mock_get_devices.return_value = ["/device:GPU:0", "/device:GPU:1"] |
|||
mock_construct_feed_dict.return_value = {} |
|||
mock_execute_model.return_value = { |
|||
"value_loss": 0.1, |
|||
"policy_loss": 0.3, |
|||
"update_batch": None, |
|||
} |
|||
|
|||
trainer_parameters = dummy_config |
|||
trainer_parameters["model_path"] = "" |
|||
trainer_parameters["keep_checkpoints"] = 3 |
|||
brain = create_mock_brainparams() |
|||
policy = MultiGpuPPOPolicy(0, brain, trainer_parameters, False, False) |
|||
mock_mini_batch = mock.Mock() |
|||
mock_mini_batch.items.return_value = [("action", [1, 2]), ("value", [3, 4])] |
|||
run_out = policy.update(mock_mini_batch, 1) |
|||
|
|||
assert mock_mini_batch.items.call_count == len(mock_get_devices.return_value) |
|||
assert mock_construct_feed_dict.call_count == len(mock_get_devices.return_value) |
|||
assert run_out["Losses/Value Loss"] == 0.1 |
|||
assert run_out["Losses/Policy Loss"] == 0.3 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
撰写
预览
正在加载...
取消
保存
Reference in new issue