|
|
|
|
|
|
if "save_replay_buffer" in trainer_parameters |
|
|
|
else False |
|
|
|
) |
|
|
|
self.policy = SACPolicy(seed, brain, trainer_parameters, self.is_training, load) |
|
|
|
self.sac_policy = SACPolicy( |
|
|
|
seed, brain, trainer_parameters, self.is_training, load |
|
|
|
) |
|
|
|
self.policy = self.sac_policy |
|
|
|
|
|
|
|
# Load the replay buffer if load |
|
|
|
if load and self.checkpoint_replay_buffer: |
|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self.stats[stat].append(np.mean(stat_list)) |
|
|
|
|
|
|
|
if self.policy.bc_module: |
|
|
|
update_stats = self.policy.bc_module.update() |
|
|
|
bc_module = self.sac_policy.bc_module |
|
|
|
if bc_module: |
|
|
|
update_stats = bc_module.update() |
|
|
|
for stat, val in update_stats.items(): |
|
|
|
self.stats[stat].append(val) |
|
|
|
|
|
|
|
|
|
|
self.trainer_parameters["batch_size"], |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
update_stats = self.policy.update_reward_signals( |
|
|
|
update_stats = self.sac_policy.update_reward_signals( |
|
|
|
reward_signal_minibatches, n_sequences |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|