|
|
|
|
|
|
from mlagents.trainers.buffer import Buffer |
|
|
|
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult |
|
|
|
from mlagents.trainers.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.models import LearningModel |
|
|
|
from .model import GAILModel |
|
|
|
from mlagents.trainers.demo_loader import demo_to_buffer |
|
|
|
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
policy: TFPolicy, |
|
|
|
policy_model: LearningModel, |
|
|
|
num_epoch: int = 3, |
|
|
|
samples_per_update: int = 0, |
|
|
|
use_actions: bool = False, |
|
|
|
use_vail: bool = False, |
|
|
|
): |
|
|
|
|
|
|
:param num_epoch: The number of epochs to train over the training buffer for the discriminator. |
|
|
|
:param encoding_size: The size of the the hidden layers of the discriminator |
|
|
|
:param learning_rate: The Learning Rate used during GAIL updates. |
|
|
|
:param samples_per_update: The maximum number of samples to update during GAIL updates. |
|
|
|
super().__init__(policy, strength, gamma) |
|
|
|
self.num_epoch = num_epoch |
|
|
|
self.samples_per_update = samples_per_update |
|
|
|
super().__init__(policy, policy_model, strength, gamma) |
|
|
|
self.use_terminal_states = False |
|
|
|
|
|
|
|
self.model = GAILModel( |
|
|
|
|
|
|
self.has_updated = False |
|
|
|
self.update_dict: Dict[str, tf.Tensor] = { |
|
|
|
"gail_loss": self.model.loss, |
|
|
|
"gail_update_batch": self.model.update_batch, |
|
|
|
"gail_policy_estimate": self.model.policy_estimate, |
|
|
|
"gail_expert_estimate": self.model.expert_estimate, |
|
|
|
} |
|
|
|
if self.model.use_vail: |
|
|
|
self.update_dict["kl_loss"] = self.model.kl_loss |
|
|
|
self.update_dict["z_log_sigma_sq"] = self.model.z_log_sigma_sq |
|
|
|
self.update_dict["z_mean_expert"] = self.model.z_mean_expert |
|
|
|
self.update_dict["z_mean_policy"] = self.model.z_mean_policy |
|
|
|
self.update_dict["beta_update"] = self.model.update_beta |
|
|
|
|
|
|
|
self.stats_name_to_update_name = {"Losses/GAIL Loss": "gail_loss"} |
|
|
|
|
|
|
|
def evaluate( |
|
|
|
self, current_info: BrainInfo, next_info: BrainInfo |
|
|
|
|
|
|
param_keys = ["strength", "gamma", "demo_path"] |
|
|
|
super().check_config(config_dict, param_keys) |
|
|
|
|
|
|
|
def update(self, update_buffer: Buffer, n_sequences: int) -> Dict[str, float]: |
|
|
|
""" |
|
|
|
Updates model using buffer. |
|
|
|
:param update_buffer: The policy buffer containing the trajectories for the current policy. |
|
|
|
:param n_sequences: The number of sequences from demo and policy used in each mini batch. |
|
|
|
:return: The loss of the update. |
|
|
|
""" |
|
|
|
batch_losses = [] |
|
|
|
# Divide by 2 since we have two buffers, so we have roughly the same batch size |
|
|
|
n_sequences = max(n_sequences // 2, 1) |
|
|
|
possible_demo_batches = ( |
|
|
|
len(self.demonstration_buffer.update_buffer["actions"]) // n_sequences |
|
|
|
) |
|
|
|
possible_policy_batches = len(update_buffer["actions"]) // n_sequences |
|
|
|
possible_batches = min(possible_policy_batches, possible_demo_batches) |
|
|
|
|
|
|
|
max_batches = self.samples_per_update // n_sequences |
|
|
|
|
|
|
|
kl_loss = [] |
|
|
|
policy_estimate = [] |
|
|
|
expert_estimate = [] |
|
|
|
z_log_sigma_sq = [] |
|
|
|
z_mean_expert = [] |
|
|
|
z_mean_policy = [] |
|
|
|
|
|
|
|
n_epoch = self.num_epoch |
|
|
|
for _epoch in range(n_epoch): |
|
|
|
self.demonstration_buffer.update_buffer.shuffle( |
|
|
|
sequence_length=self.policy.sequence_length |
|
|
|
) |
|
|
|
update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
if max_batches == 0: |
|
|
|
num_batches = possible_batches |
|
|
|
else: |
|
|
|
num_batches = min(possible_batches, max_batches) |
|
|
|
for i in range(num_batches): |
|
|
|
demo_update_buffer = self.demonstration_buffer.update_buffer |
|
|
|
policy_update_buffer = update_buffer |
|
|
|
start = i * n_sequences |
|
|
|
end = (i + 1) * n_sequences |
|
|
|
mini_batch_demo = demo_update_buffer.make_mini_batch(start, end) |
|
|
|
mini_batch_policy = policy_update_buffer.make_mini_batch(start, end) |
|
|
|
run_out = self._update_batch(mini_batch_demo, mini_batch_policy) |
|
|
|
loss = run_out["gail_loss"] |
|
|
|
|
|
|
|
policy_estimate.append(run_out["policy_estimate"]) |
|
|
|
expert_estimate.append(run_out["expert_estimate"]) |
|
|
|
if self.model.use_vail: |
|
|
|
kl_loss.append(run_out["kl_loss"]) |
|
|
|
z_log_sigma_sq.append(run_out["z_log_sigma_sq"]) |
|
|
|
z_mean_policy.append(run_out["z_mean_policy"]) |
|
|
|
z_mean_expert.append(run_out["z_mean_expert"]) |
|
|
|
|
|
|
|
batch_losses.append(loss) |
|
|
|
self.has_updated = True |
|
|
|
|
|
|
|
print_list = ["n_epoch", "beta", "policy_estimate", "expert_estimate"] |
|
|
|
print_vals = [ |
|
|
|
n_epoch, |
|
|
|
self.policy.sess.run(self.model.beta), |
|
|
|
np.mean(policy_estimate), |
|
|
|
np.mean(expert_estimate), |
|
|
|
] |
|
|
|
if self.model.use_vail: |
|
|
|
print_list += [ |
|
|
|
"kl_loss", |
|
|
|
"z_mean_expert", |
|
|
|
"z_mean_policy", |
|
|
|
"z_log_sigma_sq", |
|
|
|
] |
|
|
|
print_vals += [ |
|
|
|
np.mean(kl_loss), |
|
|
|
np.mean(z_mean_expert), |
|
|
|
np.mean(z_mean_policy), |
|
|
|
np.mean(z_log_sigma_sq), |
|
|
|
] |
|
|
|
LOGGER.debug( |
|
|
|
"GAIL Debug:\n\t\t" |
|
|
|
+ "\n\t\t".join( |
|
|
|
"{0}: {1}".format(_name, _val) |
|
|
|
for _name, _val in zip(print_list, print_vals) |
|
|
|
) |
|
|
|
) |
|
|
|
update_stats = {"Losses/GAIL Loss": np.mean(batch_losses)} |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def _update_batch( |
|
|
|
def prepare_update( |
|
|
|
mini_batch_demo: Dict[str, np.ndarray], |
|
|
|
policy_model: LearningModel, |
|
|
|
) -> Dict[str, float]: |
|
|
|
num_sequences: int, |
|
|
|
) -> Dict[tf.Tensor, Any]: |
|
|
|
Helper method for update. |
|
|
|
Prepare inputs for update. . |
|
|
|
:return: Output from update process. |
|
|
|
:return: Feed_dict for update process. |
|
|
|
max_num_experiences = min( |
|
|
|
len(mini_batch_policy["actions"]), |
|
|
|
len(self.demonstration_buffer.update_buffer["actions"]), |
|
|
|
) |
|
|
|
# If num_sequences is less, we need to shorten the input batch. |
|
|
|
for key, element in mini_batch_policy.items(): |
|
|
|
mini_batch_policy[key] = element[:max_num_experiences] |
|
|
|
# Get demo buffer |
|
|
|
self.demonstration_buffer.update_buffer.shuffle(1) |
|
|
|
# TODO: Replace with SAC sample method |
|
|
|
mini_batch_demo = self.demonstration_buffer.update_buffer.make_mini_batch( |
|
|
|
0, len(mini_batch_policy["actions"]) |
|
|
|
) |
|
|
|
|
|
|
|
feed_dict: Dict[tf.Tensor, Any] = { |
|
|
|
self.model.done_expert_holder: mini_batch_demo["done"], |
|
|
|
self.model.done_policy_holder: mini_batch_policy["done"], |
|
|
|
|
|
|
|
|
|
|
feed_dict[self.model.action_in_expert] = np.array(mini_batch_demo["actions"]) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
feed_dict[self.policy.model.selected_actions] = mini_batch_policy["actions"] |
|
|
|
feed_dict[policy_model.selected_actions] = mini_batch_policy["actions"] |
|
|
|
feed_dict[self.policy.model.action_holder] = mini_batch_policy["actions"] |
|
|
|
feed_dict[policy_model.action_holder] = mini_batch_policy["actions"] |
|
|
|
for i in range(len(self.policy.model.visual_in)): |
|
|
|
feed_dict[self.policy.model.visual_in[i]] = mini_batch_policy[ |
|
|
|
for i in range(len(policy_model.visual_in)): |
|
|
|
feed_dict[policy_model.visual_in[i]] = mini_batch_policy[ |
|
|
|
"visual_obs%d" % i |
|
|
|
] |
|
|
|
feed_dict[self.model.expert_visual_in[i]] = mini_batch_demo[ |
|
|
|
|
|
|
feed_dict[self.policy.model.vector_in] = mini_batch_policy["vector_obs"] |
|
|
|
feed_dict[policy_model.vector_in] = mini_batch_policy["vector_obs"] |
|
|
|
|
|
|
|
out_dict = { |
|
|
|
"gail_loss": self.model.loss, |
|
|
|
"update_batch": self.model.update_batch, |
|
|
|
"policy_estimate": self.model.policy_estimate, |
|
|
|
"expert_estimate": self.model.expert_estimate, |
|
|
|
} |
|
|
|
if self.model.use_vail: |
|
|
|
out_dict["kl_loss"] = self.model.kl_loss |
|
|
|
out_dict["z_log_sigma_sq"] = self.model.z_log_sigma_sq |
|
|
|
out_dict["z_mean_expert"] = self.model.z_mean_expert |
|
|
|
out_dict["z_mean_policy"] = self.model.z_mean_policy |
|
|
|
|
|
|
|
run_out = self.policy.sess.run(out_dict, feed_dict=feed_dict) |
|
|
|
if self.model.use_vail: |
|
|
|
self.update_beta(run_out["kl_loss"]) |
|
|
|
return run_out |
|
|
|
|
|
|
|
def update_beta(self, kl_div: float) -> None: |
|
|
|
""" |
|
|
|
Updates the Beta parameter with the latest kl_divergence value. |
|
|
|
The larger Beta, the stronger the importance of the kl divergence in the loss function. |
|
|
|
:param kl_div: The KL divergence |
|
|
|
""" |
|
|
|
self.policy.sess.run( |
|
|
|
self.model.update_beta, feed_dict={self.model.kl_div_input: kl_div} |
|
|
|
) |
|
|
|
self.has_updated = True |
|
|
|
return feed_dict |