Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

270 行
11 KiB

from typing import Any, Dict, List
import logging
import numpy as np
import tensorflow as tf
from mlagents.envs.brain import BrainInfo
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.tf_policy import TFPolicy
from .model import GAILModel
from mlagents.trainers.demo_loader import demo_to_buffer
LOGGER = logging.getLogger("mlagents.trainers")
class GAILRewardSignal(RewardSignal):
def __init__(
self,
policy: TFPolicy,
strength: float,
gamma: float,
demo_path: str,
num_epoch: int = 3,
encoding_size: int = 64,
learning_rate: float = 3e-4,
samples_per_update: int = 0,
use_actions: bool = False,
use_vail: bool = False,
):
"""
The GAIL Reward signal generator. https://arxiv.org/abs/1606.03476
:param policy: The policy of the learning model
:param strength: The scaling parameter for the reward. The scaled reward will be the unscaled
reward multiplied by the strength parameter
:param gamma: The time discounting factor used for this reward.
:param demo_path: The path to the demonstration file
:param encoding_size: The size of the the hidden layers of the discriminator
:param learning_rate: The Learning Rate used during GAIL updates.
:param samples_per_update: The maximum number of samples to update during GAIL updates.
:param use_actions: Whether or not to use the actions for the discriminator.
:param use_vail: Whether or not to use a variational bottleneck for the discriminator.
See https://arxiv.org/abs/1810.00821.
"""
super().__init__(policy, strength, gamma)
self.num_epoch = num_epoch
self.samples_per_update = samples_per_update
self.model = GAILModel(
policy.model, 128, learning_rate, encoding_size, use_actions, use_vail
)
_, self.demonstration_buffer = demo_to_buffer(demo_path, policy.sequence_length)
self.has_updated = False
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
if len(current_info.agents) == 0:
return []
feed_dict: Dict[tf.Tensor, Any] = {
self.policy.model.batch_size: len(next_info.vector_observations),
self.policy.model.sequence_length: 1,
}
if self.model.use_vail:
feed_dict[self.model.use_noise] = [0]
feed_dict = self.policy.fill_eval_dict(feed_dict, brain_info=current_info)
feed_dict[self.model.done_policy] = np.reshape(next_info.local_done, [-1, 1])
if self.policy.use_continuous_act:
feed_dict[
self.policy.model.selected_actions
] = next_info.previous_vector_actions
else:
feed_dict[
self.policy.model.action_holder
] = next_info.previous_vector_actions
if self.policy.use_recurrent:
if current_info.memories.shape[1] == 0:
current_info.memories = self.policy.make_empty_memory(
len(current_info.agents)
)
feed_dict[self.policy.model.memory_in] = current_info.memories
unscaled_reward = self.policy.sess.run(
self.model.intrinsic_reward, feed_dict=feed_dict
)
scaled_reward = unscaled_reward * float(self.has_updated) * self.strength
return RewardSignalResult(scaled_reward, unscaled_reward)
@classmethod
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Checks the config and throw an exception if a hyperparameter is missing. GAIL requires strength and gamma
at minimum.
"""
param_keys = ["strength", "gamma", "demo_path"]
super().check_config(config_dict, param_keys)
def update(self, update_buffer: Buffer, n_sequences: int) -> Dict[str, float]:
"""
Updates model using buffer.
:param update_buffer: The policy buffer containing the trajectories for the current policy.
:param n_sequences: The number of sequences from demo and policy used in each mini batch.
:return: The loss of the update.
"""
batch_losses = []
# Divide by 2 since we have two buffers, so we have roughly the same batch size
n_sequences = max(n_sequences // 2, 1)
possible_demo_batches = (
len(self.demonstration_buffer.update_buffer["actions"]) // n_sequences
)
possible_policy_batches = len(update_buffer["actions"]) // n_sequences
possible_batches = min(possible_policy_batches, possible_demo_batches)
max_batches = self.samples_per_update // n_sequences
kl_loss = []
policy_estimate = []
expert_estimate = []
z_log_sigma_sq = []
z_mean_expert = []
z_mean_policy = []
n_epoch = self.num_epoch
for _epoch in range(n_epoch):
self.demonstration_buffer.update_buffer.shuffle()
update_buffer.shuffle()
if max_batches == 0:
num_batches = possible_batches
else:
num_batches = min(possible_batches, max_batches)
for i in range(num_batches):
demo_update_buffer = self.demonstration_buffer.update_buffer
policy_update_buffer = update_buffer
start = i * n_sequences
end = (i + 1) * n_sequences
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end)
mini_batch_policy = policy_update_buffer.make_mini_batch(start, end)
run_out = self._update_batch(mini_batch_demo, mini_batch_policy)
loss = run_out["gail_loss"]
policy_estimate.append(run_out["policy_estimate"])
expert_estimate.append(run_out["expert_estimate"])
if self.model.use_vail:
kl_loss.append(run_out["kl_loss"])
z_log_sigma_sq.append(run_out["z_log_sigma_sq"])
z_mean_policy.append(run_out["z_mean_policy"])
z_mean_expert.append(run_out["z_mean_expert"])
batch_losses.append(loss)
self.has_updated = True
print_list = ["n_epoch", "beta", "policy_estimate", "expert_estimate"]
print_vals = [
n_epoch,
self.policy.sess.run(self.model.beta),
np.mean(policy_estimate),
np.mean(expert_estimate),
]
if self.model.use_vail:
print_list += [
"kl_loss",
"z_mean_expert",
"z_mean_policy",
"z_log_sigma_sq",
]
print_vals += [
np.mean(kl_loss),
np.mean(z_mean_expert),
np.mean(z_mean_policy),
np.mean(z_log_sigma_sq),
]
LOGGER.debug(
"GAIL Debug:\n\t\t"
+ "\n\t\t".join(
"{0}: {1}".format(_name, _val)
for _name, _val in zip(print_list, print_vals)
)
)
update_stats = {"Losses/GAIL Loss": np.mean(batch_losses)}
return update_stats
def _update_batch(
self,
mini_batch_demo: Dict[str, np.ndarray],
mini_batch_policy: Dict[str, np.ndarray],
) -> Dict[str, float]:
"""
Helper method for update.
:param mini_batch_demo: A mini batch of expert trajectories
:param mini_batch_policy: A mini batch of trajectories sampled from the current policy
:return: Output from update process.
"""
feed_dict: Dict[tf.Tensor, Any] = {
self.model.done_expert: mini_batch_demo["done"].reshape([-1, 1]),
self.model.done_policy: mini_batch_policy["done"].reshape([-1, 1]),
}
if self.model.use_vail:
feed_dict[self.model.use_noise] = [1]
if self.policy.use_continuous_act:
feed_dict[self.policy.model.selected_actions] = mini_batch_policy[
"actions"
].reshape([-1, self.policy.model.act_size[0]])
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
[-1, self.policy.model.act_size[0]]
)
else:
feed_dict[self.policy.model.action_holder] = mini_batch_policy[
"actions"
].reshape([-1, len(self.policy.model.act_size)])
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"].reshape(
[-1, len(self.policy.model.act_size)]
)
if self.policy.use_vis_obs > 0:
for i in range(len(self.policy.model.visual_in)):
policy_obs = mini_batch_policy["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = policy_obs.shape
feed_dict[self.policy.model.visual_in[i]] = policy_obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.policy.model.visual_in[i]] = policy_obs
demo_obs = mini_batch_demo["visual_obs%d" % i]
if self.policy.sequence_length > 1 and self.policy.use_recurrent:
(_batch, _seq, _w, _h, _c) = demo_obs.shape
feed_dict[self.model.expert_visual_in[i]] = demo_obs.reshape(
[-1, _w, _h, _c]
)
else:
feed_dict[self.model.expert_visual_in[i]] = demo_obs
if self.policy.use_vec_obs:
feed_dict[self.policy.model.vector_in] = mini_batch_policy[
"vector_obs"
].reshape([-1, self.policy.vec_obs_size])
feed_dict[self.model.obs_in_expert] = mini_batch_demo["vector_obs"].reshape(
[-1, self.policy.vec_obs_size]
)
out_dict = {
"gail_loss": self.model.loss,
"update_batch": self.model.update_batch,
"policy_estimate": self.model.policy_estimate,
"expert_estimate": self.model.expert_estimate,
}
if self.model.use_vail:
out_dict["kl_loss"] = self.model.kl_loss
out_dict["z_log_sigma_sq"] = self.model.z_log_sigma_sq
out_dict["z_mean_expert"] = self.model.z_mean_expert
out_dict["z_mean_policy"] = self.model.z_mean_policy
run_out = self.policy.sess.run(out_dict, feed_dict=feed_dict)
if self.model.use_vail:
self.update_beta(run_out["kl_loss"])
return run_out
def update_beta(self, kl_div: float) -> None:
"""
Updates the Beta parameter with the latest kl_divergence value.
The larger Beta, the stronger the importance of the kl divergence in the loss function.
:param kl_div: The KL divergence
"""
self.policy.sess.run(
self.model.update_beta, feed_dict={self.model.kl_div_input: kl_div}
)