import cloudpickle
from mlagents.envs import UnityEnvironment
from multiprocessing import Process , Pipe
from multiprocessing import Process , Pipe , Queue
from queue import Empty as EmptyQueueException
from mlagents.envs.timers import timed , hierarchical_timer
from mlagents.envs import AllBrainInfo , BrainParameters , ActionInfo
from mlagents.envs.timers import timed
from mlagents.envs import BrainParameters , ActionInfo
class EnvironmentCommand ( NamedTuple ) :
self . conn = conn
self . previous_step : StepInfo = StepInfo ( None , { } , None )
self . previous_all_action_info : Dict [ str , ActionInfo ] = { }
self . waiting = False
def send ( self , name : str , payload = None ) :
try :
self . process . join ( )
def worker ( parent_conn : Connection , pickled_env_factory : str , worker_id : int ) :
def worker (
parent_conn : Connection , step_queue : Queue , pickled_env_factory : str , worker_id : int
) :
env_factory : Callable [ [ int ] , UnityEnvironment ] = cloudpickle . loads (
pickled_env_factory
)
cmd : EnvironmentCommand = parent_conn . recv ( )
if cmd . name == " step " :
all_action_info = cmd . payload
# When an environment is "global_done" it means automatic agent reset won't occur, so we need
# to perform an academy reset.
if env . global_done :
all_brain_info = env . reset ( )
else :
texts [ brain_name ] = action_info . text
values [ brain_name ] = action_info . value
all_brain_info = env . step ( actions , memories , texts , values )
_send_r esponse( " step " , all_brain_info )
step_queue . put ( EnvironmentR esponse( " step " , worker_id , all_brain_info ) )
elif cmd . name == " external_brains " :
_send_response ( " external_brains " , env . external_brains )
elif cmd . name == " reset_parameters " :
except KeyboardInterrupt :
print ( " UnityEnvironment worker: keyboard interrupt " )
finally :
step_queue . close ( )
env . close ( )
) :
super ( ) . __init__ ( )
self . env_workers : List [ UnityEnvWorker ] = [ ]
self . step_queue : Queue = Queue ( )
self . env_workers . append ( self . create_worker ( worker_idx , env_factory ) )
def get_last_steps ( self ) :
return [ ew . previous_step for ew in self . env_workers ]
self . env_workers . append (
self . create_worker ( worker_idx , self . step_queue , env_factory )
)
worker_id : int , env_factory : Callable [ [ int ] , BaseUnityEnvironment ]
worker_id : int ,
step_queue : Queue ,
env_factory : Callable [ [ int ] , BaseUnityEnvironment ] ,
) - > UnityEnvWorker :
parent_conn , child_conn = Pipe ( )
child_process = Process (
target = worker , args = ( child_conn , pickled_env_factory , worker_id )
target = worker , args = ( child_conn , step_queue , pickled_env_factory , worker_id )
def _queue_steps ( self ) - > None :
for env_worker in self . env_workers :
if not env_worker . waiting :
env_action_info = self . _take_step ( env_worker . previous_step )
env_worker . previous_all_action_info = env_action_info
env_worker . send ( " step " , env_action_info )
env_worker . waiting = True
for env_worker in self . env_workers :
all_action_info = self . _take_step ( env_worker . previous_step )
env_worker . previous_all_action_info = all_action_info
env_worker . send ( " step " , all_action_info )
# Queue steps for any workers which aren't in the "waiting" state.
self . _queue_steps ( )
worker_steps : List [ EnvironmentResponse ] = [ ]
step_workers : Set [ int ] = set ( )
# Poll the step queue for completed steps from environment workers until we retrieve
# 1 or more, which we will then return as StepInfos
while len ( worker_steps ) < 1 :
try :
while True :
step = self . step_queue . get_nowait ( )
self . env_workers [ step . worker_id ] . waiting = False
if step . worker_id not in step_workers :
worker_steps . append ( step )
step_workers . add ( step . worker_id )
except EmptyQueueException :
pass
with hierarchical_timer ( " recv " ) :
step_brain_infos : List [ AllBrainInfo ] = [
self . env_workers [ i ] . recv ( ) . payload for i in range ( len ( self . env_workers ) )
]
steps = [ ]
for i in range ( len ( step_brain_infos ) ) :
env_worker = self . env_workers [ i ]
step_info = StepInfo (
env_worker . previous_step . current_all_brain_info ,
step_brain_infos [ i ] ,
env_worker . previous_all_action_info ,
)
env_worker . previous_step = step_info
steps . append ( step_info )
return steps
step_infos = self . _postprocess_steps ( worker_steps )
return step_infos
self . _broadcast_message ( " reset " , ( config , train_mode , custom_reset_parameters ) )
reset_results = [
self . env_workers [ i ] . recv ( ) . payload for i in range ( len ( self . env_workers ) )
]
for i in range ( len ( reset_results ) ) :
env_worker = self . env_workers [ i ]
env_worker . previous_step = StepInfo ( None , reset_results [ i ] , None )
while any ( [ ew . waiting for ew in self . env_workers ] ) :
if not self . step_queue . empty ( ) :
step = self . step_queue . get_nowait ( )
self . env_workers [ step . worker_id ] . waiting = False
# First enqueue reset commands for all workers so that they reset in parallel
for ew in self . env_workers :
ew . send ( " reset " , ( config , train_mode , custom_reset_parameters ) )
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self . env_workers :
ew . previous_step = StepInfo ( None , ew . recv ( ) . payload , None )
return list ( map ( lambda ew : ew . previous_step , self . env_workers ) )
@property
self . env_workers [ 0 ] . send ( " reset_parameters " )
return self . env_workers [ 0 ] . recv ( ) . payload
def close ( self ) :
def close ( self ) - > None :
self . step_queue . close ( )
self . step_queue . join_thread ( )
def _broadcast_message ( self , name : str , payload = None ) :
for env in self . env_workers :
env . send ( name , payload )
def _postprocess_steps (
self , env_steps : List [ EnvironmentResponse ]
) - > List [ StepInfo ] :
step_infos = [ ]
for step in env_steps :
env_worker = self . env_workers [ step . worker_id ]
new_step = StepInfo (
env_worker . previous_step . current_all_brain_info ,
step . payload ,
env_worker . previous_all_action_info ,
)
step_infos . append ( new_step )
env_worker . previous_step = new_step
return step_infos
@timed
def _take_step ( self , last_step : StepInfo ) - > Dict [ str , ActionInfo ] :