|
|
|
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|
def create_sac_optimizer(self) -> SACOptimizer: |
|
|
|
return SACOptimizer(cast(TFPolicy, self.policy), self.trainer_settings) |
|
|
|
|
|
|
|
def add_policy( |
|
|
|
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy |
|
|
|
) -> None: |
|
|
|
|
|
|
) |
|
|
|
self.policy = policy |
|
|
|
self.policies[parsed_behavior_id.behavior_id] = policy |
|
|
|
self.optimizer = SACOptimizer( |
|
|
|
cast(TFPolicy, self.policy), self.trainer_settings |
|
|
|
) |
|
|
|
self.optimizer = self.create_sac_optimizer() |
|
|
|
for _reward_signal in self.optimizer.reward_signals.keys(): |
|
|
|
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0) |
|
|
|
# Needed to resume loads properly |
|
|
|