|
|
|
|
|
|
) |
|
|
|
self.step = 0 |
|
|
|
|
|
|
|
# Don't count buffer_init_steps in steps_per_update ratio, but also don't divide-by-0 |
|
|
|
self.update_steps = max(1, self.hyperparameters.buffer_init_steps) |
|
|
|
self.reward_signal_update_steps = max(1, self.hyperparameters.buffer_init_steps) |
|
|
|
# Don't divide by zero |
|
|
|
self.update_steps = 1 |
|
|
|
self.reward_signal_update_steps = 1 |
|
|
|
|
|
|
|
self.steps_per_update = self.hyperparameters.steps_per_update |
|
|
|
self.reward_signal_steps_per_update = ( |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
batch_update_stats: Dict[str, list] = defaultdict(list) |
|
|
|
while self.step / self.update_steps > self.steps_per_update: |
|
|
|
while ( |
|
|
|
self.step - self.hyperparameters.buffer_init_steps |
|
|
|
) / self.update_steps > self.steps_per_update: |
|
|
|
logger.debug("Updating SAC policy at step {}".format(self.step)) |
|
|
|
buffer = self.update_buffer |
|
|
|
if self.update_buffer.num_experiences >= self.hyperparameters.batch_size: |
|
|
|
|
|
|
) |
|
|
|
batch_update_stats: Dict[str, list] = defaultdict(list) |
|
|
|
while ( |
|
|
|
self.step / self.reward_signal_update_steps |
|
|
|
> self.reward_signal_steps_per_update |
|
|
|
): |
|
|
|
self.step - self.hyperparameters.buffer_init_steps |
|
|
|
) / self.reward_signal_update_steps > self.reward_signal_steps_per_update: |
|
|
|
# Get minibatches for reward signal update if needed |
|
|
|
reward_signal_minibatches = {} |
|
|
|
for name, signal in self.optimizer.reward_signals.items(): |
|
|
|
|
|
|
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) |
|
|
|
# Needed to resume loads properly |
|
|
|
self.step = policy.get_current_step() |
|
|
|
# Assume steps were updated at the correct ratio before |
|
|
|
self.update_steps = int(max(1, self.step / self.steps_per_update)) |
|
|
|
self.reward_signal_update_steps = int( |
|
|
|
max(1, self.step / self.reward_signal_steps_per_update) |
|
|
|
) |
|
|
|
self.next_summary_step = self._get_next_summary_step() |
|
|
|
|
|
|
|
def get_policy(self, name_behavior_id: str) -> TFPolicy: |
|
|
|