浏览代码
init sac transfer, and added action encoder to bisim; configs for crawler
/develop/bisim-sac-transfer
init sac transfer, and added action encoder to bisim; configs for crawler
/develop/bisim-sac-transfer
yanchaosun
4 年前
当前提交
80bad241
共有 9 个文件被更改,包括 1462 次插入 和 14 次删除
-
7config/ppo_transfer/CrawlerStatic.yaml
-
22config/ppo_transfer/TransferCrawlerStatic.yaml
-
15ml-agents/mlagents/trainers/policy/transfer_policy.py
-
6ml-agents/mlagents/trainers/tests/transfer_test_envs.py
-
0ml-agents/mlagents/trainers/sac_transfer/__init__.py
-
445ml-agents/mlagents/trainers/sac_transfer/network.py
-
641ml-agents/mlagents/trainers/sac_transfer/optimizer.py
-
340ml-agents/mlagents/trainers/sac_transfer/trainer.py
|
|||
from typing import Dict, Optional |
|||
from mlagents.tf_utils import tf |
|||
from mlagents.trainers.models import ModelUtils, EncoderType |
|||
|
|||
LOG_STD_MAX = 2 |
|||
LOG_STD_MIN = -20 |
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
DISCRETE_TARGET_ENTROPY_SCALE = 0.2 # Roughly equal to e-greedy 0.05 |
|||
CONTINUOUS_TARGET_ENTROPY_SCALE = 1.0 # TODO: Make these an optional hyperparam. |
|||
POLICY_SCOPE = "" |
|||
TARGET_SCOPE = "target_network" |
|||
|
|||
|
|||
class SACNetwork: |
|||
""" |
|||
Base class for an SAC network. Implements methods for creating the actor and critic heads. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
policy=None, |
|||
m_size=None, |
|||
h_size=128, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
self.normalize = normalize |
|||
self.use_recurrent = use_recurrent |
|||
self.num_layers = num_layers |
|||
self.stream_names = stream_names |
|||
self.h_size = h_size |
|||
self.activ_fn = ModelUtils.swish |
|||
|
|||
self.sequence_length_ph = tf.placeholder( |
|||
shape=None, dtype=tf.int32, name="sac_sequence_length" |
|||
) |
|||
|
|||
self.policy_memory_in: Optional[tf.Tensor] = None |
|||
self.policy_memory_out: Optional[tf.Tensor] = None |
|||
self.value_memory_in: Optional[tf.Tensor] = None |
|||
self.value_memory_out: Optional[tf.Tensor] = None |
|||
self.q1: Optional[tf.Tensor] = None |
|||
self.q2: Optional[tf.Tensor] = None |
|||
self.q1_p: Optional[tf.Tensor] = None |
|||
self.q2_p: Optional[tf.Tensor] = None |
|||
self.q1_memory_in: Optional[tf.Tensor] = None |
|||
self.q2_memory_in: Optional[tf.Tensor] = None |
|||
self.q1_memory_out: Optional[tf.Tensor] = None |
|||
self.q2_memory_out: Optional[tf.Tensor] = None |
|||
self.prev_action: Optional[tf.Tensor] = None |
|||
self.action_masks: Optional[tf.Tensor] = None |
|||
self.external_action_in: Optional[tf.Tensor] = None |
|||
self.log_sigma_sq: Optional[tf.Tensor] = None |
|||
self.entropy: Optional[tf.Tensor] = None |
|||
self.deterministic_output: Optional[tf.Tensor] = None |
|||
self.normalized_logprobs: Optional[tf.Tensor] = None |
|||
self.action_probs: Optional[tf.Tensor] = None |
|||
self.output_oh: Optional[tf.Tensor] = None |
|||
self.output_pre: Optional[tf.Tensor] = None |
|||
|
|||
self.value_vars = None |
|||
self.q_vars = None |
|||
self.critic_vars = None |
|||
self.policy_vars = None |
|||
|
|||
self.q1_heads: Dict[str, tf.Tensor] = None |
|||
self.q2_heads: Dict[str, tf.Tensor] = None |
|||
self.q1_pheads: Dict[str, tf.Tensor] = None |
|||
self.q2_pheads: Dict[str, tf.Tensor] = None |
|||
|
|||
self.policy = policy |
|||
|
|||
def get_vars(self, scope): |
|||
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) |
|||
|
|||
def join_scopes(self, scope_1, scope_2): |
|||
""" |
|||
Joins two scopes. Does so safetly (i.e., if one of the two scopes doesn't |
|||
exist, don't add any backslashes) |
|||
""" |
|||
if not scope_1: |
|||
return scope_2 |
|||
if not scope_2: |
|||
return scope_1 |
|||
else: |
|||
return "/".join(filter(None, [scope_1, scope_2])) |
|||
|
|||
def create_value_heads(self, stream_names, hidden_input): |
|||
""" |
|||
Creates one value estimator head for each reward signal in stream_names. |
|||
Also creates the node corresponding to the mean of all the value heads in self.value. |
|||
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal. |
|||
:param stream_names: The list of reward signal names |
|||
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top |
|||
of the hidden input. |
|||
""" |
|||
self.value_heads = {} |
|||
for name in stream_names: |
|||
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name)) |
|||
self.value_heads[name] = value |
|||
self.value = tf.reduce_mean(list(self.value_heads.values()), 0) |
|||
|
|||
def _create_cc_critic(self, hidden_value, scope, create_qs=True): |
|||
""" |
|||
Creates just the critic network |
|||
""" |
|||
scope = self.join_scopes(scope, "critic") |
|||
self.create_sac_value_head( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "value"), |
|||
) |
|||
self.external_action_in = tf.placeholder( |
|||
shape=[None, self.policy.act_size[0]], |
|||
dtype=tf.float32, |
|||
name="external_action_in", |
|||
) |
|||
self.value_vars = self.get_vars(self.join_scopes(scope, "value")) |
|||
if create_qs: |
|||
hidden_q = tf.concat([hidden_value, self.external_action_in], axis=-1) |
|||
hidden_qp = tf.concat([hidden_value, self.policy.output], axis=-1) |
|||
self.q1_heads, self.q2_heads, self.q1, self.q2 = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_q, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
) |
|||
self.q1_pheads, self.q2_pheads, self.q1_p, self.q2_p = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_qp, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
reuse=True, |
|||
) |
|||
self.q_vars = self.get_vars(self.join_scopes(scope, "q")) |
|||
self.critic_vars = self.get_vars(scope) |
|||
|
|||
def _create_dc_critic(self, hidden_value, scope, create_qs=True): |
|||
""" |
|||
Creates just the critic network |
|||
""" |
|||
scope = self.join_scopes(scope, "critic") |
|||
self.create_sac_value_head( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "value"), |
|||
) |
|||
|
|||
self.value_vars = self.get_vars("/".join([scope, "value"])) |
|||
|
|||
if create_qs: |
|||
self.q1_heads, self.q2_heads, self.q1, self.q2 = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
num_outputs=sum(self.policy.act_size), |
|||
) |
|||
self.q1_pheads, self.q2_pheads, self.q1_p, self.q2_p = self.create_q_heads( |
|||
self.stream_names, |
|||
hidden_value, |
|||
self.num_layers, |
|||
self.h_size, |
|||
self.join_scopes(scope, "q"), |
|||
reuse=True, |
|||
num_outputs=sum(self.policy.act_size), |
|||
) |
|||
self.q_vars = self.get_vars(scope) |
|||
self.critic_vars = self.get_vars(scope) |
|||
|
|||
def create_sac_value_head( |
|||
self, stream_names, hidden_input, num_layers, h_size, scope |
|||
): |
|||
""" |
|||
Creates one value estimator head for each reward signal in stream_names. |
|||
Also creates the node corresponding to the mean of all the value heads in self.value. |
|||
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal. |
|||
:param stream_names: The list of reward signal names |
|||
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top |
|||
of the hidden input. |
|||
:param num_layers: Number of hidden layers for value network |
|||
:param h_size: size of hidden layers for value network |
|||
:param scope: TF scope for value network. |
|||
""" |
|||
with tf.variable_scope(scope): |
|||
value_hidden = ModelUtils.create_vector_observation_encoder( |
|||
hidden_input, h_size, self.activ_fn, num_layers, "encoder", False |
|||
) |
|||
if self.use_recurrent: |
|||
value_hidden, memory_out = ModelUtils.create_recurrent_encoder( |
|||
value_hidden, |
|||
self.value_memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_value", |
|||
) |
|||
self.value_memory_out = memory_out |
|||
self.create_value_heads(stream_names, value_hidden) |
|||
|
|||
def create_q_heads( |
|||
self, |
|||
stream_names, |
|||
hidden_input, |
|||
num_layers, |
|||
h_size, |
|||
scope, |
|||
reuse=False, |
|||
num_outputs=1, |
|||
): |
|||
""" |
|||
Creates two q heads for each reward signal in stream_names. |
|||
Also creates the node corresponding to the mean of all the value heads in self.value. |
|||
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal. |
|||
:param stream_names: The list of reward signal names |
|||
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top |
|||
of the hidden input. |
|||
:param num_layers: Number of hidden layers for Q network |
|||
:param h_size: size of hidden layers for Q network |
|||
:param scope: TF scope for Q network. |
|||
:param reuse: Whether or not to reuse variables. Useful for creating Q of policy. |
|||
:param num_outputs: Number of outputs of each Q function. If discrete, equal to number of actions. |
|||
""" |
|||
with tf.variable_scope(self.join_scopes(scope, "q1_encoding"), reuse=reuse): |
|||
q1_hidden = ModelUtils.create_vector_observation_encoder( |
|||
hidden_input, h_size, self.activ_fn, num_layers, "q1_encoder", reuse |
|||
) |
|||
if self.use_recurrent: |
|||
q1_hidden, memory_out = ModelUtils.create_recurrent_encoder( |
|||
q1_hidden, |
|||
self.q1_memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_q1", |
|||
) |
|||
self.q1_memory_out = memory_out |
|||
|
|||
q1_heads = {} |
|||
for name in stream_names: |
|||
_q1 = tf.layers.dense(q1_hidden, num_outputs, name="{}_q1".format(name)) |
|||
q1_heads[name] = _q1 |
|||
|
|||
q1 = tf.reduce_mean(list(q1_heads.values()), axis=0) |
|||
with tf.variable_scope(self.join_scopes(scope, "q2_encoding"), reuse=reuse): |
|||
q2_hidden = ModelUtils.create_vector_observation_encoder( |
|||
hidden_input, h_size, self.activ_fn, num_layers, "q2_encoder", reuse |
|||
) |
|||
if self.use_recurrent: |
|||
q2_hidden, memory_out = ModelUtils.create_recurrent_encoder( |
|||
q2_hidden, |
|||
self.q2_memory_in, |
|||
self.sequence_length_ph, |
|||
name="lstm_q2", |
|||
) |
|||
self.q2_memory_out = memory_out |
|||
|
|||
q2_heads = {} |
|||
for name in stream_names: |
|||
_q2 = tf.layers.dense(q2_hidden, num_outputs, name="{}_q2".format(name)) |
|||
q2_heads[name] = _q2 |
|||
|
|||
q2 = tf.reduce_mean(list(q2_heads.values()), axis=0) |
|||
|
|||
return q1_heads, q2_heads, q1, q2 |
|||
|
|||
|
|||
class SACTargetNetwork(SACNetwork): |
|||
""" |
|||
Instantiation for the SAC target network. Only contains a single |
|||
value estimator and is updated from the Policy Network. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
policy, |
|||
m_size=None, |
|||
h_size=128, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
super().__init__( |
|||
policy, |
|||
m_size, |
|||
h_size, |
|||
normalize, |
|||
use_recurrent, |
|||
num_layers, |
|||
stream_names, |
|||
vis_encode_type, |
|||
) |
|||
with tf.variable_scope(TARGET_SCOPE): |
|||
self.visual_in = ModelUtils.create_visual_input_placeholders( |
|||
policy.brain.camera_resolutions |
|||
) |
|||
self.vector_in = ModelUtils.create_vector_input(policy.vec_obs_size) |
|||
if self.policy.normalize: |
|||
normalization_tensors = ModelUtils.create_normalizer(self.vector_in) |
|||
self.update_normalization_op = normalization_tensors.update_op |
|||
self.normalization_steps = normalization_tensors.steps |
|||
self.running_mean = normalization_tensors.running_mean |
|||
self.running_variance = normalization_tensors.running_variance |
|||
self.processed_vector_in = ModelUtils.normalize_vector_obs( |
|||
self.vector_in, |
|||
self.running_mean, |
|||
self.running_variance, |
|||
self.normalization_steps, |
|||
) |
|||
else: |
|||
self.processed_vector_in = self.vector_in |
|||
self.update_normalization_op = None |
|||
|
|||
if self.policy.use_recurrent: |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in" |
|||
) |
|||
self.value_memory_in = self.memory_in |
|||
hidden_streams = ModelUtils.create_observation_streams( |
|||
self.visual_in, |
|||
self.processed_vector_in, |
|||
1, |
|||
self.h_size, |
|||
0, |
|||
vis_encode_type=vis_encode_type, |
|||
stream_scopes=["critic/value/"], |
|||
) |
|||
if self.policy.use_continuous_act: |
|||
self._create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|||
else: |
|||
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False) |
|||
if self.use_recurrent: |
|||
self.memory_out = tf.concat( |
|||
self.value_memory_out, axis=1 |
|||
) # Needed for Barracuda to work |
|||
|
|||
def copy_normalization(self, mean, variance, steps): |
|||
""" |
|||
Copies the mean, variance, and steps into the normalizers of the |
|||
input of this SACNetwork. Used to copy the normalizer from the policy network |
|||
to the target network. |
|||
param mean: Tensor containing the mean. |
|||
param variance: Tensor containing the variance |
|||
param steps: Tensor containing the number of steps. |
|||
""" |
|||
update_mean = tf.assign(self.running_mean, mean) |
|||
update_variance = tf.assign(self.running_variance, variance) |
|||
update_norm_step = tf.assign(self.normalization_steps, steps) |
|||
return tf.group([update_mean, update_variance, update_norm_step]) |
|||
|
|||
|
|||
class SACPolicyNetwork(SACNetwork): |
|||
""" |
|||
Instantiation for SAC policy network. Contains a dual Q estimator, |
|||
a value estimator, and a reference to the actual policy network. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
policy, |
|||
m_size=None, |
|||
h_size=128, |
|||
normalize=False, |
|||
use_recurrent=False, |
|||
num_layers=2, |
|||
stream_names=None, |
|||
vis_encode_type=EncoderType.SIMPLE, |
|||
): |
|||
super().__init__( |
|||
policy, |
|||
m_size, |
|||
h_size, |
|||
normalize, |
|||
use_recurrent, |
|||
num_layers, |
|||
stream_names, |
|||
vis_encode_type, |
|||
) |
|||
if self.policy.use_recurrent: |
|||
self._create_memory_ins(m_size) |
|||
|
|||
hidden_critic = self._create_observation_in(vis_encode_type) |
|||
self.policy.output = self.policy.output |
|||
# Use the sequence length of the policy |
|||
self.sequence_length_ph = self.policy.sequence_length_ph |
|||
|
|||
if self.policy.use_continuous_act: |
|||
self._create_cc_critic(hidden_critic, POLICY_SCOPE) |
|||
|
|||
else: |
|||
self._create_dc_critic(hidden_critic, POLICY_SCOPE) |
|||
|
|||
if self.use_recurrent: |
|||
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out] |
|||
self.memory_out = tf.concat(mem_outs, axis=1) |
|||
|
|||
def _create_memory_ins(self, m_size): |
|||
""" |
|||
Creates the memory input placeholders for LSTM. |
|||
:param m_size: the total size of the memory. |
|||
""" |
|||
self.memory_in = tf.placeholder( |
|||
shape=[None, m_size * 3], dtype=tf.float32, name="value_recurrent_in" |
|||
) |
|||
|
|||
# Re-break-up for each network |
|||
num_mems = 3 |
|||
input_size = self.memory_in.get_shape().as_list()[1] |
|||
mem_ins = [] |
|||
for i in range(num_mems): |
|||
_start = input_size // num_mems * i |
|||
_end = input_size // num_mems * (i + 1) |
|||
mem_ins.append(self.memory_in[:, _start:_end]) |
|||
self.value_memory_in = mem_ins[0] |
|||
self.q1_memory_in = mem_ins[1] |
|||
self.q2_memory_in = mem_ins[2] |
|||
|
|||
def _create_observation_in(self, vis_encode_type): |
|||
""" |
|||
Creates the observation inputs, and a CNN if needed, |
|||
:param vis_encode_type: Type of CNN encoder. |
|||
:param share_ac_cnn: Whether or not to share the actor and critic CNNs. |
|||
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used |
|||
once and thrown away. |
|||
""" |
|||
with tf.variable_scope(POLICY_SCOPE): |
|||
hidden_streams = ModelUtils.create_observation_streams( |
|||
self.policy.visual_in, |
|||
self.policy.processed_vector_in, |
|||
1, |
|||
self.h_size, |
|||
0, |
|||
vis_encode_type=vis_encode_type, |
|||
stream_scopes=["critic/value/"], |
|||
) |
|||
hidden_critic = hidden_streams[0] |
|||
return hidden_critic |
|
|||
import numpy as np |
|||
from typing import Dict, List, Optional, Any, Mapping, cast |
|||
|
|||
from mlagents.tf_utils import tf |
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork |
|||
from mlagents.trainers.models import ModelUtils |
|||
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.buffer import AgentBuffer |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
POLICY_SCOPE = "" |
|||
TARGET_SCOPE = "target_network" |
|||
|
|||
|
|||
class SACOptimizer(TFOptimizer): |
|||
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings): |
|||
""" |
|||
Takes a Unity environment and model-specific hyper-parameters and returns the |
|||
appropriate PPO agent model for the environment. |
|||
:param brain: Brain parameters used to generate specific network graph. |
|||
:param lr: Learning rate. |
|||
:param lr_schedule: Learning rate decay schedule. |
|||
:param h_size: Size of hidden layers |
|||
:param init_entcoef: Initial value for entropy coefficient. Set lower to learn faster, |
|||
set higher to explore more. |
|||
:return: a sub-class of PPOAgent tailored to the environment. |
|||
:param max_step: Total number of training steps. |
|||
:param normalize: Whether to normalize vector observation input. |
|||
:param use_recurrent: Whether to use an LSTM layer in the network. |
|||
:param num_layers: Number of hidden layers between encoded input and policy & value layers |
|||
:param tau: Strength of soft-Q update. |
|||
:param m_size: Size of brain memory. |
|||
""" |
|||
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|||
policy.create_tf_graph() |
|||
|
|||
with policy.graph.as_default(): |
|||
with tf.variable_scope(""): |
|||
super().__init__(policy, trainer_params) |
|||
hyperparameters: SACSettings = cast( |
|||
SACSettings, trainer_params.hyperparameters |
|||
) |
|||
lr = hyperparameters.learning_rate |
|||
lr_schedule = hyperparameters.learning_rate_schedule |
|||
max_step = trainer_params.max_steps |
|||
self.tau = hyperparameters.tau |
|||
self.init_entcoef = hyperparameters.init_entcoef |
|||
|
|||
self.policy = policy |
|||
self.act_size = policy.act_size |
|||
policy_network_settings = policy.network_settings |
|||
h_size = policy_network_settings.hidden_units |
|||
num_layers = policy_network_settings.num_layers |
|||
vis_encode_type = policy_network_settings.vis_encode_type |
|||
|
|||
self.tau = hyperparameters.tau |
|||
self.burn_in_ratio = 0.0 |
|||
|
|||
# Non-exposed SAC parameters |
|||
self.discrete_target_entropy_scale = ( |
|||
0.2 |
|||
) # Roughly equal to e-greedy 0.05 |
|||
self.continuous_target_entropy_scale = 1.0 |
|||
|
|||
stream_names = list(self.reward_signals.keys()) |
|||
# Use to reduce "survivor bonus" when using Curiosity or GAIL. |
|||
self.gammas = [ |
|||
_val.gamma for _val in trainer_params.reward_signals.values() |
|||
] |
|||
self.use_dones_in_backup = { |
|||
name: tf.Variable(1.0) for name in stream_names |
|||
} |
|||
self.disable_use_dones = { |
|||
name: self.use_dones_in_backup[name].assign(0.0) |
|||
for name in stream_names |
|||
} |
|||
|
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
|
|||
self.target_init_op: List[tf.Tensor] = [] |
|||
self.target_update_op: List[tf.Tensor] = [] |
|||
self.update_batch_policy: Optional[tf.Operation] = None |
|||
self.update_batch_value: Optional[tf.Operation] = None |
|||
self.update_batch_entropy: Optional[tf.Operation] = None |
|||
|
|||
self.policy_network = SACPolicyNetwork( |
|||
policy=self.policy, |
|||
m_size=self.policy.m_size, # 3x policy.m_size |
|||
h_size=h_size, |
|||
normalize=self.policy.normalize, |
|||
use_recurrent=self.policy.use_recurrent, |
|||
num_layers=num_layers, |
|||
stream_names=stream_names, |
|||
vis_encode_type=vis_encode_type, |
|||
) |
|||
self.target_network = SACTargetNetwork( |
|||
policy=self.policy, |
|||
m_size=self.policy.m_size, # 1x policy.m_size |
|||
h_size=h_size, |
|||
normalize=self.policy.normalize, |
|||
use_recurrent=self.policy.use_recurrent, |
|||
num_layers=num_layers, |
|||
stream_names=stream_names, |
|||
vis_encode_type=vis_encode_type, |
|||
) |
|||
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value) |
|||
self.m_size = 3 * self.policy.m_size |
|||
self._create_inputs_and_outputs() |
|||
self.learning_rate = ModelUtils.create_schedule( |
|||
lr_schedule, |
|||
lr, |
|||
self.policy.global_step, |
|||
int(max_step), |
|||
min_value=1e-10, |
|||
) |
|||
self._create_losses( |
|||
self.policy_network.q1_heads, |
|||
self.policy_network.q2_heads, |
|||
lr, |
|||
int(max_step), |
|||
stream_names, |
|||
discrete=not self.policy.use_continuous_act, |
|||
) |
|||
self._create_sac_optimizer_ops() |
|||
|
|||
self.selected_actions = ( |
|||
self.policy.selected_actions |
|||
) # For GAIL and other reward signals |
|||
if self.policy.normalize: |
|||
target_update_norm = self.target_network.copy_normalization( |
|||
self.policy.running_mean, |
|||
self.policy.running_variance, |
|||
self.policy.normalization_steps, |
|||
) |
|||
# Update the normalization of the optimizer when the policy does. |
|||
self.policy.update_normalization_op = tf.group( |
|||
[self.policy.update_normalization_op, target_update_norm] |
|||
) |
|||
|
|||
self.policy.initialize_or_load() |
|||
|
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
"Losses/Q1 Loss": "q1_loss", |
|||
"Losses/Q2 Loss": "q2_loss", |
|||
"Policy/Entropy Coeff": "entropy_coef", |
|||
"Policy/Learning Rate": "learning_rate", |
|||
} |
|||
|
|||
self.update_dict = { |
|||
"value_loss": self.total_value_loss, |
|||
"policy_loss": self.policy_loss, |
|||
"q1_loss": self.q1_loss, |
|||
"q2_loss": self.q2_loss, |
|||
"entropy_coef": self.ent_coef, |
|||
"update_batch": self.update_batch_policy, |
|||
"update_value": self.update_batch_value, |
|||
"update_entropy": self.update_batch_entropy, |
|||
"learning_rate": self.learning_rate, |
|||
} |
|||
|
|||
def _create_inputs_and_outputs(self) -> None: |
|||
""" |
|||
Assign the higher-level SACModel's inputs and outputs to those of its policy or |
|||
target network. |
|||
""" |
|||
self.vector_in = self.policy.vector_in |
|||
self.visual_in = self.policy.visual_in |
|||
self.next_vector_in = self.target_network.vector_in |
|||
self.next_visual_in = self.target_network.visual_in |
|||
self.sequence_length_ph = self.policy.sequence_length_ph |
|||
self.next_sequence_length_ph = self.target_network.sequence_length_ph |
|||
if not self.policy.use_continuous_act: |
|||
self.action_masks = self.policy_network.action_masks |
|||
else: |
|||
self.output_pre = self.policy_network.output_pre |
|||
|
|||
# Don't use value estimate during inference. |
|||
self.value = tf.identity( |
|||
self.policy_network.value, name="value_estimate_unused" |
|||
) |
|||
self.value_heads = self.policy_network.value_heads |
|||
self.dones_holder = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="dones_holder" |
|||
) |
|||
|
|||
if self.policy.use_recurrent: |
|||
self.memory_in = self.policy_network.memory_in |
|||
self.memory_out = self.policy_network.memory_out |
|||
if not self.policy.use_continuous_act: |
|||
self.prev_action = self.policy_network.prev_action |
|||
self.next_memory_in = self.target_network.memory_in |
|||
|
|||
def _create_losses( |
|||
self, |
|||
q1_streams: Dict[str, tf.Tensor], |
|||
q2_streams: Dict[str, tf.Tensor], |
|||
lr: tf.Tensor, |
|||
max_step: int, |
|||
stream_names: List[str], |
|||
discrete: bool = False, |
|||
) -> None: |
|||
""" |
|||
Creates training-specific Tensorflow ops for SAC models. |
|||
:param q1_streams: Q1 streams from policy network |
|||
:param q1_streams: Q2 streams from policy network |
|||
:param lr: Learning rate |
|||
:param max_step: Total number of training steps. |
|||
:param stream_names: List of reward stream names. |
|||
:param discrete: Whether or not to use discrete action losses. |
|||
""" |
|||
|
|||
if discrete: |
|||
self.target_entropy = [ |
|||
self.discrete_target_entropy_scale * np.log(i).astype(np.float32) |
|||
for i in self.act_size |
|||
] |
|||
discrete_action_probs = tf.exp(self.policy.all_log_probs) |
|||
per_action_entropy = discrete_action_probs * self.policy.all_log_probs |
|||
else: |
|||
self.target_entropy = ( |
|||
-1 |
|||
* self.continuous_target_entropy_scale |
|||
* np.prod(self.act_size[0]).astype(np.float32) |
|||
) |
|||
|
|||
self.rewards_holders = {} |
|||
self.min_policy_qs = {} |
|||
|
|||
for name in stream_names: |
|||
if discrete: |
|||
_branched_mpq1 = ModelUtils.break_into_branches( |
|||
self.policy_network.q1_pheads[name] * discrete_action_probs, |
|||
self.act_size, |
|||
) |
|||
branched_mpq1 = tf.stack( |
|||
[ |
|||
tf.reduce_sum(_br, axis=1, keep_dims=True) |
|||
for _br in _branched_mpq1 |
|||
] |
|||
) |
|||
_q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0) |
|||
|
|||
_branched_mpq2 = ModelUtils.break_into_branches( |
|||
self.policy_network.q2_pheads[name] * discrete_action_probs, |
|||
self.act_size, |
|||
) |
|||
branched_mpq2 = tf.stack( |
|||
[ |
|||
tf.reduce_sum(_br, axis=1, keep_dims=True) |
|||
for _br in _branched_mpq2 |
|||
] |
|||
) |
|||
_q2_p_mean = tf.reduce_mean(branched_mpq2, axis=0) |
|||
|
|||
self.min_policy_qs[name] = tf.minimum(_q1_p_mean, _q2_p_mean) |
|||
else: |
|||
self.min_policy_qs[name] = tf.minimum( |
|||
self.policy_network.q1_pheads[name], |
|||
self.policy_network.q2_pheads[name], |
|||
) |
|||
|
|||
rewards_holder = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="{}_rewards".format(name) |
|||
) |
|||
self.rewards_holders[name] = rewards_holder |
|||
|
|||
q1_losses = [] |
|||
q2_losses = [] |
|||
# Multiple q losses per stream |
|||
expanded_dones = tf.expand_dims(self.dones_holder, axis=-1) |
|||
for i, name in enumerate(stream_names): |
|||
_expanded_rewards = tf.expand_dims(self.rewards_holders[name], axis=-1) |
|||
|
|||
q_backup = tf.stop_gradient( |
|||
_expanded_rewards |
|||
+ (1.0 - self.use_dones_in_backup[name] * expanded_dones) |
|||
* self.gammas[i] |
|||
* self.target_network.value_heads[name] |
|||
) |
|||
|
|||
if discrete: |
|||
# We need to break up the Q functions by branch, and update them individually. |
|||
branched_q1_stream = ModelUtils.break_into_branches( |
|||
self.policy.selected_actions * q1_streams[name], self.act_size |
|||
) |
|||
branched_q2_stream = ModelUtils.break_into_branches( |
|||
self.policy.selected_actions * q2_streams[name], self.act_size |
|||
) |
|||
|
|||
# Reduce each branch into scalar |
|||
branched_q1_stream = [ |
|||
tf.reduce_sum(_branch, axis=1, keep_dims=True) |
|||
for _branch in branched_q1_stream |
|||
] |
|||
branched_q2_stream = [ |
|||
tf.reduce_sum(_branch, axis=1, keep_dims=True) |
|||
for _branch in branched_q2_stream |
|||
] |
|||
|
|||
q1_stream = tf.reduce_mean(branched_q1_stream, axis=0) |
|||
q2_stream = tf.reduce_mean(branched_q2_stream, axis=0) |
|||
|
|||
else: |
|||
q1_stream = q1_streams[name] |
|||
q2_stream = q2_streams[name] |
|||
|
|||
_q1_loss = 0.5 * tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference(q_backup, q1_stream) |
|||
) |
|||
|
|||
_q2_loss = 0.5 * tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference(q_backup, q2_stream) |
|||
) |
|||
|
|||
q1_losses.append(_q1_loss) |
|||
q2_losses.append(_q2_loss) |
|||
|
|||
self.q1_loss = tf.reduce_mean(q1_losses) |
|||
self.q2_loss = tf.reduce_mean(q2_losses) |
|||
|
|||
# Learn entropy coefficient |
|||
if discrete: |
|||
# Create a log_ent_coef for each branch |
|||
self.log_ent_coef = tf.get_variable( |
|||
"log_ent_coef", |
|||
dtype=tf.float32, |
|||
initializer=np.log([self.init_entcoef] * len(self.act_size)).astype( |
|||
np.float32 |
|||
), |
|||
trainable=True, |
|||
) |
|||
else: |
|||
self.log_ent_coef = tf.get_variable( |
|||
"log_ent_coef", |
|||
dtype=tf.float32, |
|||
initializer=np.log(self.init_entcoef).astype(np.float32), |
|||
trainable=True, |
|||
) |
|||
|
|||
self.ent_coef = tf.exp(self.log_ent_coef) |
|||
if discrete: |
|||
# We also have to do a different entropy and target_entropy per branch. |
|||
branched_per_action_ent = ModelUtils.break_into_branches( |
|||
per_action_entropy, self.act_size |
|||
) |
|||
branched_ent_sums = tf.stack( |
|||
[ |
|||
tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te |
|||
for _lp, _te in zip(branched_per_action_ent, self.target_entropy) |
|||
], |
|||
axis=1, |
|||
) |
|||
self.entropy_loss = -tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.reduce_mean( |
|||
self.log_ent_coef |
|||
* tf.squeeze(tf.stop_gradient(branched_ent_sums), axis=2), |
|||
axis=1, |
|||
) |
|||
) |
|||
|
|||
# Same with policy loss, we have to do the loss per branch and average them, |
|||
# so that larger branches don't get more weight. |
|||
# The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q |
|||
branched_q_term = ModelUtils.break_into_branches( |
|||
discrete_action_probs * self.policy_network.q1_p, self.act_size |
|||
) |
|||
|
|||
branched_policy_loss = tf.stack( |
|||
[ |
|||
tf.reduce_sum(self.ent_coef[i] * _lp - _qt, axis=1, keep_dims=True) |
|||
for i, (_lp, _qt) in enumerate( |
|||
zip(branched_per_action_ent, branched_q_term) |
|||
) |
|||
] |
|||
) |
|||
self.policy_loss = tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) * tf.squeeze(branched_policy_loss) |
|||
) |
|||
|
|||
# Do vbackup entropy bonus per branch as well. |
|||
branched_ent_bonus = tf.stack( |
|||
[ |
|||
tf.reduce_sum(self.ent_coef[i] * _lp, axis=1, keep_dims=True) |
|||
for i, _lp in enumerate(branched_per_action_ent) |
|||
] |
|||
) |
|||
value_losses = [] |
|||
for name in stream_names: |
|||
v_backup = tf.stop_gradient( |
|||
self.min_policy_qs[name] |
|||
- tf.reduce_mean(branched_ent_bonus, axis=0) |
|||
) |
|||
value_losses.append( |
|||
0.5 |
|||
* tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference( |
|||
self.policy_network.value_heads[name], v_backup |
|||
) |
|||
) |
|||
) |
|||
|
|||
else: |
|||
self.entropy_loss = -tf.reduce_mean( |
|||
self.log_ent_coef |
|||
* tf.to_float(self.policy.mask) |
|||
* tf.stop_gradient( |
|||
tf.reduce_sum( |
|||
self.policy.all_log_probs + self.target_entropy, |
|||
axis=1, |
|||
keep_dims=True, |
|||
) |
|||
) |
|||
) |
|||
batch_policy_loss = tf.reduce_mean( |
|||
self.ent_coef * self.policy.all_log_probs - self.policy_network.q1_p, |
|||
axis=1, |
|||
) |
|||
self.policy_loss = tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) * batch_policy_loss |
|||
) |
|||
|
|||
value_losses = [] |
|||
for name in stream_names: |
|||
v_backup = tf.stop_gradient( |
|||
self.min_policy_qs[name] |
|||
- tf.reduce_sum(self.ent_coef * self.policy.all_log_probs, axis=1) |
|||
) |
|||
value_losses.append( |
|||
0.5 |
|||
* tf.reduce_mean( |
|||
tf.to_float(self.policy.mask) |
|||
* tf.squared_difference( |
|||
self.policy_network.value_heads[name], v_backup |
|||
) |
|||
) |
|||
) |
|||
self.value_loss = tf.reduce_mean(value_losses) |
|||
|
|||
self.total_value_loss = self.q1_loss + self.q2_loss + self.value_loss |
|||
|
|||
self.entropy = self.policy_network.entropy |
|||
|
|||
def _create_sac_optimizer_ops(self) -> None: |
|||
""" |
|||
Creates the Adam optimizers and update ops for SAC, including |
|||
the policy, value, and entropy updates, as well as the target network update. |
|||
""" |
|||
policy_optimizer = self.create_optimizer_op( |
|||
learning_rate=self.learning_rate, name="sac_policy_opt" |
|||
) |
|||
entropy_optimizer = self.create_optimizer_op( |
|||
learning_rate=self.learning_rate, name="sac_entropy_opt" |
|||
) |
|||
value_optimizer = self.create_optimizer_op( |
|||
learning_rate=self.learning_rate, name="sac_value_opt" |
|||
) |
|||
|
|||
self.target_update_op = [ |
|||
tf.assign(target, (1 - self.tau) * target + self.tau * source) |
|||
for target, source in zip( |
|||
self.target_network.value_vars, self.policy_network.value_vars |
|||
) |
|||
] |
|||
logger.debug("value_vars") |
|||
self.print_all_vars(self.policy_network.value_vars) |
|||
logger.debug("targvalue_vars") |
|||
self.print_all_vars(self.target_network.value_vars) |
|||
logger.debug("critic_vars") |
|||
self.print_all_vars(self.policy_network.critic_vars) |
|||
logger.debug("q_vars") |
|||
self.print_all_vars(self.policy_network.q_vars) |
|||
logger.debug("policy_vars") |
|||
policy_vars = self.policy.get_trainable_variables() |
|||
self.print_all_vars(policy_vars) |
|||
|
|||
self.target_init_op = [ |
|||
tf.assign(target, source) |
|||
for target, source in zip( |
|||
self.target_network.value_vars, self.policy_network.value_vars |
|||
) |
|||
] |
|||
|
|||
self.update_batch_policy = policy_optimizer.minimize( |
|||
self.policy_loss, var_list=policy_vars |
|||
) |
|||
|
|||
# Make sure policy is updated first, then value, then entropy. |
|||
with tf.control_dependencies([self.update_batch_policy]): |
|||
self.update_batch_value = value_optimizer.minimize( |
|||
self.total_value_loss, var_list=self.policy_network.critic_vars |
|||
) |
|||
# Add entropy coefficient optimization operation |
|||
with tf.control_dependencies([self.update_batch_value]): |
|||
self.update_batch_entropy = entropy_optimizer.minimize( |
|||
self.entropy_loss, var_list=self.log_ent_coef |
|||
) |
|||
|
|||
def print_all_vars(self, variables): |
|||
for _var in variables: |
|||
logger.debug(_var) |
|||
|
|||
@timed |
|||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|||
""" |
|||
Updates model using buffer. |
|||
:param num_sequences: Number of trajectories in batch. |
|||
:param batch: Experience mini-batch. |
|||
:param update_target: Whether or not to update target value network |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
:return: Output from update process. |
|||
""" |
|||
feed_dict = self._construct_feed_dict(self.policy, batch, num_sequences) |
|||
stats_needed = self.stats_name_to_update_name |
|||
update_stats: Dict[str, float] = {} |
|||
update_vals = self._execute_model(feed_dict, self.update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
# Update target network. By default, target update happens at every policy update. |
|||
self.sess.run(self.target_update_op) |
|||
return update_stats |
|||
|
|||
def update_reward_signals( |
|||
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int |
|||
) -> Dict[str, float]: |
|||
""" |
|||
Only update the reward signals. |
|||
:param reward_signal_batches: Minibatches to use for updating the reward signals, |
|||
indexed by name. If none, don't update the reward signals. |
|||
""" |
|||
# Collect feed dicts for all reward signals. |
|||
feed_dict: Dict[tf.Tensor, Any] = {} |
|||
update_dict: Dict[str, tf.Tensor] = {} |
|||
update_stats: Dict[str, float] = {} |
|||
stats_needed: Dict[str, str] = {} |
|||
if reward_signal_minibatches: |
|||
self.add_reward_signal_dicts( |
|||
feed_dict, |
|||
update_dict, |
|||
stats_needed, |
|||
reward_signal_minibatches, |
|||
num_sequences, |
|||
) |
|||
update_vals = self._execute_model(feed_dict, update_dict) |
|||
for stat_name, update_name in stats_needed.items(): |
|||
update_stats[stat_name] = update_vals[update_name] |
|||
return update_stats |
|||
|
|||
def add_reward_signal_dicts( |
|||
self, |
|||
feed_dict: Dict[tf.Tensor, Any], |
|||
update_dict: Dict[str, tf.Tensor], |
|||
stats_needed: Dict[str, str], |
|||
reward_signal_minibatches: Mapping[str, AgentBuffer], |
|||
num_sequences: int, |
|||
) -> None: |
|||
""" |
|||
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict. |
|||
:param feed_dict: Feed dict needed update |
|||
:param update_dit: Update dict that needs update |
|||
:param stats_needed: Stats needed to get from the update. |
|||
:param reward_signal_minibatches: Minibatches to use for updating the reward signals, |
|||
indexed by name. |
|||
""" |
|||
for name, r_batch in reward_signal_minibatches.items(): |
|||
feed_dict.update( |
|||
self.reward_signals[name].prepare_update( |
|||
self.policy, r_batch, num_sequences |
|||
) |
|||
) |
|||
update_dict.update(self.reward_signals[name].update_dict) |
|||
stats_needed.update(self.reward_signals[name].stats_name_to_update_name) |
|||
|
|||
def _construct_feed_dict( |
|||
self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int |
|||
) -> Dict[tf.Tensor, Any]: |
|||
""" |
|||
Builds the feed dict for updating the SAC model. |
|||
:param model: The model to update. May be different when, e.g. using multi-GPU. |
|||
:param batch: Mini-batch to use to update. |
|||
:param num_sequences: Number of LSTM sequences in batch. |
|||
""" |
|||
# Do an optional burn-in for memories |
|||
num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length) |
|||
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32) |
|||
burn_in_mask[range(0, num_burn_in)] = 0 |
|||
burn_in_mask = np.tile(burn_in_mask, num_sequences) |
|||
feed_dict = { |
|||
policy.batch_size_ph: num_sequences, |
|||
policy.sequence_length_ph: self.policy.sequence_length, |
|||
self.next_sequence_length_ph: self.policy.sequence_length, |
|||
self.policy.mask_input: batch["masks"] * burn_in_mask, |
|||
} |
|||
for name in self.reward_signals: |
|||
feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)] |
|||
|
|||
if self.policy.use_continuous_act: |
|||
feed_dict[self.policy_network.external_action_in] = batch["actions"] |
|||
else: |
|||
feed_dict[policy.output] = batch["actions"] |
|||
if self.policy.use_recurrent: |
|||
feed_dict[policy.prev_action] = batch["prev_action"] |
|||
feed_dict[policy.action_masks] = batch["action_mask"] |
|||
if self.policy.use_vec_obs: |
|||
feed_dict[policy.vector_in] = batch["vector_obs"] |
|||
feed_dict[self.next_vector_in] = batch["next_vector_in"] |
|||
if self.policy.vis_obs_size > 0: |
|||
for i, _ in enumerate(policy.visual_in): |
|||
_obs = batch["visual_obs%d" % i] |
|||
feed_dict[policy.visual_in[i]] = _obs |
|||
for i, _ in enumerate(self.next_visual_in): |
|||
_obs = batch["next_visual_obs%d" % i] |
|||
feed_dict[self.next_visual_in[i]] = _obs |
|||
if self.policy.use_recurrent: |
|||
feed_dict[policy.memory_in] = [ |
|||
batch["memory"][i] |
|||
for i in range(0, len(batch["memory"]), self.policy.sequence_length) |
|||
] |
|||
feed_dict[self.policy_network.memory_in] = self._make_zero_mem( |
|||
self.m_size, batch.num_experiences |
|||
) |
|||
feed_dict[self.target_network.memory_in] = self._make_zero_mem( |
|||
self.m_size // 3, batch.num_experiences |
|||
) |
|||
feed_dict[self.dones_holder] = batch["done"] |
|||
return feed_dict |
|
|||
# ## ML-Agent Learning (SAC) |
|||
# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290 |
|||
# and implemented in https://github.com/hill-a/stable-baselines |
|||
|
|||
from collections import defaultdict |
|||
from typing import Dict, cast |
|||
import os |
|||
|
|||
import numpy as np |
|||
|
|||
|
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.policy.tf_policy import TFPolicy |
|||
from mlagents.trainers.policy.nn_policy import NNPolicy |
|||
from mlagents.trainers.sac.optimizer import SACOptimizer |
|||
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|||
from mlagents.trainers.trajectory import Trajectory, SplitObservations |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|||
from mlagents.trainers.settings import TrainerSettings, SACSettings |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
BUFFER_TRUNCATE_PERCENT = 0.8 |
|||
|
|||
|
|||
class SACTrainer(RLTrainer): |
|||
""" |
|||
The SACTrainer is an implementation of the SAC algorithm, with support |
|||
for discrete actions and recurrent networks. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
brain_name: str, |
|||
reward_buff_cap: int, |
|||
trainer_settings: TrainerSettings, |
|||
training: bool, |
|||
load: bool, |
|||
seed: int, |
|||
artifact_path: str, |
|||
): |
|||
""" |
|||
Responsible for collecting experiences and training SAC model. |
|||
:param brain_name: The name of the brain associated with trainer config |
|||
:param reward_buff_cap: Max reward history to track in the reward buffer |
|||
:param trainer_settings: The parameters for the trainer. |
|||
:param training: Whether the trainer is set for training. |
|||
:param load: Whether the model should be loaded. |
|||
:param seed: The seed the model will be initialized with |
|||
:param artifact_path: The directory within which to store artifacts from this trainer. |
|||
""" |
|||
super().__init__( |
|||
brain_name, trainer_settings, training, artifact_path, reward_buff_cap |
|||
) |
|||
|
|||
self.load = load |
|||
self.seed = seed |
|||
self.policy: NNPolicy = None # type: ignore |
|||
self.optimizer: SACOptimizer = None # type: ignore |
|||
self.hyperparameters: SACSettings = cast( |
|||
SACSettings, trainer_settings.hyperparameters |
|||
) |
|||
self.step = 0 |
|||
|
|||
# Don't divide by zero |
|||
self.update_steps = 1 |
|||
self.reward_signal_update_steps = 1 |
|||
|
|||
self.steps_per_update = self.hyperparameters.steps_per_update |
|||
self.reward_signal_steps_per_update = ( |
|||
self.hyperparameters.reward_signal_steps_per_update |
|||
) |
|||
|
|||
self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer |
|||
|
|||
def save_model(self, name_behavior_id: str) -> None: |
|||
""" |
|||
Saves the model. Overrides the default save_model since we want to save |
|||
the replay buffer as well. |
|||
""" |
|||
self.policy.save_model(self.get_step) |
|||
if self.checkpoint_replay_buffer: |
|||
self.save_replay_buffer() |
|||
|
|||
def save_replay_buffer(self) -> None: |
|||
""" |
|||
Save the training buffer's update buffer to a pickle file. |
|||
""" |
|||
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5") |
|||
logger.info("Saving Experience Replay Buffer to {}".format(filename)) |
|||
with open(filename, "wb") as file_object: |
|||
self.update_buffer.save_to_file(file_object) |
|||
|
|||
def load_replay_buffer(self) -> None: |
|||
""" |
|||
Loads the last saved replay buffer from a file. |
|||
""" |
|||
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5") |
|||
logger.info("Loading Experience Replay Buffer from {}".format(filename)) |
|||
with open(filename, "rb+") as file_object: |
|||
self.update_buffer.load_from_file(file_object) |
|||
logger.info( |
|||
"Experience replay buffer has {} experiences.".format( |
|||
self.update_buffer.num_experiences |
|||
) |
|||
) |
|||
|
|||
def _process_trajectory(self, trajectory: Trajectory) -> None: |
|||
""" |
|||
Takes a trajectory and processes it, putting it into the replay buffer. |
|||
""" |
|||
super()._process_trajectory(trajectory) |
|||
last_step = trajectory.steps[-1] |
|||
agent_id = trajectory.agent_id # All the agents should have the same ID |
|||
|
|||
agent_buffer_trajectory = trajectory.to_agentbuffer() |
|||
|
|||
# Update the normalization |
|||
if self.is_training: |
|||
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"]) |
|||
|
|||
# Evaluate all reward functions for reporting purposes |
|||
self.collected_rewards["environment"][agent_id] += np.sum( |
|||
agent_buffer_trajectory["environment_rewards"] |
|||
) |
|||
for name, reward_signal in self.optimizer.reward_signals.items(): |
|||
evaluate_result = reward_signal.evaluate_batch( |
|||
agent_buffer_trajectory |
|||
).scaled_reward |
|||
# Report the reward signals |
|||
self.collected_rewards[name][agent_id] += np.sum(evaluate_result) |
|||
|
|||
# Get all value estimates for reporting purposes |
|||
value_estimates, _ = self.optimizer.get_trajectory_value_estimates( |
|||
agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached |
|||
) |
|||
for name, v in value_estimates.items(): |
|||
self._stats_reporter.add_stat( |
|||
self.optimizer.reward_signals[name].value_name, np.mean(v) |
|||
) |
|||
|
|||
# Bootstrap using the last step rather than the bootstrap step if max step is reached. |
|||
# Set last element to duplicate obs and remove dones. |
|||
if last_step.interrupted: |
|||
vec_vis_obs = SplitObservations.from_observations(last_step.obs) |
|||
for i, obs in enumerate(vec_vis_obs.visual_observations): |
|||
agent_buffer_trajectory["next_visual_obs%d" % i][-1] = obs |
|||
if vec_vis_obs.vector_observations.size > 1: |
|||
agent_buffer_trajectory["next_vector_in"][ |
|||
-1 |
|||
] = vec_vis_obs.vector_observations |
|||
agent_buffer_trajectory["done"][-1] = False |
|||
|
|||
# Append to update buffer |
|||
agent_buffer_trajectory.resequence_and_append( |
|||
self.update_buffer, training_length=self.policy.sequence_length |
|||
) |
|||
|
|||
if trajectory.done_reached: |
|||
self._update_end_episode_stats(agent_id, self.optimizer) |
|||
|
|||
def _is_ready_update(self) -> bool: |
|||
""" |
|||
Returns whether or not the trainer has enough elements to run update model |
|||
:return: A boolean corresponding to whether or not _update_policy() can be run |
|||
""" |
|||
return ( |
|||
self.update_buffer.num_experiences >= self.hyperparameters.batch_size |
|||
and self.step >= self.hyperparameters.buffer_init_steps |
|||
) |
|||
|
|||
@timed |
|||
def _update_policy(self) -> bool: |
|||
""" |
|||
Update the SAC policy and reward signals. The reward signal generators are updated using different mini batches. |
|||
By default we imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated |
|||
N times, then the reward signals are updated N times. |
|||
:return: Whether or not the policy was updated. |
|||
""" |
|||
policy_was_updated = self._update_sac_policy() |
|||
self._update_reward_signals() |
|||
return policy_was_updated |
|||
|
|||
def create_policy( |
|||
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters |
|||
) -> TFPolicy: |
|||
policy = NNPolicy( |
|||
self.seed, |
|||
brain_parameters, |
|||
self.trainer_settings, |
|||
self.is_training, |
|||
self.artifact_path, |
|||
self.load, |
|||
tanh_squash=True, |
|||
reparameterize=True, |
|||
create_tf_graph=False, |
|||
) |
|||
# Load the replay buffer if load |
|||
if self.load and self.checkpoint_replay_buffer: |
|||
try: |
|||
self.load_replay_buffer() |
|||
except (AttributeError, FileNotFoundError): |
|||
logger.warning( |
|||
"Replay buffer was unable to load, starting from scratch." |
|||
) |
|||
logger.debug( |
|||
"Loaded update buffer with {} sequences".format( |
|||
self.update_buffer.num_experiences |
|||
) |
|||
) |
|||
|
|||
return policy |
|||
|
|||
def _update_sac_policy(self) -> bool: |
|||
""" |
|||
Uses update_buffer to update the policy. We sample the update_buffer and update |
|||
until the steps_per_update ratio is met. |
|||
""" |
|||
has_updated = False |
|||
self.cumulative_returns_since_policy_update.clear() |
|||
n_sequences = max( |
|||
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 |
|||
) |
|||
|
|||
batch_update_stats: Dict[str, list] = defaultdict(list) |
|||
while ( |
|||
self.step - self.hyperparameters.buffer_init_steps |
|||
) / self.update_steps > self.steps_per_update: |
|||
logger.debug("Updating SAC policy at step {}".format(self.step)) |
|||
buffer = self.update_buffer |
|||
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: |
|||
sampled_minibatch = buffer.sample_mini_batch( |
|||
self.hyperparameters.batch_size, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
# Get rewards for each reward |
|||
for name, signal in self.optimizer.reward_signals.items(): |
|||
sampled_minibatch[ |
|||
"{}_rewards".format(name) |
|||
] = signal.evaluate_batch(sampled_minibatch).scaled_reward |
|||
|
|||
update_stats = self.optimizer.update(sampled_minibatch, n_sequences) |
|||
for stat_name, value in update_stats.items(): |
|||
batch_update_stats[stat_name].append(value) |
|||
|
|||
self.update_steps += 1 |
|||
|
|||
for stat, stat_list in batch_update_stats.items(): |
|||
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|||
has_updated = True |
|||
|
|||
if self.optimizer.bc_module: |
|||
update_stats = self.optimizer.bc_module.update() |
|||
for stat, val in update_stats.items(): |
|||
self._stats_reporter.add_stat(stat, val) |
|||
|
|||
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating |
|||
# a large buffer at each update. |
|||
if self.update_buffer.num_experiences > self.hyperparameters.buffer_size: |
|||
self.update_buffer.truncate( |
|||
int(self.hyperparameters.buffer_size * BUFFER_TRUNCATE_PERCENT) |
|||
) |
|||
return has_updated |
|||
|
|||
def _update_reward_signals(self) -> None: |
|||
""" |
|||
Iterate through the reward signals and update them. Unlike in PPO, |
|||
do it separate from the policy so that it can be done at a different |
|||
interval. |
|||
This function should only be used to simulate |
|||
http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated |
|||
N times, then the reward signals are updated N times. Normally, the reward signal |
|||
and policy are updated in parallel. |
|||
""" |
|||
buffer = self.update_buffer |
|||
n_sequences = max( |
|||
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 |
|||
) |
|||
batch_update_stats: Dict[str, list] = defaultdict(list) |
|||
while ( |
|||
self.step - self.hyperparameters.buffer_init_steps |
|||
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update: |
|||
# Get minibatches for reward signal update if needed |
|||
reward_signal_minibatches = {} |
|||
for name, signal in self.optimizer.reward_signals.items(): |
|||
logger.debug("Updating {} at step {}".format(name, self.step)) |
|||
# Some signals don't need a minibatch to be sampled - so we don't! |
|||
if signal.update_dict: |
|||
reward_signal_minibatches[name] = buffer.sample_mini_batch( |
|||
self.hyperparameters.batch_size, |
|||
sequence_length=self.policy.sequence_length, |
|||
) |
|||
update_stats = self.optimizer.update_reward_signals( |
|||
reward_signal_minibatches, n_sequences |
|||
) |
|||
for stat_name, value in update_stats.items(): |
|||
batch_update_stats[stat_name].append(value) |
|||
self.reward_signal_update_steps += 1 |
|||
|
|||
for stat, stat_list in batch_update_stats.items(): |
|||
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|||
|
|||
def add_policy( |
|||
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy |
|||
) -> None: |
|||
""" |
|||
Adds policy to trainer. |
|||
:param brain_parameters: specifications for policy construction |
|||
""" |
|||
if self.policy: |
|||
logger.warning( |
|||
"Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \ |
|||
train adversarial games.".format( |
|||
self.__class__.__name__ |
|||
) |
|||
) |
|||
if not isinstance(policy, NNPolicy): |
|||
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()") |
|||
self.policy = policy |
|||
self.optimizer = SACOptimizer(self.policy, self.trainer_settings) |
|||
for _reward_signal in self.optimizer.reward_signals.keys(): |
|||
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) |
|||
# Needed to resume loads properly |
|||
self.step = policy.get_current_step() |
|||
# Assume steps were updated at the correct ratio before |
|||
self.update_steps = int(max(1, self.step / self.steps_per_update)) |
|||
self.reward_signal_update_steps = int( |
|||
max(1, self.step / self.reward_signal_steps_per_update) |
|||
) |
|||
|
|||
def get_policy(self, name_behavior_id: str) -> TFPolicy: |
|||
""" |
|||
Gets policy from trainer associated with name_behavior_id |
|||
:param name_behavior_id: full identifier of policy |
|||
""" |
|||
|
|||
return self.policy |
撰写
预览
正在加载...
取消
保存
Reference in new issue