浏览代码

Dynamically construct actor and critic

/develop/add-fire
Arthur Juliani 4 年前
当前提交
e166d018
共有 3 个文件被更改,包括 106 次插入356 次删除
  1. 102
      ml-agents/mlagents/trainers/models_torch.py
  2. 250
      ml-agents/mlagents/trainers/policy/nn_torch_policy.py
  3. 110
      ml-agents/mlagents/trainers/policy/torch_policy.py

102
ml-agents/mlagents/trainers/models_torch.py


def forward(self, branches_logits, action_masks):
branch_masks = self.break_into_branches(action_masks, self.action_size)
raw_probs = [
torch.multiply(
torch.mul(
torch.divide(raw_probs[k], torch.sum(raw_probs[k], dim=1, keepdims=True))
torch.div(raw_probs[k], torch.sum(raw_probs[k], dim=1, keepdims=True))
for k in range(len(self.action_size))
]
output = torch.cat(

self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
for _ in range(n_blocks):
self.layers.append(self.make_block(channel))
self.layers.append(nn.RELU())
self.layers.append(nn.ReLU())
@staticmethod
def make_block(channel):

EncoderType.RESNET: 15,
}
# @staticmethod
# def scaled_init(scale):
# return tf.initializers.variance_scaling(scale)
@staticmethod
def swish(input_activation: torch.Tensor) -> torch.Tensor:
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""

f"Visual observation resolution ({width}x{height}) is too small for"
f"the provided EncoderType ({vis_encoder_type.value}). The min dimension is {min_res}"
)
# @staticmethod
# def compose_streams(
# visual_in: List[torch.Tensor],
# vector_in: torch.Tensor,
# num_streams: int,
# h_size: int,
# num_layers: int,
# vis_encode_type: EncoderType = EncoderType.SIMPLE,
# stream_scopes: List[str] = None,
# ) -> List[torch.Tensor]:
# """
# Creates encoding stream for observations.
# :param num_streams: Number of streams to create.
# :param h_size: Size of hidden linear layers in stream.
# :param num_layers: Number of hidden linear layers in stream.
# :param stream_scopes: List of strings (length == num_streams), which contains
# the scopes for each of the streams. None if all under the same TF scope.
# :return: List of encoded streams.
# """
# activation_fn = ModelUtils.swish
# vector_observation_input = vector_in
# final_hiddens = []
# for i in range(num_streams):
# # Pick the encoder function based on the EncoderType
# create_encoder_func = ModelUtils.get_encoder_for_type(vis_encode_type)
# visual_encoders = []
# hidden_state, hidden_visual = None, None
# _scope_add = stream_scopes[i] if stream_scopes else ""
# if len(visual_in) > 0:
# for j, vis_in in enumerate(visual_in):
# ModelUtils._check_resolution_for_encoder(vis_in, vis_encode_type)
# encoded_visual = create_encoder_func(
# vis_in,
# h_size,
# activation_fn,
# num_layers,
# f"{_scope_add}main_graph_{i}_encoder{j}", # scope
# False, # reuse
# )
# visual_encoders.append(encoded_visual)
# hidden_visual = torch.cat(visual_encoders, axis=1)
# if vector_in.get_shape()[-1] > 0: # Don't encode 0-shape inputs
# hidden_state = ModelUtils.create_vector_observation_encoder(
# vector_observation_input,
# h_size,
# activation_fn,
# num_layers,
# scope=f"{_scope_add}main_graph_{i}",
# reuse=False,
# )
# if hidden_state is not None and hidden_visual is not None:
# final_hidden = torch.cat([hidden_visual, hidden_state], axis=1)
# elif hidden_state is None and hidden_visual is not None:
# final_hidden = hidden_visual
# elif hidden_state is not None and hidden_visual is None:
# final_hidden = hidden_state
# else:
# raise Exception(
# "No valid network configuration possible. "
# "There are no states or observations in this brain"
# )
# final_hiddens.append(final_hidden)
# return final_hiddens
# @staticmethod
# def create_recurrent_encoder(input_state, memory_in, sequence_length, name="lstm"):
# """
# Builds a recurrent encoder for either state or observations (LSTM).
# :param sequence_length: Length of sequence to unroll.
# :param input_state: The input tensor to the LSTM cell.
# :param memory_in: The input memory to the LSTM cell.
# :param name: The scope of the LSTM cell.
# """
# s_size = input_state.get_shape().as_list()[1]
# m_size = memory_in.get_shape().as_list()[1]
# lstm_input_state = tf.reshape(input_state, shape=[-1, sequence_length, s_size])
# memory_in = tf.reshape(memory_in[:, :], [-1, m_size])
# half_point = int(m_size / 2)
# with tf.variable_scope(name):
# rnn_cell = tf.nn.rnn_cell.BasicLSTMCell(half_point)
# lstm_vector_in = tf.nn.rnn_cell.LSTMStateTuple(
# memory_in[:, :half_point], memory_in[:, half_point:]
# )
# recurrent_output, lstm_state_out = tf.nn.dynamic_rnn(
# rnn_cell, lstm_input_state, initial_state=lstm_vector_in
# )
# recurrent_output = tf.reshape(recurrent_output, shape=[-1, half_point])
# return recurrent_output, tf.concat([lstm_state_out.c, lstm_state_out.h], axis=1)

250
ml-agents/mlagents/trainers/policy/nn_torch_policy.py


from mlagents.trainers.models_torch import (
ActionType,
VectorEncoder,
SimpleVisualEncoder,
ModelUtils,
)
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.distributions_torch import (

EPSILON = 1e-7 # Small value to avoid divide by zero
class Critic(nn.Module):
def __init__(self, stream_names, hidden_size, encoder, **kwargs):
super(Critic, self).__init__(**kwargs)
self.stream_names = stream_names
self.encoder = encoder
self.value_heads = ValueHeads(stream_names, hidden_size)
def forward(self, inputs):
hidden = self.encoder(inputs)
return self.value_heads(hidden)
class ActorCriticPolicy(nn.Module):
class NetworkBody(nn.Module):
h_size,
act_size,
h_size,
stream_names,
act_type,
super(ActorCriticPolicy, self).__init__()
super(NetworkBody, self).__init__()
self.normalize = normalize
self.act_type = act_type
self.m_size = m_size
visual_encoder = ModelUtils.get_encoder_for_type(vis_encode_type)
self.visual_encoders.append(SimpleVisualEncoder(visual_size))
self.visual_encoders.append(visual_encoder(visual_size))
if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(h_size, act_size)
else:
self.distribution = MultiCategoricalDistribution(h_size, act_size)
self.critic = Critic(
stream_names, h_size, VectorEncoder(vector_sizes[0], h_size, num_layers)
)
self.act_size = act_size
torch.zeros(1, batch_size, self.h_size),
torch.zeros(1, batch_size, self.h_size),
torch.zeros(1, batch_size, self.m_size),
torch.zeros(1, batch_size, self.m_size),
def forward(self, vec_inputs, vis_inputs, masks=None):
def update_normalization(self, inputs):
if self.normalize:
self.normalizer.update(inputs)
def forward(self, vec_inputs, vis_inputs):
vec_embeds = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]

embedding = torch.cat([vec_embeds, vis_embeds])
if self.use_lstm:
embedding, self.memory = self.lstm(embedding, self.memory)
return embedding
class Actor(nn.Module):
def __init__(
self,
h_size,
vector_sizes,
visual_sizes,
act_size,
normalize,
num_layers,
m_size,
vis_encode_type,
act_type,
use_lstm,
):
super(Actor, self).__init__()
self.act_type = act_type
self.act_size = act_size
self.network_body = NetworkBody(
vector_sizes,
visual_sizes,
h_size,
normalize,
num_layers,
m_size,
vis_encode_type,
use_lstm,
)
if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(h_size, act_size)
else:
self.distribution = MultiCategoricalDistribution(h_size, act_size)
def forward(self, vec_inputs, vis_inputs, masks=None):
embedding = self.network_body(vec_inputs, vis_inputs)
if self.act_type == ActionType.CONTINUOUS:
dist = self.distribution(embedding)
else:

def update_normalization(self, inputs):
if self.normalize:
self.normalizer.update(inputs)
def get_values(self, vec_inputs, vis_inputs):
if self.normalize:
vec_inputs = self.normalizer(vec_inputs)
return self.critic(vec_inputs)
class Critic(nn.Module):
def __init__(
self,
stream_names,
h_size,
vector_sizes,
visual_sizes,
normalize,
num_layers,
m_size,
vis_encode_type,
use_lstm,
):
super(Critic, self).__init__()
self.stream_names = stream_names
self.network_body = NetworkBody(
vector_sizes,
visual_sizes,
h_size,
normalize,
num_layers,
m_size,
vis_encode_type,
use_lstm,
)
self.value_heads = ValueHeads(stream_names, h_size)
def forward(self, vec_inputs, vis_inputs):
embedding = self.network_body(vec_inputs, vis_inputs)
return self.value_heads(embedding)
class NNPolicy(TorchPolicy):

brain: BrainParameters,
trainer_params: Dict[str, Any],
is_training: bool,
create_tf_graph: bool = True,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could

:param brain: Assigned BrainParameters object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output.

"Losses/Policy Loss": "policy_loss",
}
self.model = ActorCriticPolicy(
self.model = Actor(
normalize=trainer_params["normalize"],
num_layers=int(trainer_params["num_layers"]),
m_size=trainer_params["memory_size"],
use_lstm=self.use_recurrent,
visual_sizes=brain.camera_resolutions,
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
)
self.critic = Critic(
h_size=int(trainer_params["hidden_units"]),
vector_sizes=[brain.vector_observation_space_size],
normalize=trainer_params["normalize"],
num_layers=int(trainer_params["num_layers"]),
m_size=trainer_params["memory_size"],

def evaluate(self, decision_requests: DecisionSteps) -> Dict[str, Any]:
"""
Evaluates policy for the agent experiences provided.
:param decision_step: DecisionStep object containing inputs.
:param decision_requests: DecisionStep object containing inputs.
:return: Outputs from network as defined by self.inference_dict.
"""
vec_obs, vis_obs, masks = self.split_decision_step(decision_requests)

}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
self.model.update_normalization(decision_requests.vec_obs)
self.model.update_normalization(vec_obs)
# def _create_cc_actor(
# self,
# encoded: tf.Tensor,
# tanh_squash: bool = False,
# reparameterize: bool = False,
# condition_sigma_on_obs: bool = True,
# ) -> None:
# """
# Creates Continuous control actor-critic model.
# :param h_size: Size of hidden linear layers.
# :param num_layers: Number of hidden linear layers.
# :param vis_encode_type: Type of visual encoder to use if visual input.
# :param tanh_squash: Whether to use a tanh function, or a clipped output.
# :param reparameterize: Whether we are using the resampling trick to update the policy.
# """
# if self.use_recurrent:
# self.memory_in = tf.placeholder(
# shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
# )
# hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
# encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
# )
#
# self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")
# else:
# hidden_policy = encoded
#
# with tf.variable_scope("policy"):
# distribution = GaussianDistribution(
# hidden_policy,
# self.act_size,
# reparameterize=reparameterize,
# tanh_squash=tanh_squash,
# condition_sigma=condition_sigma_on_obs,
# )
#
# if tanh_squash:
# self.output_pre = distribution.sample
# self.output = tf.identity(self.output_pre, name="action")
# else:
# self.output_pre = distribution.sample
# # Clip and scale output to ensure actions are always within [-1, 1] range.
# output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
# self.output = tf.identity(output_post, name="action")
#
# self.selected_actions = tf.stop_gradient(self.output)
#
# self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs")
# self.entropy = distribution.entropy
#
# # We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
# self.total_log_probs = distribution.total_log_probs
#
# def _create_dc_actor(self, encoded: tf.Tensor) -> None:
# """
# Creates Discrete control actor-critic model.
# :param h_size: Size of hidden linear layers.
# :param num_layers: Number of hidden linear layers.
# :param vis_encode_type: Type of visual encoder to use if visual input.
# """
# if self.use_recurrent:
# self.prev_action = tf.placeholder(
# shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
# )
# prev_action_oh = tf.concat(
# [
# tf.one_hot(self.prev_action[:, i], self.act_size[i])
# for i in range(len(self.act_size))
# ],
# axis=1,
# )
# hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
#
# self.memory_in = tf.placeholder(
# shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
# )
# hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
# hidden_policy,
# self.memory_in,
# self.sequence_length_ph,
# name="lstm_policy",
# )
#
# self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
# else:
# hidden_policy = encoded
#
# self.action_masks = tf.placeholder(
# shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
# )
#
# with tf.variable_scope("policy"):
# distribution = MultiCategoricalDistribution(
# hidden_policy, self.act_size, self.action_masks
# )
# # It's important that we are able to feed_dict a value into this tensor to get the
# # right one-hot encoding, so we can't do identity on it.
# self.output = distribution.sample
# self.all_log_probs = tf.identity(distribution.log_probs, name="action")
# self.selected_actions = tf.stop_gradient(
# distribution.sample_onehot
# ) # In discrete, these are onehot
# self.entropy = distribution.entropy
# self.total_log_probs = distribution.total_log_probs

110
ml-agents/mlagents/trainers/policy/torch_policy.py


from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.brain_conversion_utils import get_global_agent_id
from mlagents_envs.base_env import DecisionSteps
from mlagents.trainers.models import ModelUtils
logger = get_logger(__name__)

self.inference_dict = {}
self.update_dict = {}
self.sequence_length = 1
self.global_step = 0
self.seed = seed
self.brain = brain

brain.brain_name, self.m_size
)
)
self._initialize_tensorflow_references()
self.load = load
def _initialize_graph(self):

Gets current model step.
:return: current model step.
"""
step = self.sess.run(self.global_step)
return step
return self.global_step
def _set_step(self, step: int) -> int:
"""

"""
Increments model step.
"""
out_dict = {
"global_step": self.global_step,
"increment_step": self.increment_step_op,
}
feed_dict = {self.steps_to_increment: n_steps}
return self.sess.run(out_dict, feed_dict=feed_dict)["global_step"]
self.global_step += n_steps
return self.global_step
def get_inference_vars(self):
"""

If this policy normalizes vector observations, this will update the norm values in the graph.
:param vector_obs: The vector observations to add to the running estimate of the distribution.
"""
if self.use_vec_obs and self.normalize:
self.sess.run(
self.update_normalization_op, feed_dict={self.vector_in: vector_obs}
)
return None
@property
def use_vis_obs(self):

def use_vec_obs(self):
return self.vec_obs_size > 0
def _initialize_tensorflow_references(self):
self.value_heads: Dict[str, tf.Tensor] = {}
self.normalization_steps: Optional[tf.Variable] = None
self.running_mean: Optional[tf.Variable] = None
self.running_variance: Optional[tf.Variable] = None
self.update_normalization_op: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: tf.Tensor = None
self.total_log_probs: Optional[tf.Tensor] = None
self.entropy: Optional[tf.Tensor] = None
self.output_pre: Optional[tf.Tensor] = None
self.output: Optional[tf.Tensor] = None
self.selected_actions: tf.Tensor = None
self.action_masks: Optional[tf.Tensor] = None
self.prev_action: Optional[tf.Tensor] = None
self.memory_in: Optional[tf.Tensor] = None
self.memory_out: Optional[tf.Tensor] = None
def create_input_placeholders(self):
with self.graph.as_default():
(
self.global_step,
self.increment_step_op,
self.steps_to_increment,
) = ModelUtils.create_global_steps()
self.visual_in = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)
self.vector_in = ModelUtils.create_vector_input(self.vec_obs_size)
if self.normalize:
normalization_tensors = ModelUtils.create_normalizer(self.vector_in)
self.update_normalization_op = normalization_tensors.update_op
self.normalization_steps = normalization_tensors.steps
self.running_mean = normalization_tensors.running_mean
self.running_variance = normalization_tensors.running_variance
self.processed_vector_in = ModelUtils.normalize_vector_obs(
self.vector_in,
self.running_mean,
self.running_variance,
self.normalization_steps,
)
else:
self.processed_vector_in = self.vector_in
self.update_normalization_op = None
self.batch_size_ph = tf.placeholder(
shape=None, dtype=tf.int32, name="batch_size"
)
self.sequence_length_ph = tf.placeholder(
shape=None, dtype=tf.int32, name="sequence_length"
)
self.mask_input = tf.placeholder(
shape=[None], dtype=tf.float32, name="masks"
)
# Only needed for PPO, but needed for BC module
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"
)
self.mask = tf.cast(self.mask_input, tf.int32)
tf.Variable(
int(self.brain.vector_action_space_type == "continuous"),
name="is_continuous_control",
trainable=False,
dtype=tf.int32,
)
tf.Variable(
self._version_number_,
name="version_number",
trainable=False,
dtype=tf.int32,
)
tf.Variable(
self.m_size, name="memory_size", trainable=False, dtype=tf.int32
)
if self.brain.vector_action_space_type == "continuous":
tf.Variable(
self.act_size[0],
name="action_output_shape",
trainable=False,
dtype=tf.int32,
)
else:
tf.Variable(
sum(self.act_size),
name="action_output_shape",
trainable=False,
dtype=tf.int32,
)
正在加载...
取消
保存