浏览代码

init sac transfer, and added action encoder to bisim; configs for crawler

/develop/bisim-sac-transfer
yanchaosun 4 年前
当前提交
80bad241
共有 9 个文件被更改,包括 1462 次插入14 次删除
  1. 7
      config/ppo_transfer/CrawlerStatic.yaml
  2. 22
      config/ppo_transfer/TransferCrawlerStatic.yaml
  3. 15
      ml-agents/mlagents/trainers/policy/transfer_policy.py
  4. 6
      ml-agents/mlagents/trainers/tests/transfer_test_envs.py
  5. 0
      ml-agents/mlagents/trainers/sac_transfer/__init__.py
  6. 445
      ml-agents/mlagents/trainers/sac_transfer/network.py
  7. 641
      ml-agents/mlagents/trainers/sac_transfer/optimizer.py
  8. 340
      ml-agents/mlagents/trainers/sac_transfer/trainer.py

7
config/ppo_transfer/CrawlerStatic.yaml


lambd: 0.95
num_epoch: 3
learning_rate_schedule: constant
model_schedule: constant
value_layers: 2
value_layers: 3
action_feature_size: 128
in_batch_alter: true
use_bisim: false
use_bisim: true
separate_value_net: true
network_settings:
normalize: true

22
config/ppo_transfer/TransferCrawlerStatic.yaml


epsilon: 0.2
lambd: 0.95
num_epoch: 3
learning_rate_schedule: linear
encoder_layers: 2
policy_layers: 2
learning_rate_schedule: constant
model_schedule: constant
encoder_layers: 3
policy_layers: 0
forward_layers: 0
in_epoch_alter: true
use_op_buffer: true
action_feature_size: 128
reuse_encoder: true
in_epoch_alter: false
in_batch_alter: true
use_op_buffer: false
with_prior: true
with_prior: false
predict_return: true
use_bisim: true
separate_value_net: true

load_action: true
load_model: true
train_action: true
transfer_path: "results/cs-old-bisim/CrawlerStatic"
transfer_path: "results/csold-bisim/CrawlerStatic"
network_settings:
normalize: true
hidden_units: 512

15
ml-agents/mlagents/trainers/policy/transfer_policy.py


self.h_size,
self.feature_size,
encoder_layers,
action_layers,
self.vis_encode_type,
forward_layers,
var_predict,

h_size: int,
action_feature_size: int,
num_layers: int,
reuse: bool=False
) -> tf.Tensor:
hidden_stream = ModelUtils.create_vector_observation_encoder(

num_layers,
scope="action_enc",
reuse=False
reuse=reuse
)
with tf.variable_scope("action_enc"):

name="latent",
activation=tf.tanh,
kernel_initializer=tf.initializers.variance_scaling(1.0),
reuse=reuse
)
return latent

h_size: int,
feature_size: int,
encoder_layers: int,
action_layers: int,
vis_encode_type: EncoderType,
forward_layers: int,
var_predict: bool,

self.bisim_action = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="bisim_action"
)
combined_input = tf.concat([self.bisim_encoder, self.bisim_action], axis=1)
self.bisim_action_encoder = self._create_action_encoder(
self.bisim_action,
self.h_size,
self.action_feature_size,
action_layers,
reuse=True,
)
combined_input = tf.concat([self.bisim_encoder, self.bisim_action_encoder], axis=1)
combined_input = tf.stop_gradient(combined_input)
with tf.variable_scope("predict"):

6
ml-agents/mlagents/trainers/tests/transfer_test_envs.py


for i in self.positions[name]:
obs.append(np.ones((1, self.vec_obs_size), dtype=np.float32) * i)
for _ in range(self.extra_obs_size):
obs.append(np.ones((1, self.vec_obs_size), dtype=np.float32))
obs.append(np.random.randn(1, self.vec_obs_size))
obs.append(np.ones((1, self.vec_obs_size), dtype=np.float32))
obs.append(np.random.randn(1, self.vec_obs_size))
# print(obs)
return obs
@property

0
ml-agents/mlagents/trainers/sac_transfer/__init__.py

445
ml-agents/mlagents/trainers/sac_transfer/network.py


from typing import Dict, Optional
from mlagents.tf_utils import tf
from mlagents.trainers.models import ModelUtils, EncoderType
LOG_STD_MAX = 2
LOG_STD_MIN = -20
EPSILON = 1e-6 # Small value to avoid divide by zero
DISCRETE_TARGET_ENTROPY_SCALE = 0.2 # Roughly equal to e-greedy 0.05
CONTINUOUS_TARGET_ENTROPY_SCALE = 1.0 # TODO: Make these an optional hyperparam.
POLICY_SCOPE = ""
TARGET_SCOPE = "target_network"
class SACNetwork:
"""
Base class for an SAC network. Implements methods for creating the actor and critic heads.
"""
def __init__(
self,
policy=None,
m_size=None,
h_size=128,
normalize=False,
use_recurrent=False,
num_layers=2,
stream_names=None,
vis_encode_type=EncoderType.SIMPLE,
):
self.normalize = normalize
self.use_recurrent = use_recurrent
self.num_layers = num_layers
self.stream_names = stream_names
self.h_size = h_size
self.activ_fn = ModelUtils.swish
self.sequence_length_ph = tf.placeholder(
shape=None, dtype=tf.int32, name="sac_sequence_length"
)
self.policy_memory_in: Optional[tf.Tensor] = None
self.policy_memory_out: Optional[tf.Tensor] = None
self.value_memory_in: Optional[tf.Tensor] = None
self.value_memory_out: Optional[tf.Tensor] = None
self.q1: Optional[tf.Tensor] = None
self.q2: Optional[tf.Tensor] = None
self.q1_p: Optional[tf.Tensor] = None
self.q2_p: Optional[tf.Tensor] = None
self.q1_memory_in: Optional[tf.Tensor] = None
self.q2_memory_in: Optional[tf.Tensor] = None
self.q1_memory_out: Optional[tf.Tensor] = None
self.q2_memory_out: Optional[tf.Tensor] = None
self.prev_action: Optional[tf.Tensor] = None
self.action_masks: Optional[tf.Tensor] = None
self.external_action_in: Optional[tf.Tensor] = None
self.log_sigma_sq: Optional[tf.Tensor] = None
self.entropy: Optional[tf.Tensor] = None
self.deterministic_output: Optional[tf.Tensor] = None
self.normalized_logprobs: Optional[tf.Tensor] = None
self.action_probs: Optional[tf.Tensor] = None
self.output_oh: Optional[tf.Tensor] = None
self.output_pre: Optional[tf.Tensor] = None
self.value_vars = None
self.q_vars = None
self.critic_vars = None
self.policy_vars = None
self.q1_heads: Dict[str, tf.Tensor] = None
self.q2_heads: Dict[str, tf.Tensor] = None
self.q1_pheads: Dict[str, tf.Tensor] = None
self.q2_pheads: Dict[str, tf.Tensor] = None
self.policy = policy
def get_vars(self, scope):
return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope)
def join_scopes(self, scope_1, scope_2):
"""
Joins two scopes. Does so safetly (i.e., if one of the two scopes doesn't
exist, don't add any backslashes)
"""
if not scope_1:
return scope_2
if not scope_2:
return scope_1
else:
return "/".join(filter(None, [scope_1, scope_2]))
def create_value_heads(self, stream_names, hidden_input):
"""
Creates one value estimator head for each reward signal in stream_names.
Also creates the node corresponding to the mean of all the value heads in self.value.
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal.
:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
"""
self.value_heads = {}
for name in stream_names:
value = tf.layers.dense(hidden_input, 1, name="{}_value".format(name))
self.value_heads[name] = value
self.value = tf.reduce_mean(list(self.value_heads.values()), 0)
def _create_cc_critic(self, hidden_value, scope, create_qs=True):
"""
Creates just the critic network
"""
scope = self.join_scopes(scope, "critic")
self.create_sac_value_head(
self.stream_names,
hidden_value,
self.num_layers,
self.h_size,
self.join_scopes(scope, "value"),
)
self.external_action_in = tf.placeholder(
shape=[None, self.policy.act_size[0]],
dtype=tf.float32,
name="external_action_in",
)
self.value_vars = self.get_vars(self.join_scopes(scope, "value"))
if create_qs:
hidden_q = tf.concat([hidden_value, self.external_action_in], axis=-1)
hidden_qp = tf.concat([hidden_value, self.policy.output], axis=-1)
self.q1_heads, self.q2_heads, self.q1, self.q2 = self.create_q_heads(
self.stream_names,
hidden_q,
self.num_layers,
self.h_size,
self.join_scopes(scope, "q"),
)
self.q1_pheads, self.q2_pheads, self.q1_p, self.q2_p = self.create_q_heads(
self.stream_names,
hidden_qp,
self.num_layers,
self.h_size,
self.join_scopes(scope, "q"),
reuse=True,
)
self.q_vars = self.get_vars(self.join_scopes(scope, "q"))
self.critic_vars = self.get_vars(scope)
def _create_dc_critic(self, hidden_value, scope, create_qs=True):
"""
Creates just the critic network
"""
scope = self.join_scopes(scope, "critic")
self.create_sac_value_head(
self.stream_names,
hidden_value,
self.num_layers,
self.h_size,
self.join_scopes(scope, "value"),
)
self.value_vars = self.get_vars("/".join([scope, "value"]))
if create_qs:
self.q1_heads, self.q2_heads, self.q1, self.q2 = self.create_q_heads(
self.stream_names,
hidden_value,
self.num_layers,
self.h_size,
self.join_scopes(scope, "q"),
num_outputs=sum(self.policy.act_size),
)
self.q1_pheads, self.q2_pheads, self.q1_p, self.q2_p = self.create_q_heads(
self.stream_names,
hidden_value,
self.num_layers,
self.h_size,
self.join_scopes(scope, "q"),
reuse=True,
num_outputs=sum(self.policy.act_size),
)
self.q_vars = self.get_vars(scope)
self.critic_vars = self.get_vars(scope)
def create_sac_value_head(
self, stream_names, hidden_input, num_layers, h_size, scope
):
"""
Creates one value estimator head for each reward signal in stream_names.
Also creates the node corresponding to the mean of all the value heads in self.value.
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal.
:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
:param num_layers: Number of hidden layers for value network
:param h_size: size of hidden layers for value network
:param scope: TF scope for value network.
"""
with tf.variable_scope(scope):
value_hidden = ModelUtils.create_vector_observation_encoder(
hidden_input, h_size, self.activ_fn, num_layers, "encoder", False
)
if self.use_recurrent:
value_hidden, memory_out = ModelUtils.create_recurrent_encoder(
value_hidden,
self.value_memory_in,
self.sequence_length_ph,
name="lstm_value",
)
self.value_memory_out = memory_out
self.create_value_heads(stream_names, value_hidden)
def create_q_heads(
self,
stream_names,
hidden_input,
num_layers,
h_size,
scope,
reuse=False,
num_outputs=1,
):
"""
Creates two q heads for each reward signal in stream_names.
Also creates the node corresponding to the mean of all the value heads in self.value.
self.value_head is a dictionary of stream name to node containing the value estimator head for that signal.
:param stream_names: The list of reward signal names
:param hidden_input: The last layer of the Critic. The heads will consist of one dense hidden layer on top
of the hidden input.
:param num_layers: Number of hidden layers for Q network
:param h_size: size of hidden layers for Q network
:param scope: TF scope for Q network.
:param reuse: Whether or not to reuse variables. Useful for creating Q of policy.
:param num_outputs: Number of outputs of each Q function. If discrete, equal to number of actions.
"""
with tf.variable_scope(self.join_scopes(scope, "q1_encoding"), reuse=reuse):
q1_hidden = ModelUtils.create_vector_observation_encoder(
hidden_input, h_size, self.activ_fn, num_layers, "q1_encoder", reuse
)
if self.use_recurrent:
q1_hidden, memory_out = ModelUtils.create_recurrent_encoder(
q1_hidden,
self.q1_memory_in,
self.sequence_length_ph,
name="lstm_q1",
)
self.q1_memory_out = memory_out
q1_heads = {}
for name in stream_names:
_q1 = tf.layers.dense(q1_hidden, num_outputs, name="{}_q1".format(name))
q1_heads[name] = _q1
q1 = tf.reduce_mean(list(q1_heads.values()), axis=0)
with tf.variable_scope(self.join_scopes(scope, "q2_encoding"), reuse=reuse):
q2_hidden = ModelUtils.create_vector_observation_encoder(
hidden_input, h_size, self.activ_fn, num_layers, "q2_encoder", reuse
)
if self.use_recurrent:
q2_hidden, memory_out = ModelUtils.create_recurrent_encoder(
q2_hidden,
self.q2_memory_in,
self.sequence_length_ph,
name="lstm_q2",
)
self.q2_memory_out = memory_out
q2_heads = {}
for name in stream_names:
_q2 = tf.layers.dense(q2_hidden, num_outputs, name="{}_q2".format(name))
q2_heads[name] = _q2
q2 = tf.reduce_mean(list(q2_heads.values()), axis=0)
return q1_heads, q2_heads, q1, q2
class SACTargetNetwork(SACNetwork):
"""
Instantiation for the SAC target network. Only contains a single
value estimator and is updated from the Policy Network.
"""
def __init__(
self,
policy,
m_size=None,
h_size=128,
normalize=False,
use_recurrent=False,
num_layers=2,
stream_names=None,
vis_encode_type=EncoderType.SIMPLE,
):
super().__init__(
policy,
m_size,
h_size,
normalize,
use_recurrent,
num_layers,
stream_names,
vis_encode_type,
)
with tf.variable_scope(TARGET_SCOPE):
self.visual_in = ModelUtils.create_visual_input_placeholders(
policy.brain.camera_resolutions
)
self.vector_in = ModelUtils.create_vector_input(policy.vec_obs_size)
if self.policy.normalize:
normalization_tensors = ModelUtils.create_normalizer(self.vector_in)
self.update_normalization_op = normalization_tensors.update_op
self.normalization_steps = normalization_tensors.steps
self.running_mean = normalization_tensors.running_mean
self.running_variance = normalization_tensors.running_variance
self.processed_vector_in = ModelUtils.normalize_vector_obs(
self.vector_in,
self.running_mean,
self.running_variance,
self.normalization_steps,
)
else:
self.processed_vector_in = self.vector_in
self.update_normalization_op = None
if self.policy.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, m_size], dtype=tf.float32, name="target_recurrent_in"
)
self.value_memory_in = self.memory_in
hidden_streams = ModelUtils.create_observation_streams(
self.visual_in,
self.processed_vector_in,
1,
self.h_size,
0,
vis_encode_type=vis_encode_type,
stream_scopes=["critic/value/"],
)
if self.policy.use_continuous_act:
self._create_cc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
else:
self._create_dc_critic(hidden_streams[0], TARGET_SCOPE, create_qs=False)
if self.use_recurrent:
self.memory_out = tf.concat(
self.value_memory_out, axis=1
) # Needed for Barracuda to work
def copy_normalization(self, mean, variance, steps):
"""
Copies the mean, variance, and steps into the normalizers of the
input of this SACNetwork. Used to copy the normalizer from the policy network
to the target network.
param mean: Tensor containing the mean.
param variance: Tensor containing the variance
param steps: Tensor containing the number of steps.
"""
update_mean = tf.assign(self.running_mean, mean)
update_variance = tf.assign(self.running_variance, variance)
update_norm_step = tf.assign(self.normalization_steps, steps)
return tf.group([update_mean, update_variance, update_norm_step])
class SACPolicyNetwork(SACNetwork):
"""
Instantiation for SAC policy network. Contains a dual Q estimator,
a value estimator, and a reference to the actual policy network.
"""
def __init__(
self,
policy,
m_size=None,
h_size=128,
normalize=False,
use_recurrent=False,
num_layers=2,
stream_names=None,
vis_encode_type=EncoderType.SIMPLE,
):
super().__init__(
policy,
m_size,
h_size,
normalize,
use_recurrent,
num_layers,
stream_names,
vis_encode_type,
)
if self.policy.use_recurrent:
self._create_memory_ins(m_size)
hidden_critic = self._create_observation_in(vis_encode_type)
self.policy.output = self.policy.output
# Use the sequence length of the policy
self.sequence_length_ph = self.policy.sequence_length_ph
if self.policy.use_continuous_act:
self._create_cc_critic(hidden_critic, POLICY_SCOPE)
else:
self._create_dc_critic(hidden_critic, POLICY_SCOPE)
if self.use_recurrent:
mem_outs = [self.value_memory_out, self.q1_memory_out, self.q2_memory_out]
self.memory_out = tf.concat(mem_outs, axis=1)
def _create_memory_ins(self, m_size):
"""
Creates the memory input placeholders for LSTM.
:param m_size: the total size of the memory.
"""
self.memory_in = tf.placeholder(
shape=[None, m_size * 3], dtype=tf.float32, name="value_recurrent_in"
)
# Re-break-up for each network
num_mems = 3
input_size = self.memory_in.get_shape().as_list()[1]
mem_ins = []
for i in range(num_mems):
_start = input_size // num_mems * i
_end = input_size // num_mems * (i + 1)
mem_ins.append(self.memory_in[:, _start:_end])
self.value_memory_in = mem_ins[0]
self.q1_memory_in = mem_ins[1]
self.q2_memory_in = mem_ins[2]
def _create_observation_in(self, vis_encode_type):
"""
Creates the observation inputs, and a CNN if needed,
:param vis_encode_type: Type of CNN encoder.
:param share_ac_cnn: Whether or not to share the actor and critic CNNs.
:return A tuple of (hidden_policy, hidden_critic). We don't save it to self since they're used
once and thrown away.
"""
with tf.variable_scope(POLICY_SCOPE):
hidden_streams = ModelUtils.create_observation_streams(
self.policy.visual_in,
self.policy.processed_vector_in,
1,
self.h_size,
0,
vis_encode_type=vis_encode_type,
stream_scopes=["critic/value/"],
)
hidden_critic = hidden_streams[0]
return hidden_critic

641
ml-agents/mlagents/trainers/sac_transfer/optimizer.py


import numpy as np
from typing import Dict, List, Optional, Any, Mapping, cast
from mlagents.tf_utils import tf
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.sac.network import SACPolicyNetwork, SACTargetNetwork
from mlagents.trainers.models import ModelUtils
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.timers import timed
from mlagents.trainers.settings import TrainerSettings, SACSettings
EPSILON = 1e-6 # Small value to avoid divide by zero
logger = get_logger(__name__)
POLICY_SCOPE = ""
TARGET_SCOPE = "target_network"
class SACOptimizer(TFOptimizer):
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):
"""
Takes a Unity environment and model-specific hyper-parameters and returns the
appropriate PPO agent model for the environment.
:param brain: Brain parameters used to generate specific network graph.
:param lr: Learning rate.
:param lr_schedule: Learning rate decay schedule.
:param h_size: Size of hidden layers
:param init_entcoef: Initial value for entropy coefficient. Set lower to learn faster,
set higher to explore more.
:return: a sub-class of PPOAgent tailored to the environment.
:param max_step: Total number of training steps.
:param normalize: Whether to normalize vector observation input.
:param use_recurrent: Whether to use an LSTM layer in the network.
:param num_layers: Number of hidden layers between encoded input and policy & value layers
:param tau: Strength of soft-Q update.
:param m_size: Size of brain memory.
"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
policy.create_tf_graph()
with policy.graph.as_default():
with tf.variable_scope(""):
super().__init__(policy, trainer_params)
hyperparameters: SACSettings = cast(
SACSettings, trainer_params.hyperparameters
)
lr = hyperparameters.learning_rate
lr_schedule = hyperparameters.learning_rate_schedule
max_step = trainer_params.max_steps
self.tau = hyperparameters.tau
self.init_entcoef = hyperparameters.init_entcoef
self.policy = policy
self.act_size = policy.act_size
policy_network_settings = policy.network_settings
h_size = policy_network_settings.hidden_units
num_layers = policy_network_settings.num_layers
vis_encode_type = policy_network_settings.vis_encode_type
self.tau = hyperparameters.tau
self.burn_in_ratio = 0.0
# Non-exposed SAC parameters
self.discrete_target_entropy_scale = (
0.2
) # Roughly equal to e-greedy 0.05
self.continuous_target_entropy_scale = 1.0
stream_names = list(self.reward_signals.keys())
# Use to reduce "survivor bonus" when using Curiosity or GAIL.
self.gammas = [
_val.gamma for _val in trainer_params.reward_signals.values()
]
self.use_dones_in_backup = {
name: tf.Variable(1.0) for name in stream_names
}
self.disable_use_dones = {
name: self.use_dones_in_backup[name].assign(0.0)
for name in stream_names
}
if num_layers < 1:
num_layers = 1
self.target_init_op: List[tf.Tensor] = []
self.target_update_op: List[tf.Tensor] = []
self.update_batch_policy: Optional[tf.Operation] = None
self.update_batch_value: Optional[tf.Operation] = None
self.update_batch_entropy: Optional[tf.Operation] = None
self.policy_network = SACPolicyNetwork(
policy=self.policy,
m_size=self.policy.m_size, # 3x policy.m_size
h_size=h_size,
normalize=self.policy.normalize,
use_recurrent=self.policy.use_recurrent,
num_layers=num_layers,
stream_names=stream_names,
vis_encode_type=vis_encode_type,
)
self.target_network = SACTargetNetwork(
policy=self.policy,
m_size=self.policy.m_size, # 1x policy.m_size
h_size=h_size,
normalize=self.policy.normalize,
use_recurrent=self.policy.use_recurrent,
num_layers=num_layers,
stream_names=stream_names,
vis_encode_type=vis_encode_type,
)
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value)
self.m_size = 3 * self.policy.m_size
self._create_inputs_and_outputs()
self.learning_rate = ModelUtils.create_schedule(
lr_schedule,
lr,
self.policy.global_step,
int(max_step),
min_value=1e-10,
)
self._create_losses(
self.policy_network.q1_heads,
self.policy_network.q2_heads,
lr,
int(max_step),
stream_names,
discrete=not self.policy.use_continuous_act,
)
self._create_sac_optimizer_ops()
self.selected_actions = (
self.policy.selected_actions
) # For GAIL and other reward signals
if self.policy.normalize:
target_update_norm = self.target_network.copy_normalization(
self.policy.running_mean,
self.policy.running_variance,
self.policy.normalization_steps,
)
# Update the normalization of the optimizer when the policy does.
self.policy.update_normalization_op = tf.group(
[self.policy.update_normalization_op, target_update_norm]
)
self.policy.initialize_or_load()
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
"Losses/Q1 Loss": "q1_loss",
"Losses/Q2 Loss": "q2_loss",
"Policy/Entropy Coeff": "entropy_coef",
"Policy/Learning Rate": "learning_rate",
}
self.update_dict = {
"value_loss": self.total_value_loss,
"policy_loss": self.policy_loss,
"q1_loss": self.q1_loss,
"q2_loss": self.q2_loss,
"entropy_coef": self.ent_coef,
"update_batch": self.update_batch_policy,
"update_value": self.update_batch_value,
"update_entropy": self.update_batch_entropy,
"learning_rate": self.learning_rate,
}
def _create_inputs_and_outputs(self) -> None:
"""
Assign the higher-level SACModel's inputs and outputs to those of its policy or
target network.
"""
self.vector_in = self.policy.vector_in
self.visual_in = self.policy.visual_in
self.next_vector_in = self.target_network.vector_in
self.next_visual_in = self.target_network.visual_in
self.sequence_length_ph = self.policy.sequence_length_ph
self.next_sequence_length_ph = self.target_network.sequence_length_ph
if not self.policy.use_continuous_act:
self.action_masks = self.policy_network.action_masks
else:
self.output_pre = self.policy_network.output_pre
# Don't use value estimate during inference.
self.value = tf.identity(
self.policy_network.value, name="value_estimate_unused"
)
self.value_heads = self.policy_network.value_heads
self.dones_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="dones_holder"
)
if self.policy.use_recurrent:
self.memory_in = self.policy_network.memory_in
self.memory_out = self.policy_network.memory_out
if not self.policy.use_continuous_act:
self.prev_action = self.policy_network.prev_action
self.next_memory_in = self.target_network.memory_in
def _create_losses(
self,
q1_streams: Dict[str, tf.Tensor],
q2_streams: Dict[str, tf.Tensor],
lr: tf.Tensor,
max_step: int,
stream_names: List[str],
discrete: bool = False,
) -> None:
"""
Creates training-specific Tensorflow ops for SAC models.
:param q1_streams: Q1 streams from policy network
:param q1_streams: Q2 streams from policy network
:param lr: Learning rate
:param max_step: Total number of training steps.
:param stream_names: List of reward stream names.
:param discrete: Whether or not to use discrete action losses.
"""
if discrete:
self.target_entropy = [
self.discrete_target_entropy_scale * np.log(i).astype(np.float32)
for i in self.act_size
]
discrete_action_probs = tf.exp(self.policy.all_log_probs)
per_action_entropy = discrete_action_probs * self.policy.all_log_probs
else:
self.target_entropy = (
-1
* self.continuous_target_entropy_scale
* np.prod(self.act_size[0]).astype(np.float32)
)
self.rewards_holders = {}
self.min_policy_qs = {}
for name in stream_names:
if discrete:
_branched_mpq1 = ModelUtils.break_into_branches(
self.policy_network.q1_pheads[name] * discrete_action_probs,
self.act_size,
)
branched_mpq1 = tf.stack(
[
tf.reduce_sum(_br, axis=1, keep_dims=True)
for _br in _branched_mpq1
]
)
_q1_p_mean = tf.reduce_mean(branched_mpq1, axis=0)
_branched_mpq2 = ModelUtils.break_into_branches(
self.policy_network.q2_pheads[name] * discrete_action_probs,
self.act_size,
)
branched_mpq2 = tf.stack(
[
tf.reduce_sum(_br, axis=1, keep_dims=True)
for _br in _branched_mpq2
]
)
_q2_p_mean = tf.reduce_mean(branched_mpq2, axis=0)
self.min_policy_qs[name] = tf.minimum(_q1_p_mean, _q2_p_mean)
else:
self.min_policy_qs[name] = tf.minimum(
self.policy_network.q1_pheads[name],
self.policy_network.q2_pheads[name],
)
rewards_holder = tf.placeholder(
shape=[None], dtype=tf.float32, name="{}_rewards".format(name)
)
self.rewards_holders[name] = rewards_holder
q1_losses = []
q2_losses = []
# Multiple q losses per stream
expanded_dones = tf.expand_dims(self.dones_holder, axis=-1)
for i, name in enumerate(stream_names):
_expanded_rewards = tf.expand_dims(self.rewards_holders[name], axis=-1)
q_backup = tf.stop_gradient(
_expanded_rewards
+ (1.0 - self.use_dones_in_backup[name] * expanded_dones)
* self.gammas[i]
* self.target_network.value_heads[name]
)
if discrete:
# We need to break up the Q functions by branch, and update them individually.
branched_q1_stream = ModelUtils.break_into_branches(
self.policy.selected_actions * q1_streams[name], self.act_size
)
branched_q2_stream = ModelUtils.break_into_branches(
self.policy.selected_actions * q2_streams[name], self.act_size
)
# Reduce each branch into scalar
branched_q1_stream = [
tf.reduce_sum(_branch, axis=1, keep_dims=True)
for _branch in branched_q1_stream
]
branched_q2_stream = [
tf.reduce_sum(_branch, axis=1, keep_dims=True)
for _branch in branched_q2_stream
]
q1_stream = tf.reduce_mean(branched_q1_stream, axis=0)
q2_stream = tf.reduce_mean(branched_q2_stream, axis=0)
else:
q1_stream = q1_streams[name]
q2_stream = q2_streams[name]
_q1_loss = 0.5 * tf.reduce_mean(
tf.to_float(self.policy.mask)
* tf.squared_difference(q_backup, q1_stream)
)
_q2_loss = 0.5 * tf.reduce_mean(
tf.to_float(self.policy.mask)
* tf.squared_difference(q_backup, q2_stream)
)
q1_losses.append(_q1_loss)
q2_losses.append(_q2_loss)
self.q1_loss = tf.reduce_mean(q1_losses)
self.q2_loss = tf.reduce_mean(q2_losses)
# Learn entropy coefficient
if discrete:
# Create a log_ent_coef for each branch
self.log_ent_coef = tf.get_variable(
"log_ent_coef",
dtype=tf.float32,
initializer=np.log([self.init_entcoef] * len(self.act_size)).astype(
np.float32
),
trainable=True,
)
else:
self.log_ent_coef = tf.get_variable(
"log_ent_coef",
dtype=tf.float32,
initializer=np.log(self.init_entcoef).astype(np.float32),
trainable=True,
)
self.ent_coef = tf.exp(self.log_ent_coef)
if discrete:
# We also have to do a different entropy and target_entropy per branch.
branched_per_action_ent = ModelUtils.break_into_branches(
per_action_entropy, self.act_size
)
branched_ent_sums = tf.stack(
[
tf.reduce_sum(_lp, axis=1, keep_dims=True) + _te
for _lp, _te in zip(branched_per_action_ent, self.target_entropy)
],
axis=1,
)
self.entropy_loss = -tf.reduce_mean(
tf.to_float(self.policy.mask)
* tf.reduce_mean(
self.log_ent_coef
* tf.squeeze(tf.stop_gradient(branched_ent_sums), axis=2),
axis=1,
)
)
# Same with policy loss, we have to do the loss per branch and average them,
# so that larger branches don't get more weight.
# The equivalent KL divergence from Eq 10 of Haarnoja et al. is also pi*log(pi) - Q
branched_q_term = ModelUtils.break_into_branches(
discrete_action_probs * self.policy_network.q1_p, self.act_size
)
branched_policy_loss = tf.stack(
[
tf.reduce_sum(self.ent_coef[i] * _lp - _qt, axis=1, keep_dims=True)
for i, (_lp, _qt) in enumerate(
zip(branched_per_action_ent, branched_q_term)
)
]
)
self.policy_loss = tf.reduce_mean(
tf.to_float(self.policy.mask) * tf.squeeze(branched_policy_loss)
)
# Do vbackup entropy bonus per branch as well.
branched_ent_bonus = tf.stack(
[
tf.reduce_sum(self.ent_coef[i] * _lp, axis=1, keep_dims=True)
for i, _lp in enumerate(branched_per_action_ent)
]
)
value_losses = []
for name in stream_names:
v_backup = tf.stop_gradient(
self.min_policy_qs[name]
- tf.reduce_mean(branched_ent_bonus, axis=0)
)
value_losses.append(
0.5
* tf.reduce_mean(
tf.to_float(self.policy.mask)
* tf.squared_difference(
self.policy_network.value_heads[name], v_backup
)
)
)
else:
self.entropy_loss = -tf.reduce_mean(
self.log_ent_coef
* tf.to_float(self.policy.mask)
* tf.stop_gradient(
tf.reduce_sum(
self.policy.all_log_probs + self.target_entropy,
axis=1,
keep_dims=True,
)
)
)
batch_policy_loss = tf.reduce_mean(
self.ent_coef * self.policy.all_log_probs - self.policy_network.q1_p,
axis=1,
)
self.policy_loss = tf.reduce_mean(
tf.to_float(self.policy.mask) * batch_policy_loss
)
value_losses = []
for name in stream_names:
v_backup = tf.stop_gradient(
self.min_policy_qs[name]
- tf.reduce_sum(self.ent_coef * self.policy.all_log_probs, axis=1)
)
value_losses.append(
0.5
* tf.reduce_mean(
tf.to_float(self.policy.mask)
* tf.squared_difference(
self.policy_network.value_heads[name], v_backup
)
)
)
self.value_loss = tf.reduce_mean(value_losses)
self.total_value_loss = self.q1_loss + self.q2_loss + self.value_loss
self.entropy = self.policy_network.entropy
def _create_sac_optimizer_ops(self) -> None:
"""
Creates the Adam optimizers and update ops for SAC, including
the policy, value, and entropy updates, as well as the target network update.
"""
policy_optimizer = self.create_optimizer_op(
learning_rate=self.learning_rate, name="sac_policy_opt"
)
entropy_optimizer = self.create_optimizer_op(
learning_rate=self.learning_rate, name="sac_entropy_opt"
)
value_optimizer = self.create_optimizer_op(
learning_rate=self.learning_rate, name="sac_value_opt"
)
self.target_update_op = [
tf.assign(target, (1 - self.tau) * target + self.tau * source)
for target, source in zip(
self.target_network.value_vars, self.policy_network.value_vars
)
]
logger.debug("value_vars")
self.print_all_vars(self.policy_network.value_vars)
logger.debug("targvalue_vars")
self.print_all_vars(self.target_network.value_vars)
logger.debug("critic_vars")
self.print_all_vars(self.policy_network.critic_vars)
logger.debug("q_vars")
self.print_all_vars(self.policy_network.q_vars)
logger.debug("policy_vars")
policy_vars = self.policy.get_trainable_variables()
self.print_all_vars(policy_vars)
self.target_init_op = [
tf.assign(target, source)
for target, source in zip(
self.target_network.value_vars, self.policy_network.value_vars
)
]
self.update_batch_policy = policy_optimizer.minimize(
self.policy_loss, var_list=policy_vars
)
# Make sure policy is updated first, then value, then entropy.
with tf.control_dependencies([self.update_batch_policy]):
self.update_batch_value = value_optimizer.minimize(
self.total_value_loss, var_list=self.policy_network.critic_vars
)
# Add entropy coefficient optimization operation
with tf.control_dependencies([self.update_batch_value]):
self.update_batch_entropy = entropy_optimizer.minimize(
self.entropy_loss, var_list=self.log_ent_coef
)
def print_all_vars(self, variables):
for _var in variables:
logger.debug(_var)
@timed
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.
:param batch: Experience mini-batch.
:param update_target: Whether or not to update target value network
:param reward_signal_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
feed_dict = self._construct_feed_dict(self.policy, batch, num_sequences)
stats_needed = self.stats_name_to_update_name
update_stats: Dict[str, float] = {}
update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
# Update target network. By default, target update happens at every policy update.
self.sess.run(self.target_update_op)
return update_stats
def update_reward_signals(
self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int
) -> Dict[str, float]:
"""
Only update the reward signals.
:param reward_signal_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
"""
# Collect feed dicts for all reward signals.
feed_dict: Dict[tf.Tensor, Any] = {}
update_dict: Dict[str, tf.Tensor] = {}
update_stats: Dict[str, float] = {}
stats_needed: Dict[str, str] = {}
if reward_signal_minibatches:
self.add_reward_signal_dicts(
feed_dict,
update_dict,
stats_needed,
reward_signal_minibatches,
num_sequences,
)
update_vals = self._execute_model(feed_dict, update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
return update_stats
def add_reward_signal_dicts(
self,
feed_dict: Dict[tf.Tensor, Any],
update_dict: Dict[str, tf.Tensor],
stats_needed: Dict[str, str],
reward_signal_minibatches: Mapping[str, AgentBuffer],
num_sequences: int,
) -> None:
"""
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict.
:param feed_dict: Feed dict needed update
:param update_dit: Update dict that needs update
:param stats_needed: Stats needed to get from the update.
:param reward_signal_minibatches: Minibatches to use for updating the reward signals,
indexed by name.
"""
for name, r_batch in reward_signal_minibatches.items():
feed_dict.update(
self.reward_signals[name].prepare_update(
self.policy, r_batch, num_sequences
)
)
update_dict.update(self.reward_signals[name].update_dict)
stats_needed.update(self.reward_signals[name].stats_name_to_update_name)
def _construct_feed_dict(
self, policy: TFPolicy, batch: AgentBuffer, num_sequences: int
) -> Dict[tf.Tensor, Any]:
"""
Builds the feed dict for updating the SAC model.
:param model: The model to update. May be different when, e.g. using multi-GPU.
:param batch: Mini-batch to use to update.
:param num_sequences: Number of LSTM sequences in batch.
"""
# Do an optional burn-in for memories
num_burn_in = int(self.burn_in_ratio * self.policy.sequence_length)
burn_in_mask = np.ones((self.policy.sequence_length), dtype=np.float32)
burn_in_mask[range(0, num_burn_in)] = 0
burn_in_mask = np.tile(burn_in_mask, num_sequences)
feed_dict = {
policy.batch_size_ph: num_sequences,
policy.sequence_length_ph: self.policy.sequence_length,
self.next_sequence_length_ph: self.policy.sequence_length,
self.policy.mask_input: batch["masks"] * burn_in_mask,
}
for name in self.reward_signals:
feed_dict[self.rewards_holders[name]] = batch["{}_rewards".format(name)]
if self.policy.use_continuous_act:
feed_dict[self.policy_network.external_action_in] = batch["actions"]
else:
feed_dict[policy.output] = batch["actions"]
if self.policy.use_recurrent:
feed_dict[policy.prev_action] = batch["prev_action"]
feed_dict[policy.action_masks] = batch["action_mask"]
if self.policy.use_vec_obs:
feed_dict[policy.vector_in] = batch["vector_obs"]
feed_dict[self.next_vector_in] = batch["next_vector_in"]
if self.policy.vis_obs_size > 0:
for i, _ in enumerate(policy.visual_in):
_obs = batch["visual_obs%d" % i]
feed_dict[policy.visual_in[i]] = _obs
for i, _ in enumerate(self.next_visual_in):
_obs = batch["next_visual_obs%d" % i]
feed_dict[self.next_visual_in[i]] = _obs
if self.policy.use_recurrent:
feed_dict[policy.memory_in] = [
batch["memory"][i]
for i in range(0, len(batch["memory"]), self.policy.sequence_length)
]
feed_dict[self.policy_network.memory_in] = self._make_zero_mem(
self.m_size, batch.num_experiences
)
feed_dict[self.target_network.memory_in] = self._make_zero_mem(
self.m_size // 3, batch.num_experiences
)
feed_dict[self.dones_holder] = batch["done"]
return feed_dict

340
ml-agents/mlagents/trainers/sac_transfer/trainer.py


# ## ML-Agent Learning (SAC)
# Contains an implementation of SAC as described in https://arxiv.org/abs/1801.01290
# and implemented in https://github.com/hill-a/stable-baselines
from collections import defaultdict
from typing import Dict, cast
import os
import numpy as np
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy.nn_policy import NNPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.settings import TrainerSettings, SACSettings
logger = get_logger(__name__)
BUFFER_TRUNCATE_PERCENT = 0.8
class SACTrainer(RLTrainer):
"""
The SACTrainer is an implementation of the SAC algorithm, with support
for discrete actions and recurrent networks.
"""
def __init__(
self,
brain_name: str,
reward_buff_cap: int,
trainer_settings: TrainerSettings,
training: bool,
load: bool,
seed: int,
artifact_path: str,
):
"""
Responsible for collecting experiences and training SAC model.
:param brain_name: The name of the brain associated with trainer config
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_settings: The parameters for the trainer.
:param training: Whether the trainer is set for training.
:param load: Whether the model should be loaded.
:param seed: The seed the model will be initialized with
:param artifact_path: The directory within which to store artifacts from this trainer.
"""
super().__init__(
brain_name, trainer_settings, training, artifact_path, reward_buff_cap
)
self.load = load
self.seed = seed
self.policy: NNPolicy = None # type: ignore
self.optimizer: SACOptimizer = None # type: ignore
self.hyperparameters: SACSettings = cast(
SACSettings, trainer_settings.hyperparameters
)
self.step = 0
# Don't divide by zero
self.update_steps = 1
self.reward_signal_update_steps = 1
self.steps_per_update = self.hyperparameters.steps_per_update
self.reward_signal_steps_per_update = (
self.hyperparameters.reward_signal_steps_per_update
)
self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer
def save_model(self, name_behavior_id: str) -> None:
"""
Saves the model. Overrides the default save_model since we want to save
the replay buffer as well.
"""
self.policy.save_model(self.get_step)
if self.checkpoint_replay_buffer:
self.save_replay_buffer()
def save_replay_buffer(self) -> None:
"""
Save the training buffer's update buffer to a pickle file.
"""
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5")
logger.info("Saving Experience Replay Buffer to {}".format(filename))
with open(filename, "wb") as file_object:
self.update_buffer.save_to_file(file_object)
def load_replay_buffer(self) -> None:
"""
Loads the last saved replay buffer from a file.
"""
filename = os.path.join(self.artifact_path, "last_replay_buffer.hdf5")
logger.info("Loading Experience Replay Buffer from {}".format(filename))
with open(filename, "rb+") as file_object:
self.update_buffer.load_from_file(file_object)
logger.info(
"Experience replay buffer has {} experiences.".format(
self.update_buffer.num_experiences
)
)
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the replay buffer.
"""
super()._process_trajectory(trajectory)
last_step = trajectory.steps[-1]
agent_id = trajectory.agent_id # All the agents should have the same ID
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization
if self.is_training:
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"])
# Evaluate all reward functions for reporting purposes
self.collected_rewards["environment"][agent_id] += np.sum(
agent_buffer_trajectory["environment_rewards"]
)
for name, reward_signal in self.optimizer.reward_signals.items():
evaluate_result = reward_signal.evaluate_batch(
agent_buffer_trajectory
).scaled_reward
# Report the reward signals
self.collected_rewards[name][agent_id] += np.sum(evaluate_result)
# Get all value estimates for reporting purposes
value_estimates, _ = self.optimizer.get_trajectory_value_estimates(
agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached
)
for name, v in value_estimates.items():
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)
# Bootstrap using the last step rather than the bootstrap step if max step is reached.
# Set last element to duplicate obs and remove dones.
if last_step.interrupted:
vec_vis_obs = SplitObservations.from_observations(last_step.obs)
for i, obs in enumerate(vec_vis_obs.visual_observations):
agent_buffer_trajectory["next_visual_obs%d" % i][-1] = obs
if vec_vis_obs.vector_observations.size > 1:
agent_buffer_trajectory["next_vector_in"][
-1
] = vec_vis_obs.vector_observations
agent_buffer_trajectory["done"][-1] = False
# Append to update buffer
agent_buffer_trajectory.resequence_and_append(
self.update_buffer, training_length=self.policy.sequence_length
)
if trajectory.done_reached:
self._update_end_episode_stats(agent_id, self.optimizer)
def _is_ready_update(self) -> bool:
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to whether or not _update_policy() can be run
"""
return (
self.update_buffer.num_experiences >= self.hyperparameters.batch_size
and self.step >= self.hyperparameters.buffer_init_steps
)
@timed
def _update_policy(self) -> bool:
"""
Update the SAC policy and reward signals. The reward signal generators are updated using different mini batches.
By default we imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated
N times, then the reward signals are updated N times.
:return: Whether or not the policy was updated.
"""
policy_was_updated = self._update_sac_policy()
self._update_reward_signals()
return policy_was_updated
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters
) -> TFPolicy:
policy = NNPolicy(
self.seed,
brain_parameters,
self.trainer_settings,
self.is_training,
self.artifact_path,
self.load,
tanh_squash=True,
reparameterize=True,
create_tf_graph=False,
)
# Load the replay buffer if load
if self.load and self.checkpoint_replay_buffer:
try:
self.load_replay_buffer()
except (AttributeError, FileNotFoundError):
logger.warning(
"Replay buffer was unable to load, starting from scratch."
)
logger.debug(
"Loaded update buffer with {} sequences".format(
self.update_buffer.num_experiences
)
)
return policy
def _update_sac_policy(self) -> bool:
"""
Uses update_buffer to update the policy. We sample the update_buffer and update
until the steps_per_update ratio is met.
"""
has_updated = False
self.cumulative_returns_since_policy_update.clear()
n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
batch_update_stats: Dict[str, list] = defaultdict(list)
while (
self.step - self.hyperparameters.buffer_init_steps
) / self.update_steps > self.steps_per_update:
logger.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.update_buffer
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size:
sampled_minibatch = buffer.sample_mini_batch(
self.hyperparameters.batch_size,
sequence_length=self.policy.sequence_length,
)
# Get rewards for each reward
for name, signal in self.optimizer.reward_signals.items():
sampled_minibatch[
"{}_rewards".format(name)
] = signal.evaluate_batch(sampled_minibatch).scaled_reward
update_stats = self.optimizer.update(sampled_minibatch, n_sequences)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
self.update_steps += 1
for stat, stat_list in batch_update_stats.items():
self._stats_reporter.add_stat(stat, np.mean(stat_list))
has_updated = True
if self.optimizer.bc_module:
update_stats = self.optimizer.bc_module.update()
for stat, val in update_stats.items():
self._stats_reporter.add_stat(stat, val)
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating
# a large buffer at each update.
if self.update_buffer.num_experiences > self.hyperparameters.buffer_size:
self.update_buffer.truncate(
int(self.hyperparameters.buffer_size * BUFFER_TRUNCATE_PERCENT)
)
return has_updated
def _update_reward_signals(self) -> None:
"""
Iterate through the reward signals and update them. Unlike in PPO,
do it separate from the policy so that it can be done at a different
interval.
This function should only be used to simulate
http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated
N times, then the reward signals are updated N times. Normally, the reward signal
and policy are updated in parallel.
"""
buffer = self.update_buffer
n_sequences = max(
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1
)
batch_update_stats: Dict[str, list] = defaultdict(list)
while (
self.step - self.hyperparameters.buffer_init_steps
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update:
# Get minibatches for reward signal update if needed
reward_signal_minibatches = {}
for name, signal in self.optimizer.reward_signals.items():
logger.debug("Updating {} at step {}".format(name, self.step))
# Some signals don't need a minibatch to be sampled - so we don't!
if signal.update_dict:
reward_signal_minibatches[name] = buffer.sample_mini_batch(
self.hyperparameters.batch_size,
sequence_length=self.policy.sequence_length,
)
update_stats = self.optimizer.update_reward_signals(
reward_signal_minibatches, n_sequences
)
for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
self.reward_signal_update_steps += 1
for stat, stat_list in batch_update_stats.items():
self._stats_reporter.add_stat(stat, np.mean(stat_list))
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
) -> None:
"""
Adds policy to trainer.
:param brain_parameters: specifications for policy construction
"""
if self.policy:
logger.warning(
"Your environment contains multiple teams, but {} doesn't support adversarial games. Enable self-play to \
train adversarial games.".format(
self.__class__.__name__
)
)
if not isinstance(policy, NNPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()
# Assume steps were updated at the correct ratio before
self.update_steps = int(max(1, self.step / self.steps_per_update))
self.reward_signal_update_steps = int(
max(1, self.step / self.reward_signal_steps_per_update)
)
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""
Gets policy from trainer associated with name_behavior_id
:param name_behavior_id: full identifier of policy
"""
return self.policy
正在加载...
取消
保存