浏览代码

Add ResNet and distributions

/develop/add-fire
Arthur Juliani 5 年前
当前提交
dc50162d
共有 4 个文件被更改,包括 941 次插入118 次删除
  1. 213
      ml-agents/mlagents/trainers/models_torch.py
  2. 48
      ml-agents/mlagents/trainers/distributions_torch.py
  3. 339
      ml-agents/mlagents/trainers/policy/nn_torch_policy.py
  4. 459
      ml-agents/mlagents/trainers/policy/torch_policy.py

213
ml-agents/mlagents/trainers/models_torch.py


RESNET = "resnet"
class ActionType(Enum):
DISCRETE = "discrete"
CONTINUOUS = "continuous"
class LearningRateSchedule(Enum):
CONSTANT = "constant"
LINEAR = "linear"

running_variance: torch.Tensor
class ObservationNormalizer(nn.Module):
def __init__(self, vec_obs):
vec_size = vec_obs.shape[1]
self.steps = torch.Tensor([0])
self.running_mean = torch.Tensor(vec_size)
self.running_variance = torch.Tensor(vec_size)
class Normalizer(nn.Module):
def __init__(self, vec_obs_size, **kwargs):
super(Normalizer, self).__init__(**kwargs)
print(vec_obs_size)
self.normalization_steps = torch.tensor(1)
self.running_mean = torch.zeros(vec_obs_size)
self.running_variance = torch.ones(vec_obs_size)
def normalize_obs(self, vector_obs):
normalized_obs = torch.clip_by_value(
(vector_obs - self.running_mean)
def forward(self, inputs):
inputs = torch.from_numpy(inputs)
normalized_state = torch.clamp(
(inputs - self.running_mean)
self.running_variance / (self.steps.astype(torch.float32) + 1)
self.running_variance / self.normalization_steps.type(torch.float32)
return normalized_obs
return normalized_state
def update(self, vector_obs):
# Based on Welford's algorithm for running mean and standard deviation, for batch updates. Discussion here:
# https://stackoverflow.com/questions/56402955/whats-the-formula-for-welfords-algorithm-for-variance-std-with-batch-updates
steps_increment = vector_obs.shape[0]
total_new_steps = self.steps + steps_increment
# Compute the incremental update and divide by the number of new steps.
input_to_old_mean = vector_obs - self.running_mean
new_mean = self.running_mean + torch.sum(
input_to_old_mean / total_new_steps.astype(torch.float32), dim=0
)
# Compute difference of input to the new mean for Welford update
input_to_new_mean = vector_obs - new_mean
new_variance = self.running_variance + torch.sum(
input_to_new_mean * input_to_old_mean, dim=0
def update(self, vector_input):
vector_input = torch.from_numpy(vector_input)
mean_current_observation = vector_input.mean(0).type(torch.float32)
new_mean = self.running_mean + (
mean_current_observation - self.running_mean
) / (self.normalization_steps + 1).type(torch.float32)
new_variance = self.running_variance + (mean_current_observation - new_mean) * (
mean_current_observation - self.running_mean
self.steps = total_new_steps
self.normalization_steps = self.normalization_steps + 1
super(ValueHeads, self).__init__()
self.stream_names = stream_names
self.value_heads = {}

value_outputs = {}
for stream_name, _ in self.value_heads.items():
value_outputs[stream_name] = self.value_heads[stream_name](hidden)
return value_outputs, torch.mean(list(value_outputs), dim=0)
return value_outputs, torch.mean(torch.stack(list(value_outputs)), dim=0)
class VectorEncoder(nn.Module):

class SimpleVisualEncoder(nn.Module):
def __init__(self, visual_obs):
image_depth = visual_obs.shape[-1]
self.conv1 = nn.Conv2d(image_depth, 16, [8, 8], [4, 4])
def __init__(self, initial_channels):
super(SimpleVisualEncoder, self).__init__()
self.conv1 = nn.Conv2d(initial_channels, 16, [8, 8], [4, 4])
# Todo: add vector encoder here?
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
return torch.flatten(conv_2)

def __init__(self, visual_obs):
image_depth = visual_obs.shape[-1]
self.conv1 = nn.Conv2d(image_depth, 32, [8, 8], [4, 4])
def __init__(self, initial_channels):
super(NatureVisualEncoder, self).__init__()
self.conv1 = nn.Conv2d(initial_channels, 32, [8, 8], [4, 4])
# Todo: add vector encoder here?
conv_1 = torch.relu(self.conv1(visual_obs))
conv_2 = torch.relu(self.conv2(conv_1))
conv_3 = torch.relu(self.conv3(conv_2))

class DiscreteActionMask(nn.Module):
def __init__(self, action_size):
super(DiscreteActionMask, self).__init__()
@staticmethod
self, concatenated_logits: torch.Tensor, action_size: List[int]
concatenated_logits: torch.Tensor, action_size: List[int]
) -> List[torch.Tensor]:
"""
Takes a concatenated set of logits that represent multiple discrete action branches

)
class GlobalSteps(nn.Module):
def __init__(self):
super(GlobalSteps, self).__init__()
self.global_step = torch.Tensor([0])
def increment(self, value):
self.global_step += value
class LearningRate(nn.Module):
def __init__(self, lr):
# Todo: add learning rate decay
super(LearningRate, self).__init__()
self.learning_rate = torch.Tensor([lr])
class ResNetVisualEncoder(nn.Module):
def __init__(self, initial_channels):
super(ResNetVisualEncoder, self).__init__()
n_channels = [16, 32, 32] # channel for each stack
n_blocks = 2 # number of residual blocks
self.layers = []
for _, channel in enumerate(n_channels):
self.layers.append(nn.Conv2d(initial_channels, channel, [3, 3], [1, 1]))
self.layers.append(nn.MaxPool2d([3, 3], [2, 2]))
for _ in range(n_blocks):
self.layers.append(self.make_block(channel))
self.layers.append(nn.RELU())
@staticmethod
def make_block(channel):
block_layers = [
nn.ReLU(),
nn.Conv2d(channel, channel, [3, 3], [1, 1]),
nn.ReLU(),
nn.Conv2d(channel, channel, [3, 3], [1, 1]),
]
return block_layers
@staticmethod
def forward_block(input_hidden, block_layers):
hidden = input_hidden
for layer in block_layers:
hidden = layer(hidden)
return hidden + input_hidden
def forward(self, visual_obs):
hidden = visual_obs
for layer in self.layers:
if layer is nn.Module:
hidden = layer(hidden)
elif layer is list:
hidden = self.forward_block(hidden, layer)
return hidden.flatten()
class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.

EncoderType.RESNET: 15,
}
class GlobalSteps(nn.Module):
def __init__(self):
self.global_step = torch.Tensor([0])
def increment(self, value):
self.global_step += value
class LearningRate(nn.Module):
def __init__(self, lr):
# Todo: add learning rate decay
self.learning_rate = torch.Tensor([lr])
# @staticmethod
# def scaled_init(scale):
# return tf.initializers.variance_scaling(scale)

"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return torch.mul(input_activation, torch.sigmoid(input_activation))
# @staticmethod
# def create_resnet_visual_observation_encoder(
# image_input: tf.Tensor,
# h_size: int,
# activation: ActivationFunction,
# num_layers: int,
# scope: str,
# reuse: bool,
# ) -> tf.Tensor:
# """
# Builds a set of resnet visual encoders.
# :param image_input: The placeholder for the image input to use.
# :param h_size: Hidden layer size.
# :param activation: What type of activation function to use for layers.
# :param num_layers: number of hidden layers to create.
# :param scope: The scope of the graph within which to create the ops.
# :param reuse: Whether to re-use the weights within the same scope.
# :return: List of hidden layer tensors.
# """
# n_channels = [16, 32, 32] # channel for each stack
# n_blocks = 2 # number of residual blocks
# with tf.variable_scope(scope):
# hidden = image_input
# for i, ch in enumerate(n_channels):
# hidden = tf.layers.conv2d(
# hidden,
# ch,
# kernel_size=[3, 3],
# strides=[1, 1],
# reuse=reuse,
# name="layer%dconv_1" % i,
# )
# hidden = tf.layers.max_pooling2d(
# hidden, pool_size=[3, 3], strides=[2, 2], padding="same"
# )
# # create residual blocks
# for j in range(n_blocks):
# block_input = hidden
# hidden = tf.nn.relu(hidden)
# hidden = tf.layers.conv2d(
# hidden,
# ch,
# kernel_size=[3, 3],
# strides=[1, 1],
# padding="same",
# reuse=reuse,
# name="layer%d_%d_conv1" % (i, j),
# )
# hidden = tf.nn.relu(hidden)
# hidden = tf.layers.conv2d(
# hidden,
# ch,
# kernel_size=[3, 3],
# strides=[1, 1],
# padding="same",
# reuse=reuse,
# name="layer%d_%d_conv2" % (i, j),
# )
# hidden = tf.add(block_input, hidden)
# hidden = tf.nn.relu(hidden)
# hidden = tf.layers.flatten(hidden)
# with tf.variable_scope(scope + "/" + "flat_encoding"):
# hidden_flat = ModelUtils.create_vector_observation_encoder(
# hidden, h_size, activation, num_layers, scope, reuse
# )
# return hidden_flat
# EncoderType.RESNET: ModelUtils.create_resnet_visual_observation_encoder,
EncoderType.RESNET: ResNetVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

48
ml-agents/mlagents/trainers/distributions_torch.py


import torch
from torch import nn
from torch import distributions
EPSILON = 1e-6 # Small value to avoid divide by zero
class GaussianDistribution(nn.Module):
def __init__(self, hidden_size, num_outputs, **kwargs):
super(GaussianDistribution, self).__init__(**kwargs)
self.mu = nn.Linear(hidden_size, num_outputs)
self.log_sigma_sq = nn.Linear(hidden_size, num_outputs)
nn.init.xavier_uniform(self.mu.weight, gain=0.01)
nn.init.xavier_uniform(self.log_sigma_sq.weight, gain=0.01)
def forward(self, inputs):
mu = self.mu(inputs)
log_sig = self.log_sigma_sq(inputs)
return distributions.normal.Normal(loc=mu, scale=torch.sqrt(torch.exp(log_sig)))
class MultiCategoricalDistribution(nn.Module):
def __init__(self, hidden_size, act_sizes):
super(MultiCategoricalDistribution, self).__init__()
self.branches = self.create_policy_branches(hidden_size, act_sizes)
def create_policy_branches(self, hidden_size, act_sizes):
branches = []
for size in act_sizes:
branch_output_layer = nn.Linear(hidden_size, size)
nn.init.xavier_uniform(branch_output_layer.weight, gain=0.01)
branches.append(branch_output_layer)
return branches
def mask_branch(self, logits, mask):
raw_probs = torch.sigmoid(logits, dim=-1) * mask
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1)
normalized_logits = torch.log(normalized_probs)
return normalized_logits
def forward(self, inputs, masks):
branch_distributions = []
for idx, branch in enumerate(self.branches):
logits = branch(inputs)
norm_logits = self.mask_branch(logits, masks[idx])
distribution = distributions.categorical.Categorical(logits=norm_logits)
branch_distributions.append(distribution)
return branch_distributions

339
ml-agents/mlagents/trainers/policy/nn_torch_policy.py


from typing import Any, Dict
import numpy as np
import torch
from mlagents_envs.base_env import DecisionSteps
from torch import nn
from mlagents.tf_utils import tf
from mlagents_envs.timers import timed
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.models import EncoderType
from mlagents.trainers.models_torch import (
ActionType,
VectorEncoder,
SimpleVisualEncoder,
ValueHeads,
Normalizer,
)
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.distributions_torch import (
GaussianDistribution,
MultiCategoricalDistribution,
)
EPSILON = 1e-7 # Small value to avoid divide by zero
class Critic(nn.Module):
def __init__(self, stream_names, hidden_size, encoder, **kwargs):
super(Critic, self).__init__(**kwargs)
self.stream_names = stream_names
self.encoder = encoder
self.value_heads = ValueHeads(stream_names, hidden_size)
def forward(self, inputs):
hidden = self.encoder(inputs)
return self.value_heads(hidden)
class ActorCriticPolicy(nn.Module):
def __init__(
self,
h_size,
vector_sizes,
visual_sizes,
act_size,
normalize,
num_layers,
m_size,
stream_names,
vis_encode_type,
act_type,
use_lstm,
):
super(ActorCriticPolicy, self).__init__()
self.visual_encoders = []
self.vector_encoders = []
self.vector_normalizers = []
self.act_type = act_type
self.use_lstm = use_lstm
self.h_size = h_size
for vector_size in vector_sizes:
self.vector_normalizers.append(Normalizer(vector_size))
self.vector_encoders.append(VectorEncoder(vector_size, h_size, num_layers))
for visual_size in visual_sizes:
self.visual_encoders.append(SimpleVisualEncoder(visual_size))
if use_lstm:
self.lstm = nn.LSTM(h_size, h_size, 1)
if self.act_type == ActionType.CONTINUOUS:
self.distribution = GaussianDistribution(h_size, act_size)
else:
self.distribution = MultiCategoricalDistribution(h_size, act_size)
self.critic = Critic(
stream_names, h_size, VectorEncoder(vector_sizes[0], h_size, num_layers)
)
self.act_size = act_size
def clear_memory(self, batch_size):
self.memory = (
torch.zeros(1, batch_size, self.h_size),
torch.zeros(1, batch_size, self.h_size),
)
def forward(self, vec_inputs, vis_inputs, masks=None):
vec_embeds = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]
if self.normalize:
vec_input = self.normalizers[idx](vec_inputs[idx])
hidden = encoder(vec_input)
vec_embeds.append(hidden)
vis_embeds = []
for idx, encoder in enumerate(self.visual_encoders):
hidden = encoder(vis_inputs[idx])
vis_embeds.append(hidden)
vec_embeds = torch.cat(vec_embeds)
vis_embeds = torch.cat(vis_embeds)
embedding = torch.cat([vec_embeds, vis_embeds])
if self.use_lstm:
embedding, self.memory = self.lstm(embedding, self.memory)
if self.act_type == ActionType.CONTINUOUS:
dist = self.distribution(embedding)
else:
dist = self.distribution(embedding, masks=masks)
return dist
def update_normalization(self, inputs):
if self.normalize:
self.normalizer.update(inputs)
def get_values(self, vec_inputs, vis_inputs):
if self.normalize:
vec_inputs = self.normalizer(vec_inputs)
return self.critic(vec_inputs)
class NNPolicy(TorchPolicy):
def __init__(
self,
seed: int,
brain: BrainParameters,
trainer_params: Dict[str, Any],
is_training: bool,
load: bool,
tanh_squash: bool = False,
reparameterize: bool = False,
condition_sigma_on_obs: bool = True,
create_tf_graph: bool = True,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could
also use a CNN to encode visual input prior to the MLP. Supports discrete and
continuous action spaces, as well as recurrent networks.
:param seed: Random seed.
:param brain: Assigned BrainParameters object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output.
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output.
"""
super().__init__(seed, brain, trainer_params, load)
self.grads = None
num_layers = trainer_params["num_layers"]
self.h_size = trainer_params["hidden_units"]
if num_layers < 1:
num_layers = 1
self.num_layers = num_layers
self.vis_encode_type = EncoderType(
trainer_params.get("vis_encode_type", "simple")
)
self.tanh_squash = tanh_squash
self.reparameterize = reparameterize
self.condition_sigma_on_obs = condition_sigma_on_obs
# Non-exposed parameters; these aren't exposed because they don't have a
# good explanation and usually shouldn't be touched.
self.log_std_min = -20
self.log_std_max = 2
self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
# TF defaults to 32-bit, so we use the same here.
torch.set_default_tensor_type(torch.DoubleTensor)
reward_signal_configs = trainer_params["reward_signals"]
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.model = ActorCriticPolicy(
h_size=int(trainer_params["hidden_units"]),
act_type=ActionType.CONTINUOUS,
vector_sizes=[brain.vector_observation_space_size],
act_size=sum(brain.vector_action_space_size),
normalize=trainer_params["normalize"],
num_layers=int(trainer_params["num_layers"]),
m_size=trainer_params["memory_size"],
use_lstm=self.use_recurrent,
visual_sizes=brain.camera_resolutions,
stream_names=list(reward_signal_configs.keys()),
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
)
def split_decision_step(self, decision_requests):
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
mask = None
if not self.use_continuous_act:
mask = np.ones(
(len(decision_requests), np.sum(self.brain.vector_action_space_size)),
dtype=np.float32,
)
if decision_requests.action_mask is not None:
mask = 1 - np.concatenate(decision_requests.action_mask, axis=1)
return vec_vis_obs.vector_observations, vec_vis_obs.visual_observations, mask
def execute_model(self, vec_obs, vis_obs, masks):
action_dist = self.model(vec_obs, vis_obs, masks)
action = action_dist.sample()
log_probs = action_dist.log_prob(action)
entropy = action_dist.entropy()
value_heads = self.model.get_values(vec_obs, vis_obs)
return action, log_probs, entropy, value_heads
@timed
def evaluate(self, decision_requests: DecisionSteps) -> Dict[str, Any]:
"""
Evaluates policy for the agent experiences provided.
:param decision_step: DecisionStep object containing inputs.
:return: Outputs from network as defined by self.inference_dict.
"""
vec_obs, vis_obs, masks = self.split_decision_step(decision_requests)
run_out = {}
action, log_probs, entropy, value_heads = self.execute_model(
vec_obs, vis_obs, masks
)
run_out["action"] = np.array(action.detach())
run_out["log_probs"] = np.array(log_probs.detach())
run_out["entropy"] = np.array(entropy.detach())
run_out["value_heads"] = {
name: np.array(t.detach()) for name, t in value_heads.items()
}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
self.model.update_normalization(decision_requests.vec_obs)
return run_out
# def _create_cc_actor(
# self,
# encoded: tf.Tensor,
# tanh_squash: bool = False,
# reparameterize: bool = False,
# condition_sigma_on_obs: bool = True,
# ) -> None:
# """
# Creates Continuous control actor-critic model.
# :param h_size: Size of hidden linear layers.
# :param num_layers: Number of hidden linear layers.
# :param vis_encode_type: Type of visual encoder to use if visual input.
# :param tanh_squash: Whether to use a tanh function, or a clipped output.
# :param reparameterize: Whether we are using the resampling trick to update the policy.
# """
# if self.use_recurrent:
# self.memory_in = tf.placeholder(
# shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
# )
# hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
# encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy"
# )
#
# self.memory_out = tf.identity(memory_policy_out, name="recurrent_out")
# else:
# hidden_policy = encoded
#
# with tf.variable_scope("policy"):
# distribution = GaussianDistribution(
# hidden_policy,
# self.act_size,
# reparameterize=reparameterize,
# tanh_squash=tanh_squash,
# condition_sigma=condition_sigma_on_obs,
# )
#
# if tanh_squash:
# self.output_pre = distribution.sample
# self.output = tf.identity(self.output_pre, name="action")
# else:
# self.output_pre = distribution.sample
# # Clip and scale output to ensure actions are always within [-1, 1] range.
# output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3
# self.output = tf.identity(output_post, name="action")
#
# self.selected_actions = tf.stop_gradient(self.output)
#
# self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs")
# self.entropy = distribution.entropy
#
# # We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control.
# self.total_log_probs = distribution.total_log_probs
#
# def _create_dc_actor(self, encoded: tf.Tensor) -> None:
# """
# Creates Discrete control actor-critic model.
# :param h_size: Size of hidden linear layers.
# :param num_layers: Number of hidden linear layers.
# :param vis_encode_type: Type of visual encoder to use if visual input.
# """
# if self.use_recurrent:
# self.prev_action = tf.placeholder(
# shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action"
# )
# prev_action_oh = tf.concat(
# [
# tf.one_hot(self.prev_action[:, i], self.act_size[i])
# for i in range(len(self.act_size))
# ],
# axis=1,
# )
# hidden_policy = tf.concat([encoded, prev_action_oh], axis=1)
#
# self.memory_in = tf.placeholder(
# shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in"
# )
# hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder(
# hidden_policy,
# self.memory_in,
# self.sequence_length_ph,
# name="lstm_policy",
# )
#
# self.memory_out = tf.identity(memory_policy_out, "recurrent_out")
# else:
# hidden_policy = encoded
#
# self.action_masks = tf.placeholder(
# shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks"
# )
#
# with tf.variable_scope("policy"):
# distribution = MultiCategoricalDistribution(
# hidden_policy, self.act_size, self.action_masks
# )
# # It's important that we are able to feed_dict a value into this tensor to get the
# # right one-hot encoding, so we can't do identity on it.
# self.output = distribution.sample
# self.all_log_probs = tf.identity(distribution.log_probs, name="action")
# self.selected_actions = tf.stop_gradient(
# distribution.sample_onehot
# ) # In discrete, these are onehot
# self.entropy = distribution.entropy
# self.total_log_probs = distribution.total_log_probs

459
ml-agents/mlagents/trainers/policy/torch_policy.py


from typing import Any, Dict, List, Optional
import numpy as np
from mlagents import tf_utils
from mlagents.tf_utils import tf
from mlagents_envs.exception import UnityException
from mlagents_envs.logging_util import get_logger
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.brain_conversion_utils import get_global_agent_id
from mlagents_envs.base_env import DecisionSteps
from mlagents.trainers.models import ModelUtils
logger = get_logger(__name__)
class UnityPolicyException(UnityException):
"""
Related to errors with the Trainer.
"""
pass
class TorchPolicy(Policy):
"""
Contains a learning model, and the necessary
functions to save/load models and create the input placeholders.
"""
def __init__(self, seed, brain, trainer_parameters, load=False):
"""
Initialized the policy.
:param seed: Random seed to use for TensorFlow.
:param brain: The corresponding Brain for this policy.
:param trainer_parameters: The trainer parameters.
"""
self._version_number_ = 2
self.m_size = 0
# for ghost trainer save/load snapshots
self.assign_phs = []
self.assign_ops = []
self.inference_dict = {}
self.update_dict = {}
self.sequence_length = 1
self.seed = seed
self.brain = brain
self.act_size = brain.vector_action_space_size
self.vec_obs_size = brain.vector_observation_space_size
self.vis_obs_size = brain.number_visual_observations
self.use_recurrent = trainer_parameters["use_recurrent"]
self.memory_dict: Dict[str, np.ndarray] = {}
self.num_branches = len(self.brain.vector_action_space_size)
self.previous_action_dict: Dict[str, np.array] = {}
self.normalize = trainer_parameters.get("normalize", False)
self.use_continuous_act = brain.vector_action_space_type == "continuous"
if self.use_continuous_act:
self.num_branches = self.brain.vector_action_space_size[0]
self.model_path = trainer_parameters["model_path"]
self.initialize_path = trainer_parameters.get("init_path", None)
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
self.graph = tf.Graph()
self.sess = tf.Session(
config=tf_utils.generate_session_config(), graph=self.graph
)
self.saver = None
self.seed = seed
if self.use_recurrent:
self.m_size = trainer_parameters["memory_size"]
self.sequence_length = trainer_parameters["sequence_length"]
if self.m_size == 0:
raise UnityPolicyException(
"The memory size for brain {0} is 0 even "
"though the trainer uses recurrent.".format(brain.brain_name)
)
elif self.m_size % 2 != 0:
raise UnityPolicyException(
"The memory size for brain {0} is {1} "
"but it must be divisible by 2.".format(
brain.brain_name, self.m_size
)
)
self._initialize_tensorflow_references()
self.load = load
def _initialize_graph(self):
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
init = tf.global_variables_initializer()
self.sess.run(init)
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None:
with self.graph.as_default():
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
logger.info(
"Loading model for brain {} from {}.".format(
self.brain.brain_name, model_path
)
)
ckpt = tf.train.get_checkpoint_state(model_path)
if ckpt is None:
raise UnityPolicyException(
"The model {0} could not be loaded. Make "
"sure you specified the right "
"--run-id and that the previous run you are loading from had the same "
"behavior names.".format(model_path)
)
try:
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
except tf.errors.NotFoundError:
raise UnityPolicyException(
"The model {0} was found but could not be loaded. Make "
"sure the model is from the same version of ML-Agents, has the same behavior parameters, "
"and is using the same trainer configuration as the current run.".format(
model_path
)
)
if reset_global_steps:
logger.info(
"Starting training from step 0 and saving to {}.".format(
self.model_path
)
)
else:
logger.info(
"Resuming training from step {}.".format(self.get_current_step())
)
def initialize_or_load(self):
# If there is an initialize path, load from that. Else, load from the set model path.
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to,
# e.g., resume from an initialize path.
reset_steps = not self.load
if self.initialize_path is not None:
self._load_graph(self.initialize_path, reset_global_steps=reset_steps)
elif self.load:
self._load_graph(self.model_path, reset_global_steps=reset_steps)
else:
self._initialize_graph()
def get_weights(self):
with self.graph.as_default():
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
values = [v.eval(session=self.sess) for v in _vars]
return values
def init_load_weights(self):
with self.graph.as_default():
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
values = [v.eval(session=self.sess) for v in _vars]
for var, value in zip(_vars, values):
assign_ph = tf.placeholder(var.dtype, shape=value.shape)
self.assign_phs.append(assign_ph)
self.assign_ops.append(tf.assign(var, assign_ph))
def load_weights(self, values):
if len(self.assign_ops) == 0:
logger.warning(
"Calling load_weights in tf_policy but assign_ops is empty. Did you forget to call init_load_weights?"
)
with self.graph.as_default():
feed_dict = {}
for assign_ph, value in zip(self.assign_phs, values):
feed_dict[assign_ph] = value
self.sess.run(self.assign_ops, feed_dict=feed_dict)
def evaluate(self, decision_requests: DecisionSteps) -> Dict[str, Any]:
"""
Evaluates policy for the agent experiences provided.
:param decision_requests: DecisionSteps input to network.
:return: Output from policy based on self.inference_dict.
"""
raise UnityPolicyException("The evaluate function was not implemented.")
def get_action(
self, decision_requests: DecisionSteps, worker_id: int = 0
) -> ActionInfo:
"""
Decides actions given observations information, and takes them in environment.
:param decision_requests: A dictionary of brain names and DecisionSteps from environment.
:param worker_id: In parallel environment training, the unique id of the environment worker that
the DecisionSteps came from. Used to construct a globally unique id for each agent.
:return: an ActionInfo containing action, memories, values and an object
to be passed to add experiences
"""
if len(decision_requests) == 0:
return ActionInfo.empty()
global_agent_ids = [
get_global_agent_id(worker_id, int(agent_id))
for agent_id in decision_requests.agent_id
] # For 1-D array, the iterator order is correct.
run_out = self.evaluate( # pylint: disable=assignment-from-no-return
decision_requests
)
self.save_memories(global_agent_ids, run_out.get("memory_out"))
return ActionInfo(
action=run_out.get("action"),
value=run_out.get("value"),
outputs=run_out,
agent_ids=decision_requests.agent_id,
)
def update(self, mini_batch, num_sequences):
"""
Performs update of the policy.
:param num_sequences: Number of experience trajectories in batch.
:param mini_batch: Batch of experiences.
:return: Results of update.
"""
raise UnityPolicyException("The update function was not implemented.")
def _execute_model(self, feed_dict, out_dict):
"""
Executes model.
:param feed_dict: Input dictionary mapping nodes to input data.
:param out_dict: Output dictionary mapping names to nodes.
:return: Dictionary mapping names to input data.
"""
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict)
run_out = dict(zip(list(out_dict.keys()), network_out))
return run_out
def fill_eval_dict(self, batched_step_result):
vec_vis_obs = SplitObservations.from_observations(batched_step_result.obs)
mask = None
if not self.use_continuous_act:
mask = np.ones(
(len(batched_step_result), np.sum(self.brain.vector_action_space_size)),
dtype=np.float32,
)
if batched_step_result.action_mask is not None:
mask = 1 - np.concatenate(batched_step_result.action_mask, axis=1)
return vec_vis_obs.vector_observations, vec_vis_obs.visual_observations, mask
def make_empty_memory(self, num_agents):
"""
Creates empty memory for use with RNNs
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.m_size), dtype=np.float32)
def save_memories(
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
) -> None:
if memory_matrix is None:
return
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.memory_dict:
memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix
def remove_memories(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.memory_dict.pop(agent_id)
def make_empty_previous_action(self, num_agents):
"""
Creates empty previous action for use with RNNs and discrete control
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.num_branches), dtype=np.int)
def save_previous_action(
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
) -> None:
if action_matrix is None:
return
for index, agent_id in enumerate(agent_ids):
self.previous_action_dict[agent_id] = action_matrix[index, :]
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_action_dict:
action_matrix[index, :] = self.previous_action_dict[agent_id]
return action_matrix
def remove_previous_action(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.previous_action_dict:
self.previous_action_dict.pop(agent_id)
def get_current_step(self):
"""
Gets current model step.
:return: current model step.
"""
step = self.sess.run(self.global_step)
return step
def _set_step(self, step: int) -> int:
"""
Sets current model step to step without creating additional ops.
:param step: Step to set the current model step to.
:return: The step the model was set to.
"""
current_step = self.get_current_step()
# Increment a positive or negative number of steps.
return self.increment_step(step - current_step)
def increment_step(self, n_steps):
"""
Increments model step.
"""
out_dict = {
"global_step": self.global_step,
"increment_step": self.increment_step_op,
}
feed_dict = {self.steps_to_increment: n_steps}
return self.sess.run(out_dict, feed_dict=feed_dict)["global_step"]
def get_inference_vars(self):
"""
:return:list of inference var names
"""
return list(self.inference_dict.keys())
def get_update_vars(self):
"""
:return:list of update var names
"""
return list(self.update_dict.keys())
def save_model(self, steps):
"""
Saves the model
:param steps: The number of steps the model was trained for
:return:
"""
with self.graph.as_default():
last_checkpoint = self.model_path + "/model-" + str(steps) + ".ckpt"
self.saver.save(self.sess, last_checkpoint)
tf.train.write_graph(
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
)
def update_normalization(self, vector_obs: np.ndarray) -> None:
"""
If this policy normalizes vector observations, this will update the norm values in the graph.
:param vector_obs: The vector observations to add to the running estimate of the distribution.
"""
if self.use_vec_obs and self.normalize:
self.sess.run(
self.update_normalization_op, feed_dict={self.vector_in: vector_obs}
)
@property
def use_vis_obs(self):
return self.vis_obs_size > 0
@property
def use_vec_obs(self):
return self.vec_obs_size > 0
def _initialize_tensorflow_references(self):
self.value_heads: Dict[str, tf.Tensor] = {}
self.normalization_steps: Optional[tf.Variable] = None
self.running_mean: Optional[tf.Variable] = None
self.running_variance: Optional[tf.Variable] = None
self.update_normalization_op: Optional[tf.Operation] = None
self.value: Optional[tf.Tensor] = None
self.all_log_probs: tf.Tensor = None
self.total_log_probs: Optional[tf.Tensor] = None
self.entropy: Optional[tf.Tensor] = None
self.output_pre: Optional[tf.Tensor] = None
self.output: Optional[tf.Tensor] = None
self.selected_actions: tf.Tensor = None
self.action_masks: Optional[tf.Tensor] = None
self.prev_action: Optional[tf.Tensor] = None
self.memory_in: Optional[tf.Tensor] = None
self.memory_out: Optional[tf.Tensor] = None
def create_input_placeholders(self):
with self.graph.as_default():
(
self.global_step,
self.increment_step_op,
self.steps_to_increment,
) = ModelUtils.create_global_steps()
self.visual_in = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)
self.vector_in = ModelUtils.create_vector_input(self.vec_obs_size)
if self.normalize:
normalization_tensors = ModelUtils.create_normalizer(self.vector_in)
self.update_normalization_op = normalization_tensors.update_op
self.normalization_steps = normalization_tensors.steps
self.running_mean = normalization_tensors.running_mean
self.running_variance = normalization_tensors.running_variance
self.processed_vector_in = ModelUtils.normalize_vector_obs(
self.vector_in,
self.running_mean,
self.running_variance,
self.normalization_steps,
)
else:
self.processed_vector_in = self.vector_in
self.update_normalization_op = None
self.batch_size_ph = tf.placeholder(
shape=None, dtype=tf.int32, name="batch_size"
)
self.sequence_length_ph = tf.placeholder(
shape=None, dtype=tf.int32, name="sequence_length"
)
self.mask_input = tf.placeholder(
shape=[None], dtype=tf.float32, name="masks"
)
# Only needed for PPO, but needed for BC module
self.epsilon = tf.placeholder(
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon"
)
self.mask = tf.cast(self.mask_input, tf.int32)
tf.Variable(
int(self.brain.vector_action_space_type == "continuous"),
name="is_continuous_control",
trainable=False,
dtype=tf.int32,
)
tf.Variable(
self._version_number_,
name="version_number",
trainable=False,
dtype=tf.int32,
)
tf.Variable(
self.m_size, name="memory_size", trainable=False, dtype=tf.int32
)
if self.brain.vector_action_space_type == "continuous":
tf.Variable(
self.act_size[0],
name="action_output_shape",
trainable=False,
dtype=tf.int32,
)
else:
tf.Variable(
sum(self.act_size),
name="action_output_shape",
trainable=False,
dtype=tf.int32,
)
正在加载...
取消
保存