|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super(RLTrainer, self).__init__(*args, **kwargs) |
|
|
|
self.param_keys: List[str] = [] |
|
|
|
self.cumulative_returns_since_policy_update: List[float] = [] |
|
|
|
self.step: int = 0 |
|
|
|
self.training_start_time = time.time() |
|
|
|
self.summary_freq = self.trainer_parameters["summary_freq"] |
|
|
|
|
|
|
A signal that the Episode has ended. The buffer must be reset. |
|
|
|
Get only called when the academy resets. |
|
|
|
""" |
|
|
|
for agent_id in self.episode_steps: |
|
|
|
self.episode_steps[agent_id] = 0 |
|
|
|
self.episode_steps[agent_id] = 0 |
|
|
|
self.cumulative_returns_since_policy_update.append( |
|
|
|
rewards.get(agent_id, 0) |
|
|
|
) |
|
|
|
self.reward_buffer.appendleft(rewards.get(agent_id, 0)) |
|
|
|
rewards[agent_id] = 0 |
|
|
|
else: |
|
|
|