|
|
|
|
|
|
self.hyperparameters: SACSettings = cast( |
|
|
|
SACSettings, trainer_settings.hyperparameters |
|
|
|
) |
|
|
|
self.step = 0 |
|
|
|
self._step = 0 |
|
|
|
|
|
|
|
# Don't divide by zero |
|
|
|
self.update_steps = 1 |
|
|
|
|
|
|
""" |
|
|
|
return ( |
|
|
|
self.update_buffer.num_experiences >= self.hyperparameters.batch_size |
|
|
|
and self.step >= self.hyperparameters.buffer_init_steps |
|
|
|
and self._step >= self.hyperparameters.buffer_init_steps |
|
|
|
) |
|
|
|
|
|
|
|
@timed |
|
|
|
|
|
|
|
|
|
|
batch_update_stats: Dict[str, list] = defaultdict(list) |
|
|
|
while ( |
|
|
|
self.step - self.hyperparameters.buffer_init_steps |
|
|
|
self._step - self.hyperparameters.buffer_init_steps |
|
|
|
logger.debug(f"Updating SAC policy at step {self.step}") |
|
|
|
logger.debug(f"Updating SAC policy at step {self._step}") |
|
|
|
buffer = self.update_buffer |
|
|
|
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: |
|
|
|
sampled_minibatch = buffer.sample_mini_batch( |
|
|
|
|
|
|
) |
|
|
|
batch_update_stats: Dict[str, list] = defaultdict(list) |
|
|
|
while ( |
|
|
|
self.step - self.hyperparameters.buffer_init_steps |
|
|
|
self._step - self.hyperparameters.buffer_init_steps |
|
|
|
logger.debug(f"Updating {name} at step {self.step}") |
|
|
|
logger.debug(f"Updating {name} at step {self._step}") |
|
|
|
if name != "extrinsic": |
|
|
|
reward_signal_minibatches[name] = buffer.sample_mini_batch( |
|
|
|
self.hyperparameters.batch_size, |
|
|
|
|
|
|
self.model_saver.initialize_or_load() |
|
|
|
|
|
|
|
# Needed to resume loads properly |
|
|
|
self.step = policy.get_current_step() |
|
|
|
self._step = policy.get_current_step() |
|
|
|
self.update_steps = int(max(1, self.step / self.steps_per_update)) |
|
|
|
self.update_steps = int(max(1, self._step / self.steps_per_update)) |
|
|
|
max(1, self.step / self.reward_signal_steps_per_update) |
|
|
|
max(1, self._step / self.reward_signal_steps_per_update) |
|
|
|
) |
|
|
|
|
|
|
|
def get_policy(self, name_behavior_id: str) -> Policy: |
|
|
|