浏览代码

Fix subprocess model saving on Windows

On Windows the interrupt for subprocesses works in a different
way from OSX/Linux. The result is that child subprocesses and
their pipes may close while the parent process is still running
during a keyboard (ctrl+C) interrupt.

To handle this, this change adds handling for EOFError and
BrokenPipeError exceptions when interacting with subprocess
environments. Additional management is also added to be sure
when using parallel runs using the "num-runs" option that
the threads for each run are joined and KeyboardInterrupts are
handled.

These changes made the "_win_handler" we used to specially
manage interrupts on Windows unnecessary, so they have been
removed.
/develop-generalizationTraining-TrainerController
Jonathan Harper 6 年前
当前提交
7a0d1531
共有 4 个文件被更改,包括 29 次插入29 次删除
  1. 21
      ml-agents-envs/mlagents/envs/subprocess_environment.py
  2. 15
      ml-agents/mlagents/trainers/learn.py
  3. 4
      ml-agents/mlagents/trainers/tests/test_learn.py
  4. 18
      ml-agents/mlagents/trainers/trainer_controller.py

21
ml-agents-envs/mlagents/envs/subprocess_environment.py


conn: Connection
def send(self, name: str, payload=None):
cmd = EnvironmentCommand(name, payload)
self.conn.send(cmd)
try:
cmd = EnvironmentCommand(name, payload)
self.conn.send(cmd)
except (BrokenPipeError, EOFError):
raise KeyboardInterrupt
response: EnvironmentResponse = self.conn.recv()
return response
try:
response: EnvironmentResponse = self.conn.recv()
return response
except (BrokenPipeError, EOFError):
raise KeyboardInterrupt
self.conn.send(EnvironmentCommand("close"))
try:
self.conn.send(EnvironmentCommand('close'))
except (BrokenPipeError, EOFError):
pass
self.process.join()

env_factory: Callable[[int], BaseUnityEnvironment]
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()
child_process = Process(target=worker, args=(child_conn, pickled_env_factory, worker_id))
child_process.start()
return UnityEnvWorker(child_process, worker_id, parent_conn)

15
ml-agents/mlagents/trainers/learn.py


trainer_config_path = run_options['<trainer-config-path>']
# Recognize and use docker volume if one is passed as an argument
if not docker_target_name:
model_path = './models/{run_id}'.format(run_id=run_id)
model_path = './models/{run_id}-{sub_id}'.format(run_id=run_id, sub_id=sub_id)
summaries_dir = './summaries'
else:
trainer_config_path = \

'/{docker_target_name}/{curriculum_folder}'.format(
docker_target_name=docker_target_name,
curriculum_folder=curriculum_folder)
model_path = '/{docker_target_name}/models/{run_id}'.format(
model_path = '/{docker_target_name}/models/{run_id}-{sub_id}'.format(
run_id=run_id)
run_id=run_id,
sub_id=sub_id)
summaries_dir = '/{docker_target_name}/summaries'.format(
docker_target_name=docker_target_name)

# Wait for signal that environment has successfully launched
while process_queue.get() is not True:
continue
# Wait for jobs to complete. Otherwise we'll have an extra
# unhandled KeyboardInterrupt if we end early.
try:
for job in jobs:
job.join()
except KeyboardInterrupt:
pass
# For python debugger to directly run this script
if __name__ == "__main__":

4
ml-agents/mlagents/trainers/tests/test_learn.py


with patch.object(TrainerController, "start_learning", MagicMock()):
learn.run_training(0, 0, basic_options(), MagicMock())
mock_init.assert_called_once_with(
'./models/ppo',
'./models/ppo-0',
'./summaries',
'ppo-0',
50000,

with patch.object(TrainerController, "start_learning", MagicMock()):
learn.run_training(0, 0, options_with_docker_target, MagicMock())
mock_init.assert_called_once()
assert(mock_init.call_args[0][0] == '/dockertarget/models/ppo')
assert(mock_init.call_args[0][0] == '/dockertarget/models/ppo-0')
assert(mock_init.call_args[0][1] == '/dockertarget/summaries')

18
ml-agents/mlagents/trainers/trainer_controller.py


import logging
import shutil
import sys
if sys.platform.startswith('win'):
import win32api
import win32con
from typing import *
import numpy as np

'while the graph is generated.')
self._save_model(steps)
def _win_handler(self, event):
"""
This function gets triggered after ctrl-c or ctrl-break is pressed
under Windows platform.
"""
if event in (win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT):
self._save_model_when_interrupted(self.global_step)
self._export_graph()
sys.exit()
return True
return False
def _write_training_metrics(self):
"""
Write all CSV metrics

for brain_name, trainer in self.trainers.items():
trainer.write_tensorboard_text('Hyperparameters',
trainer.parameters)
if sys.platform.startswith('win'):
# Add the _win_handler function to the windows console's handler function list
win32api.SetConsoleCtrlHandler(self._win_handler, True)
try:
curr_info = self._reset_env(env)
while any([t.get_step <= t.get_max_steps \

正在加载...
取消
保存