import sys
import numpy as np
from typing import List , Dict , TypeVar , Generic , Tuple , Any , Union
from collections import defaultdict , Counter
import queue
StatsAggregationMethod ,
EnvironmentStats ,
)
from mlagents.trainers.trajectory import Trajectory , AgentExperience
from mlagents.trainers.trajectory import AgentStatus , Trajectory , AgentExperience
from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents.trainers.behavior_id_utils import (
get_global_agent_id ,
get_global_group_id ,
GlobalAgentId ,
GlobalGroupId ,
)
T = TypeVar ( " T " )
: param max_trajectory_length : Maximum length of a trajectory before it is added to the trainer .
: param stats_category : The category under which to write the stats . Usually , this comes from the Trainer.
"""
self . experience_buffers : Dict [ str , List [ AgentExperience ] ] = defaultdict ( list )
self . last_step_result : Dict [ str , Tuple [ DecisionStep , int ] ] = { }
self . _experience_buffers : Dict [
GlobalAgentId , List [ AgentExperience ]
] = defaultdict ( list )
self . _last_step_result : Dict [ GlobalAgentId , Tuple [ DecisionStep , int ] ] = { }
# current_group_obs is used to collect the current (i.e. the most recently seen)
# obs of all the agents in the same group, and assemble the group obs.
# It is a dictionary of GlobalGroupId to dictionaries of GlobalAgentId to observation.
self . _current_group_obs : Dict [
GlobalGroupId , Dict [ GlobalAgentId , List [ np . ndarray ] ]
] = defaultdict ( lambda : defaultdict ( list ) )
# group_status is used to collect the current, most recently seen
# group status of all the agents in the same group, and assemble the group's status.
# It is a dictionary of GlobalGroupId to dictionaries of GlobalAgentId to AgentStatus.
self . _group_status : Dict [
GlobalGroupId , Dict [ GlobalAgentId , AgentStatus ]
] = defaultdict ( lambda : defaultdict ( None ) )
self . last_take_action_outputs : Dict [ str , ActionInfoOutputs ] = { }
self . _last_take_action_outputs : Dict [ GlobalAgentId , ActionInfoOutputs ] = { }
self . _episode_steps : Counter = Counter ( )
self . _episode_rewards : Dict [ GlobalAgentId , float ] = defaultdict ( float )
self . _stats_reporter = stats_reporter
self . _max_trajectory_length = max_trajectory_length
self . _trajectory_queues : List [ AgentManagerQueue [ Trajectory ] ] = [ ]
self . _behavior_id = behavior_id
self . episode_steps : Counter = Counter ( )
self . episode_rewards : Dict [ str , float ] = defaultdict ( float )
self . stats_reporter = stats_reporter
self . max_trajectory_length = max_trajectory_length
self . trajectory_queues : List [ AgentManagerQueue [ Trajectory ] ] = [ ]
self . behavior_id = behavior_id
def add_experiences (
self ,
take_action_outputs = previous_action . outputs
if take_action_outputs :
for _entropy in take_action_outputs [ " entropy " ] :
self . stats_reporter . add_stat ( " Policy/Entropy " , _entropy )
self . _ stats_reporter. add_stat ( " Policy/Entropy " , _entropy )
# Make unique agent_ids that are global across workers
action_global_agent_ids = [
if global_id in self . last_step_result : # Don't store if agent just reset
self . last_take_action_outputs [ global_id ] = take_action_outputs
if global_id in self . _last_step_result : # Don't store if agent just reset
self . _last_take_action_outputs [ global_id ] = take_action_outputs
# Iterate over all the terminal steps
# Iterate over all the terminal steps, first gather all the group obs
# and then create the AgentExperiences/Trajectories. _add_to_group_status
# stores Group statuses in a common data structure self.group_status
for terminal_step in terminal_steps . values ( ) :
self . _add_group_status_and_obs ( terminal_step , worker_id )
terminal_step , global_id , terminal_steps . agent_id_to_index [ local_id ]
terminal_step , worker_id , terminal_steps . agent_id_to_index [ local_id ]
# Iterate over all the decision steps
# Clear the last seen group obs when agents die.
self . _clear_group_status_and_obs ( global_id )
# Iterate over all the decision steps, first gather all the group obs
# and then create the trajectories. _add_to_group_status
# stores Group statuses in a common data structure self.group_status
for ongoing_step in decision_steps . values ( ) :
self . _add_group_status_and_obs ( ongoing_step , worker_id )
global_id = get_global_agent_id ( worker_id , local_id )
ongoing_step , global_id , decision_steps . agent_id_to_index [ local_id ]
ongoing_step , worker_id , decision_steps . agent_id_to_index [ local_id ]
if _gid in self . last_step_result :
if _gid in self . _last_step_result :
def _add_group_status_and_obs (
self , step : Union [ TerminalStep , DecisionStep ] , worker_id : int
) - > None :
"""
Takes a TerminalStep or DecisionStep and adds the information in it
to self . group_status . This information can then be retrieved
when constructing trajectories to get the status of group mates . Also stores the current
observation into current_group_obs , to be used to get the next group observations
for bootstrapping .
: param step : TerminalStep or DecisionStep
: param worker_id : Worker ID of this particular environment . Used to generate a
global group id .
"""
global_agent_id = get_global_agent_id ( worker_id , step . agent_id )
stored_decision_step , idx = self . _last_step_result . get (
global_agent_id , ( None , None )
)
stored_take_action_outputs = self . _last_take_action_outputs . get (
global_agent_id , None
)
if stored_decision_step is not None and stored_take_action_outputs is not None :
# 0, the default group_id, means that the agent doesn't belong to an agent group.
# If 0, don't add any groupmate information.
if step . group_id > 0 :
global_group_id = get_global_group_id ( worker_id , step . group_id )
stored_actions = stored_take_action_outputs [ " action " ]
action_tuple = ActionTuple (
continuous = stored_actions . continuous [ idx ] ,
discrete = stored_actions . discrete [ idx ] ,
)
group_status = AgentStatus (
obs = stored_decision_step . obs ,
reward = step . reward ,
action = action_tuple ,
done = isinstance ( step , TerminalStep ) ,
)
self . _group_status [ global_group_id ] [ global_agent_id ] = group_status
self . _current_group_obs [ global_group_id ] [ global_agent_id ] = step . obs
def _clear_group_status_and_obs ( self , global_id : GlobalAgentId ) - > None :
"""
Clears an agent from self._group_status and self._current_group_obs.
"""
self . _delete_in_nested_dict ( self . _current_group_obs , global_id )
self . _delete_in_nested_dict ( self . _group_status , global_id )
def _delete_in_nested_dict ( self , nested_dict : Dict [ str , Any ] , key : str ) - > None :
for _manager_id in list ( nested_dict . keys ( ) ) :
_team_group = nested_dict [ _manager_id ]
self . _safe_delete ( _team_group , key )
if not _team_group : # if dict is empty
self . _safe_delete ( nested_dict , _manager_id )
self , step : Union [ TerminalStep , DecisionStep ] , global_id : str , index : int
self , step : Union [ TerminalStep , DecisionStep ] , worker_id : int , index : int
stored_decision_step , idx = self . last_step_result . get ( global_id , ( None , None ) )
stored_take_action_outputs = self . last_take_action_outputs . get ( global_id , None )
global_agent_id = get_global_agent_id ( worker_id , step . agent_id )
global_group_id = get_global_group_id ( worker_id , step . group_id )
stored_decision_step , idx = self . _last_step_result . get (
global_agent_id , ( None , None )
)
stored_take_action_outputs = self . _last_take_action_outputs . get (
global_agent_id , None
)
self . last_step_result [ global_id ] = ( step , index )
self . _last_step_result [ global_agent_id ] = ( step , index )
memory = self . policy . retrieve_previous_memories ( [ global_id ] ) [ 0 , : ]
memory = self . policy . retrieve_previous_memories ( [ global_agent_id ] ) [ 0 , : ]
else :
memory = None
done = terminated # Since this is an ongoing step
discrete = stored_action_probs . discrete [ idx ] ,
)
action_mask = stored_decision_step . action_mask
prev_action = self . policy . retrieve_previous_action ( [ global_id ] ) [ 0 , : ]
prev_action = self . policy . retrieve_previous_action ( [ global_agent_id ] ) [ 0 , : ]
# Assemble teammate_obs. If none saved, then it will be an empty list.
group_statuses = [ ]
for _id , _mate_status in self . _group_status [ global_group_id ] . items ( ) :
if _id != global_agent_id :
group_statuses . append ( _mate_status )
experience = AgentExperience (
obs = obs ,
reward = step . reward ,
prev_action = prev_action ,
interrupted = interrupted ,
memory = memory ,
group_status = group_statuses ,
group_reward = step . group_reward ,
self . experience_buffers [ global_id ] . append ( experience )
self . episode_rewards [ global_id ] + = step . reward
self . _experience_buffers [ global_agent_id ] . append ( experience )
self . _episode_rewards [ global_agent_id ] + = step . reward
self . episode_steps [ global_id ] + = 1
self . _episode_steps [ global_agent_id ] + = 1
len ( self . experience_buffers [ global_id ] ) > = self . max_trajectory_length
len ( self . _experience_buffers [ global_agent_id ] )
> = self . _max_trajectory_length
# Make next AgentExperience
next_group_obs = [ ]
for _id , _obs in self . _current_group_obs [ global_group_id ] . items ( ) :
if _id != global_agent_id :
next_group_obs . append ( _obs )
steps = self . experience_buffers [ global_id ] ,
agent_id = global_id ,
steps = self . _experience_buffers [ global_agent_id ] ,
agent_id = global_agent_id ,
behavior_id = self . behavior_id ,
next_group_obs = next_group_obs ,
behavior_id = self . _behavior_id ,
for traj_queue in self . trajectory_queues :
for traj_queue in self . _trajectory_queues :
self . experience_buffers [ global_id ] = [ ]
self . _experience_buffers [ global_agent_id ] = [ ]
self . stats_reporter . add_stat (
" Environment/Episode Length " , self . episode_steps . get ( global_id , 0 )
self . _stats_reporter . add_stat (
" Environment/Episode Length " ,
self . _episode_steps . get ( global_agent_id , 0 ) ,
self . _clean_agent_data ( global_id )
self . _clean_agent_data ( global_agent_id )
def _clean_agent_data ( self , global_id : str ) - > None :
def _clean_agent_data ( self , global_id : GlobalAgentId ) - > None :
self . _safe_delete ( self . experience_buffers , global_id )
self . _safe_delete ( self . last_take_action_outputs , global_id )
self . _safe_delete ( self . last_step_result , global_id )
self . _safe_delete ( self . episode_steps , global_id )
self . _safe_delete ( self . episode_rewards , global_id )
self . _safe_delete ( self . _experience_buffers , global_id )
self . _safe_delete ( self . _last_take_action_outputs , global_id )
self . _safe_delete ( self . _last_step_result , global_id )
self . _safe_delete ( self . _episode_steps , global_id )
self . _safe_delete ( self . _episode_rewards , global_id )
self . policy . remove_previous_action ( [ global_id ] )
self . policy . remove_memories ( [ global_id ] )
assembles a Trajectory
: param trajectory_queue : Trajectory queue to publish to .
"""
self . trajectory_queues . append ( trajectory_queue )
self . _ trajectory_queues. append ( trajectory_queue )
def end_episode ( self ) - > None :
"""
all_gids = list ( self . experience_buffers . keys ( ) ) # Need to make copy
all_gids = list ( self . _experience_buffers . keys ( ) ) # Need to make copy
for _gid in all_gids :
self . _clean_agent_data ( _gid )
super ( ) . __init__ ( policy , behavior_id , stats_reporter , max_trajectory_length )
trajectory_queue_len = 20 if threaded else 0
self . trajectory_queue : AgentManagerQueue [ Trajectory ] = AgentManagerQueue (
self . behavior_id , maxlen = trajectory_queue_len
self . _ behavior_id, maxlen = trajectory_queue_len
self . behavior_id , maxlen = 0
self . _ behavior_id, maxlen = 0
)
self . publish_trajectory_queue ( self . trajectory_queue )
for stat_name , value_list in env_stats . items ( ) :
for val , agg_type in value_list :
if agg_type == StatsAggregationMethod . AVERAGE :
self . stats_reporter . add_stat ( stat_name , val , agg_type )
self . _ stats_reporter. add_stat ( stat_name , val , agg_type )
self . stats_reporter . add_stat ( stat_name , val , agg_type )
self . _ stats_reporter. add_stat ( stat_name , val , agg_type )
self . stats_reporter . set_stat ( stat_name , val )
self . _ stats_reporter. set_stat ( stat_name , val )