import logging
from typing import Dict , NamedTuple , List , Any , Optional , Callable , Set , Tuple
import cloudpickle
import enum
from mlagents_envs.exception import UnityCommunicationException , UnityTimeOutException
from mlagents_envs.exception import (
UnityCommunicationException ,
UnityTimeOutException ,
UnityEnvironmentException ,
)
from multiprocessing import Process , Pipe , Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
logger = logging . getLogger ( " mlagents.trainers " )
class EnvironmentCommand ( NamedTuple ) :
name : str
class EnvironmentCommand ( enum . Enum ) :
STEP = 1
EXTERNAL_BRAINS = 2
GET_PROPERTIES = 3
RESET = 4
CLOSE = 5
ENV_EXITED = 6
class EnvironmentRequest ( NamedTuple ) :
cmd : EnvironmentCommand
name : str
cmd : EnvironmentCommand
worker_id : int
payload : Any
self . previous_all_action_info : Dict [ str , ActionInfo ] = { }
self . waiting = False
def send ( self , name : str , payload : Any = None ) - > None :
def send ( self , cmd : EnvironmentCommand , payload : Any = None ) - > None :
cmd = EnvironmentCommand ( name , payload )
self . conn . send ( cmd )
req = EnvironmentRequest ( cmd , payload )
self . conn . send ( req )
except ( BrokenPipeError , EOFError ) :
raise UnityCommunicationException ( " UnityEnvironment worker: send failed. " )
if response . cmd == EnvironmentCommand . ENV_EXITED :
env_exception : Exception = response . payload
raise env_exception
return response
except ( BrokenPipeError , EOFError ) :
raise UnityCommunicationException ( " UnityEnvironment worker: recv failed. " )
self . conn . send ( EnvironmentCommand ( " close " ) )
self . conn . send ( EnvironmentRequest ( EnvironmentCommand . CLOSE ) )
except ( BrokenPipeError , EOFError ) :
logger . debug (
f " UnityEnvWorker {self.worker_id} got exception trying to close. "
engine_configuration_channel = EngineConfigurationChannel ( )
engine_configuration_channel . set_configuration ( engine_configuration )
stats_channel = StatsSideChannel ( )
env : BaseEnv = env_factory (
worker_id ,
[ shared_float_properties , engine_configuration_channel , stats_channel ] ,
)
env : BaseEnv = None
def _send_response ( cmd_name , payload ) :
def _send_response ( cmd_name : EnvironmentCommand , payload : Any ) - > None :
parent_conn . send ( EnvironmentResponse ( cmd_name , worker_id , payload ) )
def _generate_all_results ( ) - > AllStepResult :
return result
try :
env = env_factory (
worker_id ,
[ shared_float_properties , engine_configuration_channel , stats_channel ] ,
)
cmd : EnvironmentCommand = parent_conn . recv ( )
if cmd . name == " step " :
all_action_info = cmd . payload
req : EnvironmentRequest = parent_conn . recv ( )
if req . cmd == EnvironmentCommand . STEP :
all_action_info = req . payload
for brain_name , action_info in all_action_info . items ( ) :
if len ( action_info . action ) != 0 :
env . set_actions ( brain_name , action_info . action )
step_response = StepResponse (
all_step_result , get_timer_root ( ) , env_stats
)
step_queue . put ( EnvironmentResponse ( " step " , worker_id , step_response ) )
step_queue . put (
EnvironmentResponse (
EnvironmentCommand . STEP , worker_id , step_response
)
)
elif cmd . name == " external_brains " :
_send_response ( " external_brains " , external_brains ( ) )
elif cmd . name == " get_properties " :
elif req . cmd == EnvironmentCommand . EXTERNAL_BRAINS :
_send_response ( EnvironmentCommand . EXTERNAL_BRAINS , external_brains ( ) )
elif req . cmd == EnvironmentCommand . GET_PROPERTIES :
_send_response ( " get_properties " , reset_params )
elif cmd . name == " reset " :
for k , v in cmd . payload . items ( ) :
_send_response ( EnvironmentCommand . GET_PROPERTIES , reset_params )
elif req . cmd == EnvironmentCommand . RESET :
for k , v in req . payload . items ( ) :
_send_response ( " reset " , all_step_result )
elif cmd . name == " close " :
_send_response ( EnvironmentCommand . RESET , all_step_result )
elif req . cmd == EnvironmentCommand . CLOSE :
except ( KeyboardInterrupt , UnityCommunicationException , UnityTimeOutException ) :
except (
KeyboardInterrupt ,
UnityCommunicationException ,
UnityTimeOutException ,
UnityEnvironmentException ,
) as ex :
step_queue . put ( EnvironmentResponse ( " env_close " , worker_id , None ) )
step_queue . put (
EnvironmentResponse ( EnvironmentCommand . ENV_EXITED , worker_id , ex )
)
_send_response ( EnvironmentCommand . ENV_EXITED , ex )
finally :
# If this worker has put an item in the step queue that hasn't been processed by the EnvManager, the process
# will hang until the item is processed. We avoid this behavior by using Queue.cancel_join_thread()
step_queue . cancel_join_thread ( )
step_queue . close ( )
env . close ( )
if env is not None :
env . close ( )
logger . debug ( f " UnityEnvironment worker {worker_id} done. " )
if not env_worker . waiting :
env_action_info = self . _take_step ( env_worker . previous_step )
env_worker . previous_all_action_info = env_action_info
env_worker . send ( " step " , env_action_info )
env_worker . send ( EnvironmentCommand . STEP , env_action_info )
env_worker . waiting = True
def _step ( self ) - > List [ EnvironmentStep ] :
while len ( worker_steps ) < 1 :
try :
while True :
step = self . step_queue . get_nowait ( )
if step . name == " env_close " :
raise UnityCommunicationException (
" At least one of the environments has closed. "
)
step : EnvironmentResponse = self . step_queue . get_nowait ( )
if step . cmd == EnvironmentCommand . ENV_EXITED :
env_exception : Exception = step . payload
raise env_exception
self . env_workers [ step . worker_id ] . waiting = False
if step . worker_id not in step_workers :
worker_steps . append ( step )
self . env_workers [ step . worker_id ] . waiting = False
# First enqueue reset commands for all workers so that they reset in parallel
for ew in self . env_workers :
ew . send ( " reset " , config )
ew . send ( EnvironmentCommand . RESET , config )
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self . env_workers :
ew . previous_step = EnvironmentStep ( ew . recv ( ) . payload , ew . worker_id , { } , { } )
def external_brains ( self ) - > Dict [ AgentGroup , BrainParameters ] :
self . env_workers [ 0 ] . send ( " external_brains " )
self . env_workers [ 0 ] . send ( EnvironmentCommand . EXTERNAL_BRAINS )
self . env_workers [ 0 ] . send ( " get_properties " )
self . env_workers [ 0 ] . send ( EnvironmentCommand . GET_PROPERTIES )
return self . env_workers [ 0 ] . recv ( ) . payload
def close ( self ) - > None :