您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
295 行
11 KiB
295 行
11 KiB
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import numpy as np
|
|
from mlagents.trainers import tf
|
|
|
|
from mlagents.envs.exception import UnityException
|
|
from mlagents.envs.policy import Policy
|
|
from mlagents.envs.action_info import ActionInfo
|
|
from tensorflow.python.platform import gfile
|
|
from tensorflow.python.framework import graph_util
|
|
from mlagents.trainers import tensorflow_to_barracuda as tf2bc
|
|
from mlagents.envs.brain import BrainInfo
|
|
|
|
|
|
logger = logging.getLogger("mlagents.trainers")
|
|
|
|
|
|
class UnityPolicyException(UnityException):
|
|
"""
|
|
Related to errors with the Trainer.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class TFPolicy(Policy):
|
|
"""
|
|
Contains a learning model, and the necessary
|
|
functions to interact with it to perform evaluate and updating.
|
|
"""
|
|
|
|
possible_output_nodes = [
|
|
"action",
|
|
"value_estimate",
|
|
"action_probs",
|
|
"recurrent_out",
|
|
"memory_size",
|
|
"version_number",
|
|
"is_continuous_control",
|
|
"action_output_shape",
|
|
]
|
|
|
|
def __init__(self, seed, brain, trainer_parameters):
|
|
"""
|
|
Initialized the policy.
|
|
:param seed: Random seed to use for TensorFlow.
|
|
:param brain: The corresponding Brain for this policy.
|
|
:param trainer_parameters: The trainer parameters.
|
|
"""
|
|
self.m_size = None
|
|
self.model = None
|
|
self.inference_dict = {}
|
|
self.update_dict = {}
|
|
self.sequence_length = 1
|
|
self.seed = seed
|
|
self.brain = brain
|
|
self.use_recurrent = trainer_parameters["use_recurrent"]
|
|
self.memory_dict: Dict[int, np.ndarray] = {}
|
|
self.normalize = trainer_parameters.get("normalize", False)
|
|
self.use_continuous_act = brain.vector_action_space_type == "continuous"
|
|
self.model_path = trainer_parameters["model_path"]
|
|
self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5)
|
|
self.graph = tf.Graph()
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.allow_growth = True
|
|
# For multi-GPU training, set allow_soft_placement to True to allow
|
|
# placing the operation into an alternative device automatically
|
|
# to prevent from exceptions if the device doesn't suppport the operation
|
|
# or the device does not exist
|
|
config.allow_soft_placement = True
|
|
self.sess = tf.Session(config=config, graph=self.graph)
|
|
self.saver = None
|
|
if self.use_recurrent:
|
|
self.m_size = trainer_parameters["memory_size"]
|
|
self.sequence_length = trainer_parameters["sequence_length"]
|
|
if self.m_size == 0:
|
|
raise UnityPolicyException(
|
|
"The memory size for brain {0} is 0 even "
|
|
"though the trainer uses recurrent.".format(brain.brain_name)
|
|
)
|
|
elif self.m_size % 4 != 0:
|
|
raise UnityPolicyException(
|
|
"The memory size for brain {0} is {1} "
|
|
"but it must be divisible by 4.".format(
|
|
brain.brain_name, self.m_size
|
|
)
|
|
)
|
|
|
|
def _initialize_graph(self):
|
|
with self.graph.as_default():
|
|
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
|
|
init = tf.global_variables_initializer()
|
|
self.sess.run(init)
|
|
|
|
def _load_graph(self):
|
|
with self.graph.as_default():
|
|
self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints)
|
|
logger.info("Loading Model for brain {}".format(self.brain.brain_name))
|
|
ckpt = tf.train.get_checkpoint_state(self.model_path)
|
|
if ckpt is None:
|
|
logger.info(
|
|
"The model {0} could not be found. Make "
|
|
"sure you specified the right "
|
|
"--run-id".format(self.model_path)
|
|
)
|
|
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
|
|
|
|
def evaluate(self, brain_info: BrainInfo) -> Dict[str, Any]:
|
|
"""
|
|
Evaluates policy for the agent experiences provided.
|
|
:param brain_info: BrainInfo input to network.
|
|
:return: Output from policy based on self.inference_dict.
|
|
"""
|
|
raise UnityPolicyException("The evaluate function was not implemented.")
|
|
|
|
def get_action(self, brain_info: BrainInfo) -> ActionInfo:
|
|
"""
|
|
Decides actions given observations information, and takes them in environment.
|
|
:param brain_info: A dictionary of brain names and BrainInfo from environment.
|
|
:return: an ActionInfo containing action, memories, values and an object
|
|
to be passed to add experiences
|
|
"""
|
|
if len(brain_info.agents) == 0:
|
|
return ActionInfo([], [], None)
|
|
|
|
self.remove_memories(
|
|
[
|
|
agent
|
|
for agent, done in zip(brain_info.agents, brain_info.local_done)
|
|
if done
|
|
]
|
|
)
|
|
run_out = self.evaluate(brain_info) # pylint: disable=assignment-from-no-return
|
|
self.save_memories(brain_info.agents, run_out.get("memory_out"))
|
|
return ActionInfo(
|
|
action=run_out.get("action"), value=run_out.get("value"), outputs=run_out
|
|
)
|
|
|
|
def update(self, mini_batch, num_sequences):
|
|
"""
|
|
Performs update of the policy.
|
|
:param num_sequences: Number of experience trajectories in batch.
|
|
:param mini_batch: Batch of experiences.
|
|
:return: Results of update.
|
|
"""
|
|
raise UnityPolicyException("The update function was not implemented.")
|
|
|
|
def _execute_model(self, feed_dict, out_dict):
|
|
"""
|
|
Executes model.
|
|
:param feed_dict: Input dictionary mapping nodes to input data.
|
|
:param out_dict: Output dictionary mapping names to nodes.
|
|
:return: Dictionary mapping names to input data.
|
|
"""
|
|
network_out = self.sess.run(list(out_dict.values()), feed_dict=feed_dict)
|
|
run_out = dict(zip(list(out_dict.keys()), network_out))
|
|
return run_out
|
|
|
|
def fill_eval_dict(self, feed_dict, brain_info):
|
|
for i, _ in enumerate(brain_info.visual_observations):
|
|
feed_dict[self.model.visual_in[i]] = brain_info.visual_observations[i]
|
|
if self.use_vec_obs:
|
|
feed_dict[self.model.vector_in] = brain_info.vector_observations
|
|
if not self.use_continuous_act:
|
|
feed_dict[self.model.action_masks] = brain_info.action_masks
|
|
return feed_dict
|
|
|
|
def make_empty_memory(self, num_agents):
|
|
"""
|
|
Creates empty memory for use with RNNs
|
|
:param num_agents: Number of agents.
|
|
:return: Numpy array of zeros.
|
|
"""
|
|
return np.zeros((num_agents, self.m_size))
|
|
|
|
def save_memories(
|
|
self, agent_ids: List[int], memory_matrix: Optional[np.ndarray]
|
|
) -> None:
|
|
if memory_matrix is None:
|
|
return
|
|
for index, agent_id in enumerate(agent_ids):
|
|
self.memory_dict[agent_id] = memory_matrix[index, :]
|
|
|
|
def retrieve_memories(self, agent_ids: List[int]) -> np.ndarray:
|
|
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float)
|
|
for index, agent_id in enumerate(agent_ids):
|
|
if agent_id in self.memory_dict:
|
|
memory_matrix[index, :] = self.memory_dict[agent_id]
|
|
return memory_matrix
|
|
|
|
def remove_memories(self, agent_ids):
|
|
for agent_id in agent_ids:
|
|
if agent_id in self.memory_dict:
|
|
self.memory_dict.pop(agent_id)
|
|
|
|
def get_current_step(self):
|
|
"""
|
|
Gets current model step.
|
|
:return: current model step.
|
|
"""
|
|
step = self.sess.run(self.model.global_step)
|
|
return step
|
|
|
|
def increment_step(self, n_steps):
|
|
"""
|
|
Increments model step.
|
|
"""
|
|
out_dict = {
|
|
"global_step": self.model.global_step,
|
|
"increment_step": self.model.increment_step,
|
|
}
|
|
feed_dict = {self.model.steps_to_increment: n_steps}
|
|
return self.sess.run(out_dict, feed_dict=feed_dict)["global_step"]
|
|
|
|
def get_inference_vars(self):
|
|
"""
|
|
:return:list of inference var names
|
|
"""
|
|
return list(self.inference_dict.keys())
|
|
|
|
def get_update_vars(self):
|
|
"""
|
|
:return:list of update var names
|
|
"""
|
|
return list(self.update_dict.keys())
|
|
|
|
def save_model(self, steps):
|
|
"""
|
|
Saves the model
|
|
:param steps: The number of steps the model was trained for
|
|
:return:
|
|
"""
|
|
with self.graph.as_default():
|
|
last_checkpoint = self.model_path + "/model-" + str(steps) + ".cptk"
|
|
self.saver.save(self.sess, last_checkpoint)
|
|
tf.train.write_graph(
|
|
self.graph, self.model_path, "raw_graph_def.pb", as_text=False
|
|
)
|
|
|
|
def export_model(self):
|
|
"""
|
|
Exports latest saved model to .nn format for Unity embedding.
|
|
"""
|
|
|
|
with self.graph.as_default():
|
|
target_nodes = ",".join(self._process_graph())
|
|
graph_def = self.graph.as_graph_def()
|
|
output_graph_def = graph_util.convert_variables_to_constants(
|
|
self.sess, graph_def, target_nodes.replace(" ", "").split(",")
|
|
)
|
|
frozen_graph_def_path = self.model_path + "/frozen_graph_def.pb"
|
|
with gfile.GFile(frozen_graph_def_path, "wb") as f:
|
|
f.write(output_graph_def.SerializeToString())
|
|
tf2bc.convert(frozen_graph_def_path, self.model_path + ".nn")
|
|
logger.info("Exported " + self.model_path + ".nn file")
|
|
|
|
def _process_graph(self):
|
|
"""
|
|
Gets the list of the output nodes present in the graph for inference
|
|
:return: list of node names
|
|
"""
|
|
all_nodes = [x.name for x in self.graph.as_graph_def().node]
|
|
nodes = [x for x in all_nodes if x in self.possible_output_nodes]
|
|
logger.info("List of nodes to export for brain :" + self.brain.brain_name)
|
|
for n in nodes:
|
|
logger.info("\t" + n)
|
|
return nodes
|
|
|
|
def update_normalization(self, vector_obs: np.ndarray) -> None:
|
|
"""
|
|
If this policy normalizes vector observations, this will update the norm values in the graph.
|
|
:param vector_obs: The vector observations to add to the running estimate of the distribution.
|
|
"""
|
|
if self.use_vec_obs and self.normalize:
|
|
self.sess.run(
|
|
self.model.update_normalization,
|
|
feed_dict={self.model.vector_in: vector_obs},
|
|
)
|
|
|
|
@property
|
|
def vis_obs_size(self):
|
|
return self.model.vis_obs_size
|
|
|
|
@property
|
|
def vec_obs_size(self):
|
|
return self.model.vec_obs_size
|
|
|
|
@property
|
|
def use_vis_obs(self):
|
|
return self.model.vis_obs_size > 0
|
|
|
|
@property
|
|
def use_vec_obs(self):
|
|
return self.model.vec_obs_size > 0
|