from time import time
LOGGER = logging . getLogger ( " mlagents.trainers " )
FIELD_NAMES = [ ' Brain name ' , ' Time to update policy ' ,
' Time since start of training ' , ' Time for last experience collection ' ,
' Number of experiences used for training ' , ' Mean return ' ]
class TrainerMetrics :
"""
def __init__ ( self , path : str , brain_name : str ) :
"""
: str path : Fully qualified path where CSV is stored .
self . brain_name = brain_name
self . FIELD_NAMES = [ ' Brain name ' , ' Time to update policy ' ,
' Time since start of training ' , ' Time for last experience collection ' ,
' Number of experiences used for training ' , ' Mean return ' ]
self . rows = [ ]
self . time_start_experience_collection = None
self . time_training_start = time ( )
"""
Inform Metrics class that experience collection is done .
"""
if self . start_experience_collection_timer :
self . delta_last_experience_collection = time ( ) - self . time_start_experience_collection
if self . time_start_experience_collection :
curr_delta = time ( ) - self . time_start_experience_collection
if self . delta_last_experience_collection is None :
self . delta_last_experience_collection = curr_delta
else :
self . delta_last_experience_collection + = curr_delta
self . time_start_experience_collection = None
def add_delta_step ( self , delta : float ) :
"""
Inform Metrics class about time to step in environment .
"""
if self . delta_last_experience_collection :
self . delta_last_experience_collection + = delta
self . delta_last_experience_collection = 0.0
self . time_start_experience_collection = None
self . delta_last_experience_collection = delta
def start_policy_update_timer ( self , number_experiences : int , mean_return : float ) :
"""
for c in [ self . delta_policy_update , delta_train_start ,
self . delta_last_experience_collection ,
self . last_buffer_length , self . last_mean_return ] )
self . delta_last_experience_collection = None
def end_policy_update ( self ) :
"""
Inform Metrics class that policy update has started .
self . last_buffer_length , self . last_mean_return ) )
self . _add_row ( delta_train_start )
with open ( self . path , ' w ' ) as f :
writer = csv . writer ( f )
writer . writerow ( self . FIELD_NAMES )
with open ( self . path , ' w ' ) as file :
writer = csv . writer ( file )
writer . writerow ( FIELD_NAMES )
for row in self . rows :
writer . writerow ( row )