Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

316 行
13 KiB

import logging
from typing import Dict, Any, Optional
import numpy as np
import tensorflow as tf
from mlagents.envs.timers import timed
from mlagents.envs.brain import BrainInfo, BrainParameters
from mlagents.trainers.models import EncoderType, LearningRateSchedule
from mlagents.trainers.sac.models import SACModel
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.components.reward_signals.reward_signal_factory import (
create_reward_signal,
)
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents.trainers.components.bc.module import BCModule
logger = logging.getLogger("mlagents.trainers")
class SACPolicy(TFPolicy):
def __init__(
self,
seed: int,
brain: BrainParameters,
trainer_params: Dict[str, Any],
is_training: bool,
load: bool,
) -> None:
"""
Policy for Proximal Policy Optimization Networks.
:param seed: Random seed.
:param brain: Assigned Brain object.
:param trainer_params: Defined training parameters.
:param is_training: Whether the model should be trained.
:param load: Whether a pre-trained model will be loaded or a new one created.
"""
super().__init__(seed, brain, trainer_params)
reward_signal_configs = {}
for key, rsignal in trainer_params["reward_signals"].items():
if type(rsignal) is dict:
reward_signal_configs[key] = rsignal
self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
self.create_model(
brain, trainer_params, reward_signal_configs, is_training, load, seed
)
self.create_reward_signals(reward_signal_configs)
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
"Losses/Q1 Loss": "q1_loss",
"Losses/Q2 Loss": "q2_loss",
"Policy/Entropy Coeff": "entropy_coef",
}
with self.graph.as_default():
# Create pretrainer if needed
self.bc_module: Optional[BCModule] = None
if "pretraining" in trainer_params:
BCModule.check_config(trainer_params["pretraining"])
self.bc_module = BCModule(
self,
policy_learning_rate=trainer_params["learning_rate"],
default_batch_size=trainer_params["batch_size"],
default_num_epoch=1,
samples_per_update=trainer_params["batch_size"],
**trainer_params["pretraining"],
)
# SAC-specific setting - we don't want to do a whole epoch each update!
if "samples_per_update" in trainer_params["pretraining"]:
logger.warning(
"Pretraining: Samples Per Update is not a valid setting for SAC."
)
self.bc_module.samples_per_update = 1
if load:
self._load_graph()
else:
self._initialize_graph()
self.sess.run(self.model.target_init_op)
# Disable terminal states for certain reward signals to avoid survivor bias
for name, reward_signal in self.reward_signals.items():
if not reward_signal.use_terminal_states:
self.sess.run(self.model.disable_use_dones[name])
def create_model(
self,
brain: BrainParameters,
trainer_params: Dict[str, Any],
reward_signal_configs: Dict[str, Any],
is_training: bool,
load: bool,
seed: int,
) -> None:
with self.graph.as_default():
self.model = SACModel(
brain,
lr=float(trainer_params["learning_rate"]),
lr_schedule=LearningRateSchedule(
trainer_params.get("learning_rate_schedule", "constant")
),
h_size=int(trainer_params["hidden_units"]),
init_entcoef=float(trainer_params["init_entcoef"]),
max_step=float(trainer_params["max_steps"]),
normalize=trainer_params["normalize"],
use_recurrent=trainer_params["use_recurrent"],
num_layers=int(trainer_params["num_layers"]),
m_size=self.m_size,
seed=seed,
stream_names=list(reward_signal_configs.keys()),
tau=float(trainer_params["tau"]),
gammas=[_val["gamma"] for _val in reward_signal_configs.values()],
vis_encode_type=EncoderType(
trainer_params.get("vis_encode_type", "simple")
),
)
self.model.create_sac_optimizers()
self.inference_dict.update(
{
"action": self.model.output,
"log_probs": self.model.all_log_probs,
"value_heads": self.model.value_heads,
"value": self.model.value,
"entropy": self.model.entropy,
"learning_rate": self.model.learning_rate,
}
)
if self.use_continuous_act:
self.inference_dict["pre_action"] = self.model.output_pre
if self.use_recurrent:
self.inference_dict["memory_out"] = self.model.memory_out
self.update_dict.update(
{
"value_loss": self.model.total_value_loss,
"policy_loss": self.model.policy_loss,
"q1_loss": self.model.q1_loss,
"q2_loss": self.model.q2_loss,
"entropy_coef": self.model.ent_coef,
"entropy": self.model.entropy,
"update_batch": self.model.update_batch_policy,
"update_value": self.model.update_batch_value,
"update_entropy": self.model.update_batch_entropy,
}
)
def create_reward_signals(self, reward_signal_configs: Dict[str, Any]) -> None:
"""
Create reward signals
:param reward_signal_configs: Reward signal config.
"""
self.reward_signals: Dict[str, RewardSignal] = {}
with self.graph.as_default():
# Create reward signals
for reward_signal, config in reward_signal_configs.items():
if type(config) is dict:
self.reward_signals[reward_signal] = create_reward_signal(
self, self.model, reward_signal, config
)
def evaluate(self, brain_info: BrainInfo) -> Dict[str, np.ndarray]:
"""
Evaluates policy for the agent experiences provided.
:param brain_info: BrainInfo object containing inputs.
:return: Outputs from network as defined by self.inference_dict.
"""
feed_dict = {
self.model.batch_size: len(brain_info.vector_observations),
self.model.sequence_length: 1,
}
if self.use_recurrent:
if not self.use_continuous_act:
feed_dict[
self.model.prev_action
] = brain_info.previous_vector_actions.reshape(
[-1, len(self.model.act_size)]
)
feed_dict[self.model.memory_in] = self.retrieve_memories(brain_info.agents)
feed_dict = self.fill_eval_dict(feed_dict, brain_info)
run_out = self._execute_model(feed_dict, self.inference_dict)
return run_out
@timed
def update(
self, mini_batch: Dict[str, Any], num_sequences: int, update_target: bool = True
) -> Dict[str, float]:
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.
:param mini_batch: Experience batch.
:param update_target: Whether or not to update target value network
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
:return: Output from update process.
"""
feed_dict = self.construct_feed_dict(self.model, mini_batch, num_sequences)
stats_needed = self.stats_name_to_update_name
update_stats: Dict[str, float] = {}
update_vals = self._execute_model(feed_dict, self.update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
if update_target:
self.sess.run(self.model.target_update_op)
return update_stats
def update_reward_signals(
self, reward_signal_minibatches: Dict[str, Dict], num_sequences: int
) -> Dict[str, float]:
"""
Only update the reward signals.
:param reward_signal_mini_batches: Minibatches to use for updating the reward signals,
indexed by name. If none, don't update the reward signals.
"""
# Collect feed dicts for all reward signals.
feed_dict: Dict[tf.Tensor, Any] = {}
update_dict: Dict[str, tf.Tensor] = {}
update_stats: Dict[str, float] = {}
stats_needed: Dict[str, str] = {}
if reward_signal_minibatches:
self.add_reward_signal_dicts(
feed_dict,
update_dict,
stats_needed,
reward_signal_minibatches,
num_sequences,
)
update_vals = self._execute_model(feed_dict, update_dict)
for stat_name, update_name in stats_needed.items():
update_stats[stat_name] = update_vals[update_name]
return update_stats
def add_reward_signal_dicts(
self,
feed_dict: Dict[tf.Tensor, Any],
update_dict: Dict[str, tf.Tensor],
stats_needed: Dict[str, str],
reward_signal_minibatches: Dict[str, Dict],
num_sequences: int,
) -> None:
"""
Adds the items needed for reward signal updates to the feed_dict and stats_needed dict.
:param feed_dict: Feed dict needed update
:param update_dit: Update dict that needs update
:param stats_needed: Stats needed to get from the update.
:param reward_signal_minibatches: Minibatches to use for updating the reward signals,
indexed by name.
"""
for name, r_mini_batch in reward_signal_minibatches.items():
feed_dict.update(
self.reward_signals[name].prepare_update(
self.model, r_mini_batch, num_sequences
)
)
update_dict.update(self.reward_signals[name].update_dict)
stats_needed.update(self.reward_signals[name].stats_name_to_update_name)
def construct_feed_dict(
self, model: SACModel, mini_batch: Dict[str, Any], num_sequences: int
) -> Dict[tf.Tensor, Any]:
"""
Builds the feed dict for updating the SAC model.
:param model: The model to update. May be different when, e.g. using multi-GPU.
:param mini_batch: Mini-batch to use to update.
:param num_sequences: Number of LSTM sequences in mini_batch.
"""
feed_dict = {
self.model.batch_size: num_sequences,
self.model.sequence_length: self.sequence_length,
self.model.next_sequence_length: self.sequence_length,
self.model.mask_input: mini_batch["masks"],
}
for name in self.reward_signals:
feed_dict[model.rewards_holders[name]] = mini_batch[
"{}_rewards".format(name)
]
if self.use_continuous_act:
feed_dict[model.action_holder] = mini_batch["actions"]
else:
feed_dict[model.action_holder] = mini_batch["actions"]
if self.use_recurrent:
feed_dict[model.prev_action] = mini_batch["prev_action"]
feed_dict[model.action_masks] = mini_batch["action_mask"]
if self.use_vec_obs:
feed_dict[model.vector_in] = mini_batch["vector_obs"]
feed_dict[model.next_vector_in] = mini_batch["next_vector_in"]
if self.model.vis_obs_size > 0:
for i, _ in enumerate(model.visual_in):
_obs = mini_batch["visual_obs%d" % i]
feed_dict[model.visual_in[i]] = _obs
for i, _ in enumerate(model.next_visual_in):
_obs = mini_batch["next_visual_obs%d" % i]
feed_dict[model.next_visual_in[i]] = _obs
if self.use_recurrent:
mem_in = [
mini_batch["memory"][i]
for i in range(0, len(mini_batch["memory"]), self.sequence_length)
]
# LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
offset = 1 if self.sequence_length > 1 else 0
next_mem_in = [
mini_batch["memory"][i][
: self.m_size // 4
] # only pass value part of memory to target network
for i in range(offset, len(mini_batch["memory"]), self.sequence_length)
]
feed_dict[model.memory_in] = mem_in
feed_dict[model.next_memory_in] = next_mem_in
feed_dict[model.dones_holder] = mini_batch["done"]
return feed_dict