|
|
|
|
|
|
from mlagents.trainers.agent_processor import AgentManager |
|
|
|
from mlagents.tf_utils.globals import get_rank |
|
|
|
|
|
|
|
try: |
|
|
|
import torch |
|
|
|
except ModuleNotFoundError: |
|
|
|
torch = None # type: ignore |
|
|
|
|
|
|
|
|
|
|
|
class TrainerController: |
|
|
|
def __init__( |
|
|
|
|
|
|
self.kill_trainers = False |
|
|
|
np.random.seed(training_seed) |
|
|
|
tf.set_random_seed(training_seed) |
|
|
|
if torch is not None: |
|
|
|
torch.manual_seed(training_seed) |
|
|
|
self.rank = get_rank() |
|
|
|
|
|
|
|
@timed |
|
|
|