|
|
|
|
|
|
from mlagents_envs.base_env import BehaviorSpec |
|
|
|
from mlagents.trainers.buffer import BufferKey, RewardSignalUtil |
|
|
|
from mlagents.trainers.trainer.rl_trainer import RLTrainer |
|
|
|
from mlagents.trainers.optimizer import Optimizer |
|
|
|
from mlagents.trainers.policy import Policy |
|
|
|
from mlagents.trainers.policy.torch_policy import TorchPolicy |
|
|
|
from mlagents.trainers.coma.optimizer_torch import TorchCOMAOptimizer |
|
|
|
|
|
|
agent_buffer_trajectory, |
|
|
|
trajectory.next_obs, |
|
|
|
trajectory.next_group_obs, |
|
|
|
trajectory.teammate_dones_reached |
|
|
|
trajectory.all_group_dones_reached |
|
|
|
and trajectory.done_reached |
|
|
|
and not trajectory.interrupted, |
|
|
|
) |
|
|
|
|
|
|
# If this was a terminal trajectory, append stats and reset reward collection |
|
|
|
if trajectory.done_reached: |
|
|
|
self._update_end_episode_stats(agent_id, self.optimizer) |
|
|
|
# Remove dead agents from group reward recording |
|
|
|
self.collected_group_rewards.pop(agent_id) |
|
|
|
|
|
|
|
# If the whole team is done, average the remaining group rewards. |
|
|
|
if trajectory.all_group_dones_reached: |
|
|
|
self.stats_reporter.add_stat( |
|
|
|
"Environment/Group Cumulative Reward", |
|
|
|
self.collected_group_rewards.get(agent_id, 0), |
|
|
|
aggregation=StatsAggregationMethod.HISTOGRAM, |
|
|
|
) |
|
|
|
self.collected_group_rewards.pop(agent_id) |
|
|
|
|
|
|
|
def _is_ready_update(self): |
|
|
|
""" |
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
return self.policy |
|
|
|
|
|
|
|
def _update_end_episode_stats(self, agent_id: str, optimizer: Optimizer) -> None: |
|
|
|
super()._update_end_episode_stats(agent_id, optimizer) |
|
|
|
self.stats_reporter.add_stat( |
|
|
|
"Environment/Team Cumulative Reward", |
|
|
|
self.collected_group_rewards.get(agent_id, 0), |
|
|
|
aggregation=StatsAggregationMethod.HISTOGRAM, |
|
|
|
) |
|
|
|
self.collected_group_rewards.pop(agent_id) |
|
|
|
|
|
|
|
|
|
|
|
def lambda_return(r, value_estimates, gamma=0.99, lambd=0.8, value_next=0.0): |
|
|
|