Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

1071 行
41 KiB

import os
from typing import Any, Dict, Optional, List, Tuple
from mlagents.tf_utils import tf
from mlagents_envs.timers import timed
from mlagents_envs.base_env import DecisionSteps
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.models import EncoderType
from mlagents.trainers.models import ModelUtils
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,
)
# import tf_slim as slim
EPSILON = 1e-6 # Small value to avoid divide by zero
class GaussianEncoderDistribution:
def __init__(
self,
encoded: tf.Tensor,
feature_size: int,
reuse: bool=False
):
self.mu = tf.layers.dense(
encoded,
feature_size,
activation=None,
name="mu",
kernel_initializer=ModelUtils.scaled_init(0.01),
reuse=reuse,
)
self.log_sigma = tf.layers.dense(
encoded,
feature_size,
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
reuse=reuse
)
self.sigma = tf.exp(self.log_sigma)
def sample(self):
epsilon = tf.random_normal(tf.shape(self.mu))
sampled = self.mu + self.sigma * epsilon
return sampled
def kl_standard(self):
"""
KL divergence with a standard gaussian
"""
kl = 0.5 * tf.reduce_sum(tf.square(self.mu) + tf.square(self.sigma) - 2 * self.log_sigma - 1, 1)
return kl
class TransferPolicy(TFPolicy):
def __init__(
self,
seed: int,
brain: BrainParameters,
trainer_params: TrainerSettings,
is_training: bool,
model_path: str,
load: bool,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
create_tf_graph: bool = True,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could
also use a CNN to encode visual input prior to the MLP. Supports discrete and
continuous action spaces, as well as recurrent networks.
:param seed: Random seed.
:param brain: Assigned BrainParameters object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param model_path: Path where the model should be saved and loaded.
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output.
"""
super().__init__(seed, brain, trainer_params, model_path, load)
self.grads = None
self.update_batch: Optional[tf.Operation] = None
num_layers = self.network_settings.num_layers
self.h_size = self.network_settings.hidden_units
if num_layers < 1:
num_layers = 1
self.num_layers = num_layers
self.vis_encode_type = self.network_settings.vis_encode_type
self.tanh_squash = tanh_squash
self.reparameterize = reparameterize
self.condition_sigma_on_obs = condition_sigma_on_obs
self.trainable_variables: List[tf.Variable] = []
self.encoder = None
self.encoder_distribution = None
self.targ_encoder = None
# Non-exposed parameters; these aren't exposed because they don't have a
# good explanation and usually shouldn't be touched.
self.log_std_min = -20
self.log_std_max = 2
if create_tf_graph:
self.create_tf_graph()
def get_trainable_variables(self) -> List[tf.Variable]:
"""
Returns a List of the trainable variables in this policy. if create_tf_graph hasn't been called,
returns empty list.
"""
return self.trainable_variables
def create_tf_graph(self,
encoder_layers = 1,
policy_layers = 1,
forward_layers = 1,
inverse_layers = 1,
feature_size = 16,
transfer=False,
separate_train=False,
var_encoder=False,
var_predict=False,
predict_return=False,
inverse_model=False,
reuse_encoder=False,
) -> None:
"""
Builds the tensorflow graph needed for this policy.
"""
self.inverse_model = inverse_model
self.reuse_encoder = reuse_encoder
self.feature_size = feature_size
with self.graph.as_default():
tf.set_random_seed(self.seed)
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
if len(_vars) > 0:
# We assume the first thing created in the graph is the Policy. If
# already populated, don't create more tensors.
return
self.create_input_placeholders()
self.current_action = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="current_action"
)
self.current_reward = tf.placeholder(
shape=[None], dtype=tf.float32, name="current_reward"
)
self.next_visual_in: List[tf.Tensor] = []
# if var_encoder:
# self.encoder, self.targ_encoder, self.encoder_distribution, _ = self.create_encoders(var_latent=True, reuse_encoder=reuse_encoder)
# else:
# self.encoder, self.targ_encoder = self.create_encoders(reuse_encoder=reuse_encoder)
# if not reuse_encoder:
# self.targ_encoder = tf.stop_gradient(self.targ_encoder)
# self._create_hard_copy()
if var_encoder:
self.encoder_distribution, self.encoder = self._create_var_encoder(
self.visual_in,
self.processed_vector_in,
self.h_size,
self.feature_size,
encoder_layers,
self.vis_encode_type
)
_, self.targ_encoder = self._create_var_target_encoder(
self.h_size,
self.feature_size,
encoder_layers,
self.vis_encode_type,
reuse_encoder
)
else:
self.encoder = self._create_encoder(
self.visual_in,
self.processed_vector_in,
self.h_size,
self.feature_size,
encoder_layers,
self.vis_encode_type
)
self.targ_encoder = self._create_target_encoder(
self.h_size,
self.feature_size,
encoder_layers,
self.vis_encode_type,
reuse_encoder
)
if not reuse_encoder:
self.targ_encoder = tf.stop_gradient(self.targ_encoder)
self._create_hard_copy()
if self.inverse_model:
with tf.variable_scope("inverse"):
self.create_inverse_model(self.encoder, self.targ_encoder, inverse_layers)
with tf.variable_scope("predict"):
self.create_forward_model(self.encoder, self.targ_encoder, forward_layers,
var_predict=var_predict)
if predict_return:
with tf.variable_scope("reward"):
self.create_reward_model(self.encoder, self.targ_encoder, forward_layers)
# if var_predict:
# self.predict_distribution, self.predict = self._create_var_world_model(
# self.encoder,
# self.h_size,
# self.feature_size,
# self.num_layers,
# self.vis_encode_type,
# predict_return
# )
# else:
# self.predict = self._create_world_model(
# self.encoder,
# self.h_size,
# self.feature_size,
# self.num_layers,
# self.vis_encode_type,
# predict_return
# )
# if inverse_model:
# self._create_inverse_model(self.encoder, self.targ_encoder)
if self.use_continuous_act:
self._create_cc_actor(
self.encoder,
self.h_size,
policy_layers,
self.tanh_squash,
self.reparameterize,
self.condition_sigma_on_obs,
separate_train
)
else:
self._create_dc_actor(self.encoder, self.h_size, policy_layers, separate_train)
self.trainable_variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="policy"
)
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="encoding"
)
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="predict"
)
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="lstm"
) # LSTMs need to be root scope for Barracuda export
if self.inverse_model:
self.trainable_variables += tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="inverse"
)
self.inference_dict: Dict[str, tf.Tensor] = {
"action": self.output,
"log_probs": self.all_log_probs,
"entropy": self.entropy,
}
if self.use_continuous_act:
self.inference_dict["pre_action"] = self.output_pre
if self.use_recurrent:
self.inference_dict["memory_out"] = self.memory_out
# We do an initialize to make the Policy usable out of the box. If an optimizer is needed,
# it will re-load the full graph
self._initialize_graph()
# slim.model_analyzer.analyze_vars(self.trainable_variables, print_info=True)
def load_graph_partial(self, path: str, transfer_type="dynamics", load_model=True, load_policy=True,
load_value=True):
load_nets = {"dynamics": [],
"observation": ["encoding", "inverse"]}
if load_model:
load_nets["dynamics"].append("predict")
if load_policy:
load_nets["dynamics"].append("policy")
if load_value:
load_nets["dynamics"].append("value")
if self.inverse_model:
load_nets["dynamics"].append("inverse")
with self.graph.as_default():
for net in load_nets[transfer_type]:
variables_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, net)
partial_saver = tf.train.Saver(variables_to_restore)
partial_model_checkpoint = os.path.join(path, f"{net}.ckpt")
partial_saver.restore(self.sess, partial_model_checkpoint)
print("loaded net", net, "from path", path)
# variables_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding/latent")
# partial_saver = tf.train.Saver(variables_to_restore)
# partial_model_checkpoint = os.path.join(path, f"latent.ckpt")
# partial_saver.restore(self.sess, partial_model_checkpoint)
# print("loaded net latent from path", path)
if transfer_type == "observation":
self.run_hard_copy()
def _create_world_model(
self,
encoder: tf.Tensor,
h_size: int,
feature_size: int,
num_layers: int,
vis_encode_type: EncoderType,
predict_return: bool=False
) -> tf.Tensor:
""""
Builds the world model for state prediction
"""
with self.graph.as_default():
with tf.variable_scope("predict"):
# self.current_action = tf.placeholder(
# shape=[None, sum(self.act_size)], dtype=tf.float32, name="current_action"
# )
hidden_stream = ModelUtils.create_vector_observation_encoder(
tf.concat([encoder, self.current_action], axis=1),
h_size,
ModelUtils.swish,
num_layers,
scope=f"main_graph",
reuse=False
)
if predict_return:
predict = tf.layers.dense(
hidden_stream,
feature_size+1,
name="next_state"
)
else:
predict = tf.layers.dense(
hidden_stream,
feature_size,
name="next_state"
)
return predict
def _create_var_world_model(
self,
encoder: tf.Tensor,
h_size: int,
feature_size: int,
num_layers: int,
vis_encode_type: EncoderType,
predict_return: bool=False
) -> tf.Tensor:
""""
Builds the world model for state prediction
"""
with self.graph.as_default():
with tf.variable_scope("predict"):
hidden_stream = ModelUtils.create_vector_observation_encoder(
tf.concat([encoder, self.current_action], axis=1),
h_size,
ModelUtils.swish,
num_layers,
scope=f"main_graph",
reuse=False
)
with tf.variable_scope("latent"):
if predict_return:
predict_distribution = GaussianEncoderDistribution(
hidden_stream,
feature_size+1
)
# separate prediction of return
else:
predict_distribution = GaussianEncoderDistribution(
hidden_stream,
feature_size
)
predict = predict_distribution.sample()
return predict_distribution, predict
@timed
def evaluate(
self, decision_requests: DecisionSteps, global_agent_ids: List[str]
) -> Dict[str, Any]:
"""
Evaluates policy for the agent experiences provided.
:param decision_requests: DecisionSteps object containing inputs.
:param global_agent_ids: The global (with worker ID) agent ids of the data in the batched_step_result.
:return: Outputs from network as defined by self.inference_dict.
"""
feed_dict = {
self.batch_size_ph: len(decision_requests),
self.sequence_length_ph: 1,
}
if self.use_recurrent:
if not self.use_continuous_act:
feed_dict[self.prev_action] = self.retrieve_previous_action(
global_agent_ids
)
feed_dict[self.memory_in] = self.retrieve_memories(global_agent_ids)
feed_dict = self.fill_eval_dict(feed_dict, decision_requests)
run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
def _create_target_encoder(
self,
h_size: int,
feature_size: int,
num_layers: int,
vis_encode_type: EncoderType,
reuse_encoder: bool
) -> tf.Tensor:
if reuse_encoder:
next_encoder_scope = "encoding"
else:
next_encoder_scope = "target_enc"
self.visual_next = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)
self.vector_next = ModelUtils.create_vector_input(self.vec_obs_size)
# if self.normalize:
# self.processed_vector_next = ModelUtils.normalize_vector_obs(
# self.vector_next,
# self.running_mean,
# self.running_variance,
# self.normalization_steps,
# )
# else:
# self.processed_vector_next = self.vector_next
with tf.variable_scope(next_encoder_scope):
hidden_stream_targ = ModelUtils.create_observation_streams(
self.visual_next,
self.vector_next,
1,
h_size,
num_layers,
vis_encode_type,
reuse=reuse_encoder
)[0]
latent_targ = tf.layers.dense(
hidden_stream_targ,
feature_size,
name="latent",
reuse=reuse_encoder,
activation=ModelUtils.swish,
kernel_initializer=tf.initializers.variance_scaling(1.0),
)
return latent_targ
# return tf.stop_gradient(latent_targ)
def _create_encoder(
self,
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
h_size: int,
feature_size: int,
num_layers: int,
vis_encode_type: EncoderType,
) -> tf.Tensor:
"""
Creates an encoder for visual and vector observations.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:return: The hidden layer (tf.Tensor) after the encoder.
"""
with tf.variable_scope("encoding"):
hidden_stream = ModelUtils.create_observation_streams(
visual_in,
vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
latent = tf.layers.dense(
hidden_stream,
feature_size,
name="latent",
activation=ModelUtils.swish,
kernel_initializer=tf.initializers.variance_scaling(1.0),
)
return latent
def _create_var_target_encoder(
self,
h_size: int,
feature_size: int,
num_layers: int,
vis_encode_type: EncoderType,
reuse_encoder: bool
) -> tf.Tensor:
if reuse_encoder:
next_encoder_scope = "encoding"
else:
next_encoder_scope = "target_enc"
self.visual_next = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)
self.vector_next = ModelUtils.create_vector_input(self.vec_obs_size)
# if self.normalize:
# self.processed_vector_next = ModelUtils.normalize_vector_obs(
# self.vector_next,
# self.running_mean,
# self.running_variance,
# self.normalization_steps,
# )
# else:
# self.processed_vector_next = self.vector_next
with tf.variable_scope(next_encoder_scope):
hidden_stream_targ = ModelUtils.create_observation_streams(
self.visual_next,
self.vector_next,
1,
h_size,
num_layers,
vis_encode_type,
reuse=reuse_encoder
)[0]
with tf.variable_scope("latent"):
latent_targ_distribution = GaussianEncoderDistribution(
hidden_stream_targ,
feature_size,
reuse=reuse_encoder
)
latent_targ = latent_targ_distribution.sample()
return latent_targ_distribution, latent_targ
def _create_var_encoder(
self,
visual_in: List[tf.Tensor],
vector_in: tf.Tensor,
h_size: int,
feature_size: int,
num_layers: int,
vis_encode_type: EncoderType
) -> tf.Tensor:
"""
Creates a variational encoder for visual and vector observations.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:return: The hidden layer (tf.Tensor) after the encoder.
"""
with tf.variable_scope("encoding"):
hidden_stream = ModelUtils.create_observation_streams(
visual_in,
vector_in,
1,
h_size,
num_layers,
vis_encode_type,
)[0]
with tf.variable_scope("latent"):
latent_distribution = GaussianEncoderDistribution(
hidden_stream,
feature_size
)
latent = latent_distribution.sample()
return latent_distribution, latent
def _create_hard_copy(self):
t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_enc')
e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoding')
with tf.variable_scope('hard_replacement'):
self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]
def run_hard_copy(self):
self.sess.run(self.target_replace_op)
def _create_inverse_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
) -> None:
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
with tf.variable_scope("inverse"):
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, self.h_size, activation=ModelUtils.swish)
if self.brain.vector_action_space_type == "continuous":
pred_action = tf.layers.dense(
hidden, self.act_size[0], activation=None
)
squared_difference = tf.reduce_sum(
tf.squared_difference(pred_action, self.current_action), axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
else:
pred_action = tf.concat(
[
tf.layers.dense(
hidden, self.act_size[i], activation=tf.nn.softmax
)
for i in range(len(self.act_size))
],
axis=1,
)
cross_entropy = tf.reduce_sum(
-tf.log(pred_action + 1e-10) * self.current_action, axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(cross_entropy, self.mask, 2)[1]
)
def _create_cc_actor(
self,
encoded: tf.Tensor,
h_size: int,
num_layers: int,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
separate_train: bool = False
) -> None:
"""
Creates Continuous control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
:param tanh_squash: Whether to use a tanh function, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy.
"""
if self.use_recurrent:
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
)
self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")
else:
hidden_policy = encoded
if separate_train:
hidden_policy = tf.stop_gradient(hidden_policy)
with tf.variable_scope("policy"):
hidden_policy = ModelUtils.create_vector_observation_encoder(
hidden_policy,
h_size,
ModelUtils.swish,
num_layers,
scope=f"main_graph",
reuse=False,
)
distribution = GaussianDistribution(
hidden_policy,
self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
condition_sigma=condition_sigma_on_obs,
)
if tanh_squash:
self.output_pre = distribution.sample
self.output = tf.identity(self.output_pre, name="action")
else:
self.output_pre = distribution.sample
# Clip and scale output to ensure actions are always within [-1, 1] range.
output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
self.output = tf.identity(output_post, name="action")
self.selected_actions = tf.stop_gradient(self.output)
self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs")
self.entropy = distribution.entropy
# We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
self.total_log_probs = distribution.total_log_probs
def _create_dc_actor(
self,
encoded: tf.Tensor,
h_size: int,
num_layers: int,
separate_train: bool = False
) -> None:
"""
Creates Discrete control actor-critic model.
:param h_size: Size of hidden linear layers.
:param num_layers: Number of hidden linear layers.
:param vis_encode_type: Type of visual encoder to use if visual input.
"""
if self.use_recurrent:
self.prev_action = tf.placeholder(
shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
)
prev_action_oh = tf.concat(
[
tf.one_hot(self.prev_action[:, i], self.act_size[i])
for i in range(len(self.act_size))
],
axis=1,
)
hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
self.memory_in = tf.placeholder(
shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
)
hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
hidden_policy,
self.memory_in,
self.sequence_length_ph,
name="lstm_policy",
)
self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
else:
hidden_policy = encoded
if separate_train:
hidden_policy = tf.stop_gradient(hidden_policy)
self.action_masks = tf.placeholder(
shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
)
with tf.variable_scope("policy"):
hidden_policy = ModelUtils.create_vector_observation_encoder(
hidden_policy,
h_size,
ModelUtils.swish,
num_layers,
scope=f"main_graph",
reuse=False,
)
distribution = MultiCategoricalDistribution(
hidden_policy, self.act_size, self.action_masks
)
# It's important that we are able to feed_dict a value into this tensor to get the
# right one-hot encoding, so we can't do identity on it.
self.output = distribution.sample
self.all_log_probs = tf.identity(distribution.log_probs, name="action")
self.selected_actions = tf.stop_gradient(
distribution.sample_onehot
) # In discrete, these are onehot
self.entropy = distribution.entropy
self.total_log_probs = distribution.total_log_probs
def save_model(self, steps):
"""
Saves the model
:param steps: The number of steps the model was trained for
:return:
"""
self.get_policy_weights()
with self.graph.as_default():
last_checkpoint = os.path.join(self.model_path, f"model-{steps}.ckpt")
self.saver.save(self.sess, last_checkpoint)
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
# save each net separately
policy_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "policy")
policy_saver = tf.train.Saver(policy_vars)
policy_checkpoint = os.path.join(self.model_path, f"policy.ckpt")
policy_saver.save(self.sess, policy_checkpoint)
encoding_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding")
encoding_saver = tf.train.Saver(encoding_vars)
encoding_checkpoint = os.path.join(self.model_path, f"encoding.ckpt")
encoding_saver.save(self.sess, encoding_checkpoint)
latent_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoding/latent")
latent_saver = tf.train.Saver(latent_vars)
latent_checkpoint = os.path.join(self.model_path, f"latent.ckpt")
latent_saver.save(self.sess, latent_checkpoint)
predict_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "predict")
predict_saver = tf.train.Saver(predict_vars)
predict_checkpoint = os.path.join(self.model_path, f"predict.ckpt")
predict_saver.save(self.sess, predict_checkpoint)
value_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "value")
value_saver = tf.train.Saver(value_vars)
value_checkpoint = os.path.join(self.model_path, f"value.ckpt")
value_saver.save(self.sess, value_checkpoint)
if self.inverse_model:
inverse_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "inverse")
inverse_saver = tf.train.Saver(inverse_vars)
inverse_checkpoint = os.path.join(self.model_path, f"inverse.ckpt")
inverse_saver.save(self.sess, inverse_checkpoint)
def get_encoder_weights(self):
with self.graph.as_default():
enc = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "encoding/latent/bias:0")
targ = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "target_enc/latent/bias:0")
print("encoding:", self.sess.run(enc))
print("target:", self.sess.run(targ))
def get_policy_weights(self):
with self.graph.as_default():
pol = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "policy/mu/bias:0")
print("policy:", self.sess.run(pol))
def create_encoders(self, var_latent: bool=False, reuse_encoder: bool=False) -> Tuple[tf.Tensor, tf.Tensor]:
encoded_state_list = []
encoded_next_state_list = []
if reuse_encoder:
next_encoder_scope = "encoding"
else:
next_encoder_scope = "target_enc"
if self.vis_obs_size > 0:
self.next_visual_in = []
visual_encoders = []
next_visual_encoders = []
for i in range(self.vis_obs_size):
# Create input ops for next (t+1) visual observations.
next_visual_input = ModelUtils.create_visual_input(
self.brain.camera_resolutions[i],
name="next_visual_observation_" + str(i),
)
self.next_visual_in.append(next_visual_input)
# Create the encoder ops for current and next visual input.
# Note that these encoders are siamese.
with tf.variable_scope("encoding"):
encoded_visual = ModelUtils.create_visual_observation_encoder(
self.visual_in[i],
self.h_size,
ModelUtils.swish,
self.num_layers,
"stream_{}_visual_obs_encoder".format(i),
False,
)
with tf.variable_scope(next_encoder_scope):
encoded_next_visual = ModelUtils.create_visual_observation_encoder(
self.next_visual_in[i],
self.h_size,
ModelUtils.swish,
self.num_layers,
"stream_{}_visual_obs_encoder".format(i),
reuse_encoder
)
visual_encoders.append(encoded_visual)
next_visual_encoders.append(encoded_next_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
hidden_next_visual = tf.concat(next_visual_encoders, axis=1)
encoded_state_list.append(hidden_visual)
encoded_next_state_list.append(hidden_next_visual)
if self.vec_obs_size > 0:
# Create the encoder ops for current and next vector input.
# Note that these encoders are siamese.
# Create input op for next (t+1) vector observation.
self.next_vector_in = tf.placeholder(
shape=[None, self.vec_obs_size],
dtype=tf.float32,
name="next_vector_observation",
)
if self.normalize:
self.processed_vector_next = ModelUtils.normalize_vector_obs(
self.next_vector_in,
self.running_mean,
self.running_variance,
self.normalization_steps,
)
else:
self.processed_vector_next = self.next_vector_in
with tf.variable_scope("encoding"):
encoded_vector_obs = ModelUtils.create_vector_observation_encoder(
self.vector_in,
self.h_size,
ModelUtils.swish,
self.num_layers,
"vector_obs_encoder",
False,
)
with tf.variable_scope(next_encoder_scope):
encoded_next_vector_obs = ModelUtils.create_vector_observation_encoder(
self.processed_vector_next,
self.h_size,
ModelUtils.swish,
self.num_layers,
"vector_obs_encoder",
reuse_encoder
)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)
encoded_state = tf.concat(encoded_state_list, axis=1)
encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
if var_latent:
with tf.variable_scope("encoding/latent"):
encoded_state_dist = GaussianEncoderDistribution(
encoded_state,
self.feature_size,
)
encoded_state = encoded_state_dist.sample()
with tf.variable_scope(next_encoder_scope+"/latent"):
encoded_next_state_dist = GaussianEncoderDistribution(
encoded_next_state,
self.feature_size,
reuse=reuse_encoder
)
encoded_next_state = encoded_next_state_dist.sample()
return encoded_state, encoded_next_state, encoded_state_dist, encoded_next_state_dist
else:
with tf.variable_scope("encoding"):
encoded_state = tf.layers.dense(
encoded_state,
self.feature_size,
name="latent"
)
with tf.variable_scope(next_encoder_scope):
encoded_next_state = tf.layers.dense(
encoded_next_state,
self.feature_size,
name="latent",
reuse=reuse_encoder
)
return encoded_state, encoded_next_state
def create_inverse_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor, inverse_layers: int
) -> None:
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
# hidden = tf.layers.dense(combined_input, 256, activation=ModelUtils.swish)
hidden = combined_input
for i in range(inverse_layers-1):
hidden = tf.layers.dense(
hidden,
self.h_size,
activation=ModelUtils.swish,
name="hidden_{}".format(i),
kernel_initializer=tf.initializers.variance_scaling(1.0),
)
if self.brain.vector_action_space_type == "continuous":
pred_action = tf.layers.dense(
hidden, self.act_size[0], activation=None, name="pred_action"
)
squared_difference = tf.reduce_sum(
tf.squared_difference(pred_action, self.current_action), axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
else:
pred_action = tf.concat(
[
tf.layers.dense(
hidden, self.act_size[i], activation=tf.nn.softmax, name="pred_action"
)
for i in range(len(self.act_size))
],
axis=1,
)
cross_entropy = tf.reduce_sum(
-tf.log(pred_action + 1e-10) * self.current_action, axis=1
)
self.inverse_loss = tf.reduce_mean(
tf.dynamic_partition(cross_entropy, self.mask, 2)[1]
)
def create_forward_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor, forward_layers: int,
var_predict: bool=False
) -> None:
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.
:param encoded_state: Tensor corresponding to encoded current state.
:param encoded_next_state: Tensor corresponding to encoded next state.
"""
combined_input = tf.concat(
[encoded_state, self.current_action], axis=1
)
hidden = combined_input
for i in range(forward_layers):
hidden = tf.layers.dense(
hidden,
self.h_size,
# * (self.vis_obs_size + int(self.vec_obs_size > 0)),
name="hidden_{}".format(i),
# activation=ModelUtils.swish,
# kernel_initializer=tf.initializers.variance_scaling(1.0),
)
if var_predict:
self.predict_distribution = GaussianEncoderDistribution(
hidden,
self.feature_size
)
self.predict = self.predict_distribution.sample()
else:
self.predict = tf.layers.dense(
hidden,
self.feature_size,
name="latent",
# activation=ModelUtils.swish,
# kernel_initializer=tf.initializers.variance_scaling(1.0),
)
squared_difference = 0.5 * tf.reduce_sum(
tf.squared_difference(self.predict, encoded_next_state), axis=1
)
self.forward_loss = tf.reduce_mean(
tf.dynamic_partition(squared_difference, self.mask, 2)[1]
)
def create_reward_model(self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor, forward_layers: int):
combined_input = tf.concat(
[encoded_state, self.current_action], axis=1
)
hidden = combined_input
for i in range(forward_layers):
hidden = tf.layers.dense(
hidden,
self.h_size
* (self.vis_obs_size + int(self.vec_obs_size > 0)),
name="hidden_{}".format(i),
# activation=ModelUtils.swish,
# kernel_initializer=tf.initializers.variance_scaling(1.0),
)
self.pred_reward = tf.layers.dense(
hidden,
1,
name="reward",
# activation=ModelUtils.swish,
# kernel_initializer=tf.initializers.variance_scaling(1.0),
)
self.reward_loss = tf.reduce_mean(
tf.squared_difference(self.pred_reward, self.current_reward)
)