|
|
|
|
|
|
) |
|
|
|
|
|
|
|
@timed |
|
|
|
def _update_policy(self) -> None: |
|
|
|
def _update_policy(self) -> bool: |
|
|
|
:return: Whether or not the policy was updated. |
|
|
|
self.update_sac_policy() |
|
|
|
self.update_reward_signals() |
|
|
|
policy_was_updated = self._update_sac_policy() |
|
|
|
self._update_reward_signals() |
|
|
|
return policy_was_updated |
|
|
|
|
|
|
|
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy: |
|
|
|
policy = NNPolicy( |
|
|
|
|
|
|
|
|
|
|
return policy |
|
|
|
|
|
|
|
def update_sac_policy(self) -> None: |
|
|
|
def _update_sac_policy(self) -> bool: |
|
|
|
|
|
|
|
has_updated = False |
|
|
|
self.cumulative_returns_since_policy_update.clear() |
|
|
|
n_sequences = max( |
|
|
|
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1 |
|
|
|
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
|
|
|
|
self.update_steps += 1 |
|
|
|
self.update_steps += 1 |
|
|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
has_updated = True |
|
|
|
|
|
|
|
if self.optimizer.bc_module: |
|
|
|
update_stats = self.optimizer.bc_module.update() |
|
|
|
for stat, val in update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, val) |
|
|
|
|
|
|
|
# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating |
|
|
|
# a large buffer at each update. |
|
|
|
|
|
|
) |
|
|
|
return has_updated |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|
if self.optimizer.bc_module: |
|
|
|
update_stats = self.optimizer.bc_module.update() |
|
|
|
for stat, val in update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, val) |
|
|
|
|
|
|
|
def update_reward_signals(self) -> None: |
|
|
|
def _update_reward_signals(self) -> None: |
|
|
|
""" |
|
|
|
Iterate through the reward signals and update them. Unlike in PPO, |
|
|
|
do it separate from the policy so that it can be done at a different |
|
|
|
|
|
|
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1 |
|
|
|
) |
|
|
|
batch_update_stats: Dict[str, list] = defaultdict(list) |
|
|
|
while self.step / self.reward_signal_update_steps > self.steps_per_update: |
|
|
|
while ( |
|
|
|
self.step / self.reward_signal_update_steps |
|
|
|
> self.reward_signal_steps_per_update |
|
|
|
): |
|
|
|
# Get minibatches for reward signal update if needed |
|
|
|
reward_signal_minibatches = {} |
|
|
|
for name, signal in self.optimizer.reward_signals.items(): |
|
|
|
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
self.reward_signal_update_steps += 1 |
|
|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|
def add_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy |
|
|
|