|
|
|
|
|
|
self.optimizer: SACOptimizer = None # type: ignore |
|
|
|
|
|
|
|
self.step = 0 |
|
|
|
self.train_interval = ( |
|
|
|
trainer_parameters["train_interval"] |
|
|
|
if "train_interval" in trainer_parameters |
|
|
|
self.update_steps = 0 |
|
|
|
self.reward_signal_update_steps = 0 |
|
|
|
|
|
|
|
self.steps_per_update = ( |
|
|
|
trainer_parameters["steps_per_update"] |
|
|
|
if "steps_per_update" in trainer_parameters |
|
|
|
self.reward_signal_updates_per_train = ( |
|
|
|
trainer_parameters["reward_signals"]["reward_signal_num_update"] |
|
|
|
if "reward_signal_num_update" in trainer_parameters["reward_signals"] |
|
|
|
else trainer_parameters["num_update"] |
|
|
|
self.reward_signal_steps_per_update = ( |
|
|
|
trainer_parameters["reward_signals"]["reward_signal_steps_per_update"] |
|
|
|
if "reward_signal_steps_per_update" in trainer_parameters["reward_signals"] |
|
|
|
else self.steps_per_update |
|
|
|
) |
|
|
|
|
|
|
|
self.checkpoint_replay_buffer = ( |
|
|
|
|
|
|
If train_interval is met, update the SAC policy given the current reward signals. |
|
|
|
If reward_signal_train_interval is met, update the reward signals from the buffer. |
|
|
|
""" |
|
|
|
if self.step % self.train_interval == 0: |
|
|
|
self.update_sac_policy() |
|
|
|
self.update_reward_signals() |
|
|
|
self.update_sac_policy() |
|
|
|
self.update_reward_signals() |
|
|
|
|
|
|
|
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy: |
|
|
|
policy = NNPolicy( |
|
|
|
|
|
|
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1 |
|
|
|
) |
|
|
|
|
|
|
|
num_updates = self.trainer_parameters["num_update"] |
|
|
|
for _ in range(num_updates): |
|
|
|
while self.step / self.update_steps > self.steps_per_update: |
|
|
|
logger.debug("Updating SAC policy at step {}".format(self.step)) |
|
|
|
buffer = self.update_buffer |
|
|
|
if ( |
|
|
|
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
|
|
|
|
self.update_steps += 1 |
|
|
|
|
|
|
|
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating |
|
|
|
# a large buffer at each update. |
|
|
|
if self.update_buffer.num_experiences > self.trainer_parameters["buffer_size"]: |
|
|
|
|
|
|
and policy are updated in parallel. |
|
|
|
""" |
|
|
|
buffer = self.update_buffer |
|
|
|
num_updates = self.reward_signal_updates_per_train |
|
|
|
for _ in range(num_updates): |
|
|
|
while self.step / self.reward_signal_update_steps > self.steps_per_update: |
|
|
|
# Get minibatches for reward signal update if needed |
|
|
|
reward_signal_minibatches = {} |
|
|
|
for name, signal in self.optimizer.reward_signals.items(): |
|
|
|
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
self.reward_signal_update_steps += 1 |
|
|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|