|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
BUFFER_TRUNCATE_PERCENT = 0.8 |
|
|
|
DEFAULT_STEPS_PER_UPDATE = 1 |
|
|
|
|
|
|
|
|
|
|
|
class SACTrainer(RLTrainer): |
|
|
|
|
|
|
"init_entcoef", |
|
|
|
"max_steps", |
|
|
|
"normalize", |
|
|
|
"num_update", |
|
|
|
"steps_per_update", |
|
|
|
"sequence_length", |
|
|
|
"summary_freq", |
|
|
|
"tau", |
|
|
|
|
|
|
self.optimizer: SACOptimizer = None # type: ignore |
|
|
|
|
|
|
|
self.step = 0 |
|
|
|
self.train_interval = ( |
|
|
|
trainer_parameters["train_interval"] |
|
|
|
if "train_interval" in trainer_parameters |
|
|
|
else 1 |
|
|
|
|
|
|
|
# Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0 |
|
|
|
self.update_steps = max(1, self.trainer_parameters["buffer_init_steps"]) |
|
|
|
self.reward_signal_update_steps = max( |
|
|
|
1, self.trainer_parameters["buffer_init_steps"] |
|
|
|
self.reward_signal_updates_per_train = ( |
|
|
|
trainer_parameters["reward_signals"]["reward_signal_num_update"] |
|
|
|
if "reward_signal_num_update" in trainer_parameters["reward_signals"] |
|
|
|
else trainer_parameters["num_update"] |
|
|
|
|
|
|
|
self.steps_per_update = ( |
|
|
|
trainer_parameters["steps_per_update"] |
|
|
|
if "steps_per_update" in trainer_parameters |
|
|
|
else DEFAULT_STEPS_PER_UPDATE |
|
|
|
) |
|
|
|
self.reward_signal_steps_per_update = ( |
|
|
|
trainer_parameters["reward_signals"]["reward_signal_steps_per_update"] |
|
|
|
if "reward_signal_steps_per_update" in trainer_parameters["reward_signals"] |
|
|
|
else self.steps_per_update |
|
|
|
) |
|
|
|
|
|
|
|
self.checkpoint_replay_buffer = ( |
|
|
|
|
|
|
def _is_ready_update(self) -> bool: |
|
|
|
""" |
|
|
|
Returns whether or not the trainer has enough elements to run update model |
|
|
|
:return: A boolean corresponding to whether or not update_model() can be run |
|
|
|
:return: A boolean corresponding to whether or not _update_policy() can be run |
|
|
|
""" |
|
|
|
return ( |
|
|
|
self.update_buffer.num_experiences >= self.trainer_parameters["batch_size"] |
|
|
|
|
|
|
@timed |
|
|
|
def _update_policy(self) -> None: |
|
|
|
def _update_policy(self) -> bool: |
|
|
|
If train_interval is met, update the SAC policy given the current reward signals. |
|
|
|
If reward_signal_train_interval is met, update the reward signals from the buffer. |
|
|
|
Update the SAC policy and reward signals. The reward signal generators are updated using different mini batches. |
|
|
|
By default we imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated |
|
|
|
N times, then the reward signals are updated N times. |
|
|
|
:return: Whether or not the policy was updated. |
|
|
|
if self.step % self.train_interval == 0: |
|
|
|
self.update_sac_policy() |
|
|
|
self.update_reward_signals() |
|
|
|
policy_was_updated = self._update_sac_policy() |
|
|
|
self._update_reward_signals() |
|
|
|
return policy_was_updated |
|
|
|
|
|
|
|
def create_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, brain_parameters: BrainParameters |
|
|
|
|
|
|
|
|
|
|
return policy |
|
|
|
|
|
|
|
def update_sac_policy(self) -> None: |
|
|
|
def _update_sac_policy(self) -> bool: |
|
|
|
Uses demonstration_buffer to update the policy. |
|
|
|
The reward signal generators are updated using different mini batches. |
|
|
|
If we want to imitate http://arxiv.org/abs/1809.02925 and similar papers, where the policy is updated |
|
|
|
N times, then the reward signals are updated N times, then reward_signal_updates_per_train |
|
|
|
is greater than 1 and the reward signals are not updated in parallel. |
|
|
|
Uses update_buffer to update the policy. We sample the update_buffer and update |
|
|
|
until the steps_per_update ratio is met. |
|
|
|
|
|
|
|
has_updated = False |
|
|
|
num_updates = self.trainer_parameters["num_update"] |
|
|
|
for _ in range(num_updates): |
|
|
|
while self.step / self.update_steps > self.steps_per_update: |
|
|
|
logger.debug("Updating SAC policy at step {}".format(self.step)) |
|
|
|
buffer = self.update_buffer |
|
|
|
if ( |
|
|
|
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
|
|
|
|
self.update_steps += 1 |
|
|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
has_updated = True |
|
|
|
|
|
|
|
if self.optimizer.bc_module: |
|
|
|
update_stats = self.optimizer.bc_module.update() |
|
|
|
for stat, val in update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, val) |
|
|
|
|
|
|
|
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating |
|
|
|
# a large buffer at each update. |
|
|
|
if self.update_buffer.num_experiences > self.trainer_parameters["buffer_size"]: |
|
|
|
|
|
|
return has_updated |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|
if self.optimizer.bc_module: |
|
|
|
update_stats = self.optimizer.bc_module.update() |
|
|
|
for stat, val in update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, val) |
|
|
|
|
|
|
|
def update_reward_signals(self) -> None: |
|
|
|
def _update_reward_signals(self) -> None: |
|
|
|
""" |
|
|
|
Iterate through the reward signals and update them. Unlike in PPO, |
|
|
|
do it separate from the policy so that it can be done at a different |
|
|
|
|
|
|
and policy are updated in parallel. |
|
|
|
""" |
|
|
|
buffer = self.update_buffer |
|
|
|
num_updates = self.reward_signal_updates_per_train |
|
|
|
for _ in range(num_updates): |
|
|
|
while ( |
|
|
|
self.step / self.reward_signal_update_steps |
|
|
|
> self.reward_signal_steps_per_update |
|
|
|
): |
|
|
|
# Get minibatches for reward signal update if needed |
|
|
|
reward_signal_minibatches = {} |
|
|
|
for name, signal in self.optimizer.reward_signals.items(): |
|
|
|
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
self.reward_signal_update_steps += 1 |
|
|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|
def add_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy |
|
|
|