|
|
|
|
|
|
from typing import List, Dict |
|
|
|
from collections import defaultdict |
|
|
|
from collections import defaultdict, Counter |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mlagents.trainers.trainer import Trainer |
|
|
|
|
|
|
# We still need some info from the policy (memories, previous actions) |
|
|
|
# that really should be gathered by the env-manager. |
|
|
|
self.policy = policy |
|
|
|
self.episode_steps: Dict[str, int] = {} |
|
|
|
self.max_trajectory_length = max_trajectory_length |
|
|
|
self.episode_steps: Counter = Counter() |
|
|
|
self.episode_rewards: Dict[str, float] = defaultdict(lambda: 0.0) |
|
|
|
if max_trajectory_length: |
|
|
|
self.max_trajectory_length = max_trajectory_length |
|
|
|
self.ignore_max_length = False |
|
|
|
else: |
|
|
|
self.max_trajectory_length = 0 |
|
|
|
self.ignore_max_length = True |
|
|
|
self.trainer = trainer |
|
|
|
|
|
|
|
def add_experiences( |
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
|
next_info.local_done[next_idx] |
|
|
|
or len(self.experience_buffers[agent_id]) |
|
|
|
>= self.max_trajectory_length |
|
|
|
or ( |
|
|
|
not self.ignore_max_length |
|
|
|
and len(self.experience_buffers[agent_id]) |
|
|
|
>= self.max_trajectory_length |
|
|
|
) |
|
|
|
) and len(self.experience_buffers[agent_id]) > 0: |
|
|
|
# Make next AgentExperience |
|
|
|
next_obs = [] |
|
|
|