Arthur Juliani
5 年前
当前提交
dc50162d
共有 4 个文件被更改,包括 941 次插入 和 118 次删除
-
213ml-agents/mlagents/trainers/models_torch.py
-
48ml-agents/mlagents/trainers/distributions_torch.py
-
339ml-agents/mlagents/trainers/policy/nn_torch_policy.py
-
459ml-agents/mlagents/trainers/policy/torch_policy.py
|
|||
import torch |
|||
from torch import nn |
|||
from torch import distributions |
|||
|
|||
EPSILON = 1e-6 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class GaussianDistribution(nn.Module): |
|||
def __init__(self, hidden_size, num_outputs, **kwargs): |
|||
super(GaussianDistribution, self).__init__(**kwargs) |
|||
self.mu = nn.Linear(hidden_size, num_outputs) |
|||
self.log_sigma_sq = nn.Linear(hidden_size, num_outputs) |
|||
nn.init.xavier_uniform(self.mu.weight, gain=0.01) |
|||
nn.init.xavier_uniform(self.log_sigma_sq.weight, gain=0.01) |
|||
|
|||
def forward(self, inputs): |
|||
mu = self.mu(inputs) |
|||
log_sig = self.log_sigma_sq(inputs) |
|||
return distributions.normal.Normal(loc=mu, scale=torch.sqrt(torch.exp(log_sig))) |
|||
|
|||
|
|||
class MultiCategoricalDistribution(nn.Module): |
|||
def __init__(self, hidden_size, act_sizes): |
|||
super(MultiCategoricalDistribution, self).__init__() |
|||
self.branches = self.create_policy_branches(hidden_size, act_sizes) |
|||
|
|||
def create_policy_branches(self, hidden_size, act_sizes): |
|||
branches = [] |
|||
for size in act_sizes: |
|||
branch_output_layer = nn.Linear(hidden_size, size) |
|||
nn.init.xavier_uniform(branch_output_layer.weight, gain=0.01) |
|||
branches.append(branch_output_layer) |
|||
return branches |
|||
|
|||
def mask_branch(self, logits, mask): |
|||
raw_probs = torch.sigmoid(logits, dim=-1) * mask |
|||
normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1) |
|||
normalized_logits = torch.log(normalized_probs) |
|||
return normalized_logits |
|||
|
|||
def forward(self, inputs, masks): |
|||
branch_distributions = [] |
|||
for idx, branch in enumerate(self.branches): |
|||
logits = branch(inputs) |
|||
norm_logits = self.mask_branch(logits, masks[idx]) |
|||
distribution = distributions.categorical.Categorical(logits=norm_logits) |
|||
branch_distributions.append(distribution) |
|||
return branch_distributions |
|
|||
from typing import Any, Dict |
|||
import numpy as np |
|||
import torch |
|||
from mlagents_envs.base_env import DecisionSteps |
|||
from torch import nn |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.timers import timed |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.models import EncoderType |
|||
from mlagents.trainers.models_torch import ( |
|||
ActionType, |
|||
VectorEncoder, |
|||
SimpleVisualEncoder, |
|||
ValueHeads, |
|||
Normalizer, |
|||
) |
|||
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|||
from mlagents.trainers.distributions_torch import ( |
|||
GaussianDistribution, |
|||
MultiCategoricalDistribution, |
|||
) |
|||
|
|||
EPSILON = 1e-7 # Small value to avoid divide by zero |
|||
|
|||
|
|||
class Critic(nn.Module): |
|||
def __init__(self, stream_names, hidden_size, encoder, **kwargs): |
|||
super(Critic, self).__init__(**kwargs) |
|||
self.stream_names = stream_names |
|||
self.encoder = encoder |
|||
self.value_heads = ValueHeads(stream_names, hidden_size) |
|||
|
|||
def forward(self, inputs): |
|||
hidden = self.encoder(inputs) |
|||
return self.value_heads(hidden) |
|||
|
|||
|
|||
class ActorCriticPolicy(nn.Module): |
|||
def __init__( |
|||
self, |
|||
h_size, |
|||
vector_sizes, |
|||
visual_sizes, |
|||
act_size, |
|||
normalize, |
|||
num_layers, |
|||
m_size, |
|||
stream_names, |
|||
vis_encode_type, |
|||
act_type, |
|||
use_lstm, |
|||
): |
|||
super(ActorCriticPolicy, self).__init__() |
|||
self.visual_encoders = [] |
|||
self.vector_encoders = [] |
|||
self.vector_normalizers = [] |
|||
self.act_type = act_type |
|||
self.use_lstm = use_lstm |
|||
self.h_size = h_size |
|||
for vector_size in vector_sizes: |
|||
self.vector_normalizers.append(Normalizer(vector_size)) |
|||
self.vector_encoders.append(VectorEncoder(vector_size, h_size, num_layers)) |
|||
for visual_size in visual_sizes: |
|||
self.visual_encoders.append(SimpleVisualEncoder(visual_size)) |
|||
|
|||
if use_lstm: |
|||
self.lstm = nn.LSTM(h_size, h_size, 1) |
|||
|
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
self.distribution = GaussianDistribution(h_size, act_size) |
|||
else: |
|||
self.distribution = MultiCategoricalDistribution(h_size, act_size) |
|||
|
|||
self.critic = Critic( |
|||
stream_names, h_size, VectorEncoder(vector_sizes[0], h_size, num_layers) |
|||
) |
|||
self.act_size = act_size |
|||
|
|||
def clear_memory(self, batch_size): |
|||
self.memory = ( |
|||
torch.zeros(1, batch_size, self.h_size), |
|||
torch.zeros(1, batch_size, self.h_size), |
|||
) |
|||
|
|||
def forward(self, vec_inputs, vis_inputs, masks=None): |
|||
vec_embeds = [] |
|||
for idx, encoder in enumerate(self.vector_encoders): |
|||
vec_input = vec_inputs[idx] |
|||
if self.normalize: |
|||
vec_input = self.normalizers[idx](vec_inputs[idx]) |
|||
hidden = encoder(vec_input) |
|||
vec_embeds.append(hidden) |
|||
|
|||
vis_embeds = [] |
|||
for idx, encoder in enumerate(self.visual_encoders): |
|||
hidden = encoder(vis_inputs[idx]) |
|||
vis_embeds.append(hidden) |
|||
|
|||
vec_embeds = torch.cat(vec_embeds) |
|||
vis_embeds = torch.cat(vis_embeds) |
|||
embedding = torch.cat([vec_embeds, vis_embeds]) |
|||
if self.use_lstm: |
|||
embedding, self.memory = self.lstm(embedding, self.memory) |
|||
|
|||
if self.act_type == ActionType.CONTINUOUS: |
|||
dist = self.distribution(embedding) |
|||
else: |
|||
dist = self.distribution(embedding, masks=masks) |
|||
return dist |
|||
|
|||
def update_normalization(self, inputs): |
|||
if self.normalize: |
|||
self.normalizer.update(inputs) |
|||
|
|||
def get_values(self, vec_inputs, vis_inputs): |
|||
if self.normalize: |
|||
vec_inputs = self.normalizer(vec_inputs) |
|||
return self.critic(vec_inputs) |
|||
|
|||
|
|||
class NNPolicy(TorchPolicy): |
|||
def __init__( |
|||
self, |
|||
seed: int, |
|||
brain: BrainParameters, |
|||
trainer_params: Dict[str, Any], |
|||
is_training: bool, |
|||
load: bool, |
|||
tanh_squash: bool = False, |
|||
reparameterize: bool = False, |
|||
condition_sigma_on_obs: bool = True, |
|||
create_tf_graph: bool = True, |
|||
): |
|||
""" |
|||
Policy that uses a multilayer perceptron to map the observations to actions. Could |
|||
also use a CNN to encode visual input prior to the MLP. Supports discrete and |
|||
continuous action spaces, as well as recurrent networks. |
|||
:param seed: Random seed. |
|||
:param brain: Assigned BrainParameters object. |
|||
:param trainer_params: Defined training parameters. |
|||
:param is_training: Whether the model should be trained. |
|||
:param load: Whether a pre-trained model will be loaded or a new one created. |
|||
:param tanh_squash: Whether to use a tanh function on the continuous output, or a clipped output. |
|||
:param reparameterize: Whether we are using the resampling trick to update the policy in continuous output. |
|||
""" |
|||
super().__init__(seed, brain, trainer_params, load) |
|||
self.grads = None |
|||
num_layers = trainer_params["num_layers"] |
|||
self.h_size = trainer_params["hidden_units"] |
|||
if num_layers < 1: |
|||
num_layers = 1 |
|||
self.num_layers = num_layers |
|||
self.vis_encode_type = EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
) |
|||
self.tanh_squash = tanh_squash |
|||
self.reparameterize = reparameterize |
|||
self.condition_sigma_on_obs = condition_sigma_on_obs |
|||
|
|||
# Non-exposed parameters; these aren't exposed because they don't have a |
|||
# good explanation and usually shouldn't be touched. |
|||
self.log_std_min = -20 |
|||
self.log_std_max = 2 |
|||
|
|||
self.inference_dict: Dict[str, tf.Tensor] = {} |
|||
self.update_dict: Dict[str, tf.Tensor] = {} |
|||
# TF defaults to 32-bit, so we use the same here. |
|||
torch.set_default_tensor_type(torch.DoubleTensor) |
|||
|
|||
reward_signal_configs = trainer_params["reward_signals"] |
|||
self.stats_name_to_update_name = { |
|||
"Losses/Value Loss": "value_loss", |
|||
"Losses/Policy Loss": "policy_loss", |
|||
} |
|||
|
|||
self.model = ActorCriticPolicy( |
|||
h_size=int(trainer_params["hidden_units"]), |
|||
act_type=ActionType.CONTINUOUS, |
|||
vector_sizes=[brain.vector_observation_space_size], |
|||
act_size=sum(brain.vector_action_space_size), |
|||
normalize=trainer_params["normalize"], |
|||
num_layers=int(trainer_params["num_layers"]), |
|||
m_size=trainer_params["memory_size"], |
|||
use_lstm=self.use_recurrent, |
|||
visual_sizes=brain.camera_resolutions, |
|||
stream_names=list(reward_signal_configs.keys()), |
|||
vis_encode_type=EncoderType( |
|||
trainer_params.get("vis_encode_type", "simple") |
|||
), |
|||
) |
|||
|
|||
def split_decision_step(self, decision_requests): |
|||
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs) |
|||
mask = None |
|||
if not self.use_continuous_act: |
|||
mask = np.ones( |
|||
(len(decision_requests), np.sum(self.brain.vector_action_space_size)), |
|||
dtype=np.float32, |
|||
) |
|||
if decision_requests.action_mask is not None: |
|||
mask = 1 - np.concatenate(decision_requests.action_mask, axis=1) |
|||
return vec_vis_obs.vector_observations, vec_vis_obs.visual_observations, mask |
|||
|
|||
def execute_model(self, vec_obs, vis_obs, masks): |
|||
action_dist = self.model(vec_obs, vis_obs, masks) |
|||
action = action_dist.sample() |
|||
log_probs = action_dist.log_prob(action) |
|||
entropy = action_dist.entropy() |
|||
value_heads = self.model.get_values(vec_obs, vis_obs) |
|||
return action, log_probs, entropy, value_heads |
|||
|
|||
@timed |
|||
def evaluate(self, decision_requests: DecisionSteps) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param decision_step: DecisionStep object containing inputs. |
|||
:return: Outputs from network as defined by self.inference_dict. |
|||
""" |
|||
vec_obs, vis_obs, masks = self.split_decision_step(decision_requests) |
|||
run_out = {} |
|||
action, log_probs, entropy, value_heads = self.execute_model( |
|||
vec_obs, vis_obs, masks |
|||
) |
|||
run_out["action"] = np.array(action.detach()) |
|||
run_out["log_probs"] = np.array(log_probs.detach()) |
|||
run_out["entropy"] = np.array(entropy.detach()) |
|||
run_out["value_heads"] = { |
|||
name: np.array(t.detach()) for name, t in value_heads.items() |
|||
} |
|||
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0) |
|||
run_out["learning_rate"] = 0.0 |
|||
self.model.update_normalization(decision_requests.vec_obs) |
|||
return run_out |
|||
|
|||
# def _create_cc_actor( |
|||
# self, |
|||
# encoded: tf.Tensor, |
|||
# tanh_squash: bool = False, |
|||
# reparameterize: bool = False, |
|||
# condition_sigma_on_obs: bool = True, |
|||
# ) -> None: |
|||
# """ |
|||
# Creates Continuous control actor-critic model. |
|||
# :param h_size: Size of hidden linear layers. |
|||
# :param num_layers: Number of hidden linear layers. |
|||
# :param vis_encode_type: Type of visual encoder to use if visual input. |
|||
# :param tanh_squash: Whether to use a tanh function, or a clipped output. |
|||
# :param reparameterize: Whether we are using the resampling trick to update the policy. |
|||
# """ |
|||
# if self.use_recurrent: |
|||
# self.memory_in = tf.placeholder( |
|||
# shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
# ) |
|||
# hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|||
# encoded, self.memory_in, self.sequence_length_ph, name="lstm_policy" |
|||
# ) |
|||
# |
|||
# self.memory_out = tf.identity(memory_policy_out, name="recurrent_out") |
|||
# else: |
|||
# hidden_policy = encoded |
|||
# |
|||
# with tf.variable_scope("policy"): |
|||
# distribution = GaussianDistribution( |
|||
# hidden_policy, |
|||
# self.act_size, |
|||
# reparameterize=reparameterize, |
|||
# tanh_squash=tanh_squash, |
|||
# condition_sigma=condition_sigma_on_obs, |
|||
# ) |
|||
# |
|||
# if tanh_squash: |
|||
# self.output_pre = distribution.sample |
|||
# self.output = tf.identity(self.output_pre, name="action") |
|||
# else: |
|||
# self.output_pre = distribution.sample |
|||
# # Clip and scale output to ensure actions are always within [-1, 1] range. |
|||
# output_post = tf.clip_by_value(self.output_pre, -3, 3) / 3 |
|||
# self.output = tf.identity(output_post, name="action") |
|||
# |
|||
# self.selected_actions = tf.stop_gradient(self.output) |
|||
# |
|||
# self.all_log_probs = tf.identity(distribution.log_probs, name="action_probs") |
|||
# self.entropy = distribution.entropy |
|||
# |
|||
# # We keep these tensors the same name, but use new nodes to keep code parallelism with discrete control. |
|||
# self.total_log_probs = distribution.total_log_probs |
|||
# |
|||
# def _create_dc_actor(self, encoded: tf.Tensor) -> None: |
|||
# """ |
|||
# Creates Discrete control actor-critic model. |
|||
# :param h_size: Size of hidden linear layers. |
|||
# :param num_layers: Number of hidden linear layers. |
|||
# :param vis_encode_type: Type of visual encoder to use if visual input. |
|||
# """ |
|||
# if self.use_recurrent: |
|||
# self.prev_action = tf.placeholder( |
|||
# shape=[None, len(self.act_size)], dtype=tf.int32, name="prev_action" |
|||
# ) |
|||
# prev_action_oh = tf.concat( |
|||
# [ |
|||
# tf.one_hot(self.prev_action[:, i], self.act_size[i]) |
|||
# for i in range(len(self.act_size)) |
|||
# ], |
|||
# axis=1, |
|||
# ) |
|||
# hidden_policy = tf.concat([encoded, prev_action_oh], axis=1) |
|||
# |
|||
# self.memory_in = tf.placeholder( |
|||
# shape=[None, self.m_size], dtype=tf.float32, name="recurrent_in" |
|||
# ) |
|||
# hidden_policy, memory_policy_out = ModelUtils.create_recurrent_encoder( |
|||
# hidden_policy, |
|||
# self.memory_in, |
|||
# self.sequence_length_ph, |
|||
# name="lstm_policy", |
|||
# ) |
|||
# |
|||
# self.memory_out = tf.identity(memory_policy_out, "recurrent_out") |
|||
# else: |
|||
# hidden_policy = encoded |
|||
# |
|||
# self.action_masks = tf.placeholder( |
|||
# shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks" |
|||
# ) |
|||
# |
|||
# with tf.variable_scope("policy"): |
|||
# distribution = MultiCategoricalDistribution( |
|||
# hidden_policy, self.act_size, self.action_masks |
|||
# ) |
|||
# # It's important that we are able to feed_dict a value into this tensor to get the |
|||
# # right one-hot encoding, so we can't do identity on it. |
|||
# self.output = distribution.sample |
|||
# self.all_log_probs = tf.identity(distribution.log_probs, name="action") |
|||
# self.selected_actions = tf.stop_gradient( |
|||
# distribution.sample_onehot |
|||
# ) # In discrete, these are onehot |
|||
# self.entropy = distribution.entropy |
|||
# self.total_log_probs = distribution.total_log_probs |
|
|||
from typing import Any, Dict, List, Optional |
|||
import numpy as np |
|||
from mlagents import tf_utils |
|||
from mlagents.tf_utils import tf |
|||
from mlagents_envs.exception import UnityException |
|||
from mlagents_envs.logging_util import get_logger |
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents.trainers.action_info import ActionInfo |
|||
from mlagents.trainers.trajectory import SplitObservations |
|||
from mlagents.trainers.brain_conversion_utils import get_global_agent_id |
|||
from mlagents_envs.base_env import DecisionSteps |
|||
from mlagents.trainers.models import ModelUtils |
|||
|
|||
|
|||
logger = get_logger(__name__) |
|||
|
|||
|
|||
class UnityPolicyException(UnityException): |
|||
""" |
|||
Related to errors with the Trainer. |
|||
""" |
|||
|
|||
pass |
|||
|
|||
|
|||
class TorchPolicy(Policy): |
|||
""" |
|||
Contains a learning model, and the necessary |
|||
functions to save/load models and create the input placeholders. |
|||
""" |
|||
|
|||
def __init__(self, seed, brain, trainer_parameters, load=False): |
|||
""" |
|||
Initialized the policy. |
|||
:param seed: Random seed to use for TensorFlow. |
|||
:param brain: The corresponding Brain for this policy. |
|||
:param trainer_parameters: The trainer parameters. |
|||
""" |
|||
self._version_number_ = 2 |
|||
self.m_size = 0 |
|||
|
|||
# for ghost trainer save/load snapshots |
|||
self.assign_phs = [] |
|||
self.assign_ops = [] |
|||
|
|||
self.inference_dict = {} |
|||
self.update_dict = {} |
|||
self.sequence_length = 1 |
|||
self.seed = seed |
|||
self.brain = brain |
|||
|
|||
self.act_size = brain.vector_action_space_size |
|||
self.vec_obs_size = brain.vector_observation_space_size |
|||
self.vis_obs_size = brain.number_visual_observations |
|||
|
|||
self.use_recurrent = trainer_parameters["use_recurrent"] |
|||
self.memory_dict: Dict[str, np.ndarray] = {} |
|||
self.num_branches = len(self.brain.vector_action_space_size) |
|||
self.previous_action_dict: Dict[str, np.array] = {} |
|||
self.normalize = trainer_parameters.get("normalize", False) |
|||
self.use_continuous_act = brain.vector_action_space_type == "continuous" |
|||
if self.use_continuous_act: |
|||
self.num_branches = self.brain.vector_action_space_size[0] |
|||
self.model_path = trainer_parameters["model_path"] |
|||
self.initialize_path = trainer_parameters.get("init_path", None) |
|||
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5) |
|||
self.graph = tf.Graph() |
|||
self.sess = tf.Session( |
|||
config=tf_utils.generate_session_config(), graph=self.graph |
|||
) |
|||
self.saver = None |
|||
self.seed = seed |
|||
if self.use_recurrent: |
|||
self.m_size = trainer_parameters["memory_size"] |
|||
self.sequence_length = trainer_parameters["sequence_length"] |
|||
if self.m_size == 0: |
|||
raise UnityPolicyException( |
|||
"The memory size for brain {0} is 0 even " |
|||
"though the trainer uses recurrent.".format(brain.brain_name) |
|||
) |
|||
elif self.m_size % 2 != 0: |
|||
raise UnityPolicyException( |
|||
"The memory size for brain {0} is {1} " |
|||
"but it must be divisible by 2.".format( |
|||
brain.brain_name, self.m_size |
|||
) |
|||
) |
|||
self._initialize_tensorflow_references() |
|||
self.load = load |
|||
|
|||
def _initialize_graph(self): |
|||
with self.graph.as_default(): |
|||
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) |
|||
init = tf.global_variables_initializer() |
|||
self.sess.run(init) |
|||
|
|||
def _load_graph(self, model_path: str, reset_global_steps: bool = False) -> None: |
|||
with self.graph.as_default(): |
|||
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) |
|||
logger.info( |
|||
"Loading model for brain {} from {}.".format( |
|||
self.brain.brain_name, model_path |
|||
) |
|||
) |
|||
ckpt = tf.train.get_checkpoint_state(model_path) |
|||
if ckpt is None: |
|||
raise UnityPolicyException( |
|||
"The model {0} could not be loaded. Make " |
|||
"sure you specified the right " |
|||
"--run-id and that the previous run you are loading from had the same " |
|||
"behavior names.".format(model_path) |
|||
) |
|||
try: |
|||
self.saver.restore(self.sess, ckpt.model_checkpoint_path) |
|||
except tf.errors.NotFoundError: |
|||
raise UnityPolicyException( |
|||
"The model {0} was found but could not be loaded. Make " |
|||
"sure the model is from the same version of ML-Agents, has the same behavior parameters, " |
|||
"and is using the same trainer configuration as the current run.".format( |
|||
model_path |
|||
) |
|||
) |
|||
if reset_global_steps: |
|||
logger.info( |
|||
"Starting training from step 0 and saving to {}.".format( |
|||
self.model_path |
|||
) |
|||
) |
|||
else: |
|||
logger.info( |
|||
"Resuming training from step {}.".format(self.get_current_step()) |
|||
) |
|||
|
|||
def initialize_or_load(self): |
|||
# If there is an initialize path, load from that. Else, load from the set model path. |
|||
# If load is set to True, don't reset steps to 0. Else, do. This allows a user to, |
|||
# e.g., resume from an initialize path. |
|||
reset_steps = not self.load |
|||
if self.initialize_path is not None: |
|||
self._load_graph(self.initialize_path, reset_global_steps=reset_steps) |
|||
elif self.load: |
|||
self._load_graph(self.model_path, reset_global_steps=reset_steps) |
|||
else: |
|||
self._initialize_graph() |
|||
|
|||
def get_weights(self): |
|||
with self.graph.as_default(): |
|||
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|||
values = [v.eval(session=self.sess) for v in _vars] |
|||
return values |
|||
|
|||
def init_load_weights(self): |
|||
with self.graph.as_default(): |
|||
_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|||
values = [v.eval(session=self.sess) for v in _vars] |
|||
for var, value in zip(_vars, values): |
|||
assign_ph = tf.placeholder(var.dtype, shape=value.shape) |
|||
self.assign_phs.append(assign_ph) |
|||
self.assign_ops.append(tf.assign(var, assign_ph)) |
|||
|
|||
def load_weights(self, values): |
|||
if len(self.assign_ops) == 0: |
|||
logger.warning( |
|||
"Calling load_weights in tf_policy but assign_ops is empty. Did you forget to call init_load_weights?" |
|||
) |
|||
with self.graph.as_default(): |
|||
feed_dict = {} |
|||
for assign_ph, value in zip(self.assign_phs, values): |
|||
feed_dict[assign_ph] = value |
|||
self.sess.run(self.assign_ops, feed_dict=feed_dict) |
|||
|
|||
def evaluate(self, decision_requests: DecisionSteps) -> Dict[str, Any]: |
|||
""" |
|||
Evaluates policy for the agent experiences provided. |
|||
:param decision_requests: DecisionSteps input to network. |
|||
:return: Output from policy based on self.inference_dict. |
|||
""" |
|||
raise UnityPolicyException("The evaluate function was not implemented.") |
|||
|
|||
def get_action( |
|||
self, decision_requests: DecisionSteps, worker_id: int = 0 |
|||
) -> ActionInfo: |
|||
""" |
|||
Decides actions given observations information, and takes them in environment. |
|||
:param decision_requests: A dictionary of brain names and DecisionSteps from environment. |
|||
:param worker_id: In parallel environment training, the unique id of the environment worker that |
|||
the DecisionSteps came from. Used to construct a globally unique id for each agent. |
|||
:return: an ActionInfo containing action, memories, values and an object |
|||
to be passed to add experiences |
|||
""" |
|||
if len(decision_requests) == 0: |
|||
return ActionInfo.empty() |
|||
|
|||
global_agent_ids = [ |
|||
get_global_agent_id(worker_id, int(agent_id)) |
|||
for agent_id in decision_requests.agent_id |
|||
] # For 1-D array, the iterator order is correct. |
|||
|
|||
run_out = self.evaluate( # pylint: disable=assignment-from-no-return |
|||
decision_requests |
|||
) |
|||
|
|||
self.save_memories(global_agent_ids, run_out.get("memory_out")) |
|||
return ActionInfo( |
|||
action=run_out.get("action"), |
|||
value=run_out.get("value"), |
|||
outputs=run_out, |
|||
agent_ids=decision_requests.agent_id, |
|||
) |
|||
|
|||
def update(self, mini_batch, num_sequences): |
|||
""" |
|||
Performs update of the policy. |
|||
:param num_sequences: Number of experience trajectories in batch. |
|||
:param mini_batch: Batch of experiences. |
|||
:return: Results of update. |
|||
""" |
|||
raise UnityPolicyException("The update function was not implemented.") |
|||
|
|||
def _execute_model(self, feed_dict, out_dict): |
|||
""" |
|||
Executes model. |
|||
:param feed_dict: Input dictionary mapping nodes to input data. |
|||
:param out_dict: Output dictionary mapping names to nodes. |
|||
:return: Dictionary mapping names to input data. |
|||
""" |
|||
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict) |
|||
run_out = dict(zip(list(out_dict.keys()), network_out)) |
|||
return run_out |
|||
|
|||
def fill_eval_dict(self, batched_step_result): |
|||
vec_vis_obs = SplitObservations.from_observations(batched_step_result.obs) |
|||
mask = None |
|||
if not self.use_continuous_act: |
|||
mask = np.ones( |
|||
(len(batched_step_result), np.sum(self.brain.vector_action_space_size)), |
|||
dtype=np.float32, |
|||
) |
|||
if batched_step_result.action_mask is not None: |
|||
mask = 1 - np.concatenate(batched_step_result.action_mask, axis=1) |
|||
return vec_vis_obs.vector_observations, vec_vis_obs.visual_observations, mask |
|||
|
|||
def make_empty_memory(self, num_agents): |
|||
""" |
|||
Creates empty memory for use with RNNs |
|||
:param num_agents: Number of agents. |
|||
:return: Numpy array of zeros. |
|||
""" |
|||
return np.zeros((num_agents, self.m_size), dtype=np.float32) |
|||
|
|||
def save_memories( |
|||
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray] |
|||
) -> None: |
|||
if memory_matrix is None: |
|||
return |
|||
for index, agent_id in enumerate(agent_ids): |
|||
self.memory_dict[agent_id] = memory_matrix[index, :] |
|||
|
|||
def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray: |
|||
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) |
|||
for index, agent_id in enumerate(agent_ids): |
|||
if agent_id in self.memory_dict: |
|||
memory_matrix[index, :] = self.memory_dict[agent_id] |
|||
return memory_matrix |
|||
|
|||
def remove_memories(self, agent_ids): |
|||
for agent_id in agent_ids: |
|||
if agent_id in self.memory_dict: |
|||
self.memory_dict.pop(agent_id) |
|||
|
|||
def make_empty_previous_action(self, num_agents): |
|||
""" |
|||
Creates empty previous action for use with RNNs and discrete control |
|||
:param num_agents: Number of agents. |
|||
:return: Numpy array of zeros. |
|||
""" |
|||
return np.zeros((num_agents, self.num_branches), dtype=np.int) |
|||
|
|||
def save_previous_action( |
|||
self, agent_ids: List[str], action_matrix: Optional[np.ndarray] |
|||
) -> None: |
|||
if action_matrix is None: |
|||
return |
|||
for index, agent_id in enumerate(agent_ids): |
|||
self.previous_action_dict[agent_id] = action_matrix[index, :] |
|||
|
|||
def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray: |
|||
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int) |
|||
for index, agent_id in enumerate(agent_ids): |
|||
if agent_id in self.previous_action_dict: |
|||
action_matrix[index, :] = self.previous_action_dict[agent_id] |
|||
return action_matrix |
|||
|
|||
def remove_previous_action(self, agent_ids): |
|||
for agent_id in agent_ids: |
|||
if agent_id in self.previous_action_dict: |
|||
self.previous_action_dict.pop(agent_id) |
|||
|
|||
def get_current_step(self): |
|||
""" |
|||
Gets current model step. |
|||
:return: current model step. |
|||
""" |
|||
step = self.sess.run(self.global_step) |
|||
return step |
|||
|
|||
def _set_step(self, step: int) -> int: |
|||
""" |
|||
Sets current model step to step without creating additional ops. |
|||
:param step: Step to set the current model step to. |
|||
:return: The step the model was set to. |
|||
""" |
|||
current_step = self.get_current_step() |
|||
# Increment a positive or negative number of steps. |
|||
return self.increment_step(step - current_step) |
|||
|
|||
def increment_step(self, n_steps): |
|||
""" |
|||
Increments model step. |
|||
""" |
|||
out_dict = { |
|||
"global_step": self.global_step, |
|||
"increment_step": self.increment_step_op, |
|||
} |
|||
feed_dict = {self.steps_to_increment: n_steps} |
|||
return self.sess.run(out_dict, feed_dict=feed_dict)["global_step"] |
|||
|
|||
def get_inference_vars(self): |
|||
""" |
|||
:return:list of inference var names |
|||
""" |
|||
return list(self.inference_dict.keys()) |
|||
|
|||
def get_update_vars(self): |
|||
""" |
|||
:return:list of update var names |
|||
""" |
|||
return list(self.update_dict.keys()) |
|||
|
|||
def save_model(self, steps): |
|||
""" |
|||
Saves the model |
|||
:param steps: The number of steps the model was trained for |
|||
:return: |
|||
""" |
|||
with self.graph.as_default(): |
|||
last_checkpoint = self.model_path + "/model-" + str(steps) + ".ckpt" |
|||
self.saver.save(self.sess, last_checkpoint) |
|||
tf.train.write_graph( |
|||
self.graph, self.model_path, "raw_graph_def.pb", as_text=False |
|||
) |
|||
|
|||
def update_normalization(self, vector_obs: np.ndarray) -> None: |
|||
""" |
|||
If this policy normalizes vector observations, this will update the norm values in the graph. |
|||
:param vector_obs: The vector observations to add to the running estimate of the distribution. |
|||
""" |
|||
if self.use_vec_obs and self.normalize: |
|||
self.sess.run( |
|||
self.update_normalization_op, feed_dict={self.vector_in: vector_obs} |
|||
) |
|||
|
|||
@property |
|||
def use_vis_obs(self): |
|||
return self.vis_obs_size > 0 |
|||
|
|||
@property |
|||
def use_vec_obs(self): |
|||
return self.vec_obs_size > 0 |
|||
|
|||
def _initialize_tensorflow_references(self): |
|||
self.value_heads: Dict[str, tf.Tensor] = {} |
|||
self.normalization_steps: Optional[tf.Variable] = None |
|||
self.running_mean: Optional[tf.Variable] = None |
|||
self.running_variance: Optional[tf.Variable] = None |
|||
self.update_normalization_op: Optional[tf.Operation] = None |
|||
self.value: Optional[tf.Tensor] = None |
|||
self.all_log_probs: tf.Tensor = None |
|||
self.total_log_probs: Optional[tf.Tensor] = None |
|||
self.entropy: Optional[tf.Tensor] = None |
|||
self.output_pre: Optional[tf.Tensor] = None |
|||
self.output: Optional[tf.Tensor] = None |
|||
self.selected_actions: tf.Tensor = None |
|||
self.action_masks: Optional[tf.Tensor] = None |
|||
self.prev_action: Optional[tf.Tensor] = None |
|||
self.memory_in: Optional[tf.Tensor] = None |
|||
self.memory_out: Optional[tf.Tensor] = None |
|||
|
|||
def create_input_placeholders(self): |
|||
with self.graph.as_default(): |
|||
( |
|||
self.global_step, |
|||
self.increment_step_op, |
|||
self.steps_to_increment, |
|||
) = ModelUtils.create_global_steps() |
|||
self.visual_in = ModelUtils.create_visual_input_placeholders( |
|||
self.brain.camera_resolutions |
|||
) |
|||
self.vector_in = ModelUtils.create_vector_input(self.vec_obs_size) |
|||
if self.normalize: |
|||
normalization_tensors = ModelUtils.create_normalizer(self.vector_in) |
|||
self.update_normalization_op = normalization_tensors.update_op |
|||
self.normalization_steps = normalization_tensors.steps |
|||
self.running_mean = normalization_tensors.running_mean |
|||
self.running_variance = normalization_tensors.running_variance |
|||
self.processed_vector_in = ModelUtils.normalize_vector_obs( |
|||
self.vector_in, |
|||
self.running_mean, |
|||
self.running_variance, |
|||
self.normalization_steps, |
|||
) |
|||
else: |
|||
self.processed_vector_in = self.vector_in |
|||
self.update_normalization_op = None |
|||
|
|||
self.batch_size_ph = tf.placeholder( |
|||
shape=None, dtype=tf.int32, name="batch_size" |
|||
) |
|||
self.sequence_length_ph = tf.placeholder( |
|||
shape=None, dtype=tf.int32, name="sequence_length" |
|||
) |
|||
self.mask_input = tf.placeholder( |
|||
shape=[None], dtype=tf.float32, name="masks" |
|||
) |
|||
# Only needed for PPO, but needed for BC module |
|||
self.epsilon = tf.placeholder( |
|||
shape=[None, self.act_size[0]], dtype=tf.float32, name="epsilon" |
|||
) |
|||
self.mask = tf.cast(self.mask_input, tf.int32) |
|||
|
|||
tf.Variable( |
|||
int(self.brain.vector_action_space_type == "continuous"), |
|||
name="is_continuous_control", |
|||
trainable=False, |
|||
dtype=tf.int32, |
|||
) |
|||
tf.Variable( |
|||
self._version_number_, |
|||
name="version_number", |
|||
trainable=False, |
|||
dtype=tf.int32, |
|||
) |
|||
tf.Variable( |
|||
self.m_size, name="memory_size", trainable=False, dtype=tf.int32 |
|||
) |
|||
if self.brain.vector_action_space_type == "continuous": |
|||
tf.Variable( |
|||
self.act_size[0], |
|||
name="action_output_shape", |
|||
trainable=False, |
|||
dtype=tf.int32, |
|||
) |
|||
else: |
|||
tf.Variable( |
|||
sum(self.act_size), |
|||
name="action_output_shape", |
|||
trainable=False, |
|||
dtype=tf.int32, |
|||
) |
撰写
预览
正在加载...
取消
保存
Reference in new issue