|
|
|
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
def should_still_train(self) -> bool: |
|
|
|
""" |
|
|
|
Returns whether or not the trainer should train. A Trainer could |
|
|
|
stop training if it wasn't training to begin with, or if max_steps |
|
|
|
is reached. |
|
|
|
""" |
|
|
|
return self.is_training and self.steps_per_update * self.update_steps <= self.get_max_steps |
|
|
|
|
|
|
|
super()._process_trajectory(trajectory) |
|
|
|
# super()._process_trajectory(trajectory) |
|
|
|
self._maybe_write_summary(self.get_step + int(self.steps_per_update)) |
|
|
|
self._maybe_save_model(self.get_step + int(self.steps_per_update)) |
|
|
|
self._increment_step(self.hyperparameters.buffer_size, self.brain_name) |
|
|
|
|
|
|
|
last_step = trajectory.steps[-1] |
|
|
|
agent_id = trajectory.agent_id # All the agents should have the same ID |
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
has_updated = False |
|
|
|
self.cumulative_returns_since_policy_update.clear() |
|
|
|
self._maybe_write_summary(self.get_step + int(self.steps_per_update)) |
|
|
|
self._maybe_save_model(self.get_step + int(self.steps_per_update)) |
|
|
|
self._increment_step(self.hyperparameters.buffer_size, self.brain_name) |
|
|
|
|
|
|
|
n_sequences = max( |
|
|
|
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 |
|
|
|
) |
|
|
|
|
|
|
self.step - self.hyperparameters.buffer_init_steps |
|
|
|
) / self.update_steps > self.steps_per_update: |
|
|
|
|
|
|
|
logger.debug("Updating SAC policy at step {}".format(self.step)) |
|
|
|
buffer = self.update_buffer |
|
|
|
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: |
|
|
|