浏览代码

fixed the windows ctrl-c bug (#1558)

* Documentation tweaks and updates (#1479)

* Add blurb about using the --load flag in the intro guide, and typo fix.

* Add section in tutorial to create multiple area learning environment.

* Add mention of Done() method in agent design

* fixed the windows ctrl-c bug

* fixed typo

* removed some uncessary printing

* nothing

* make the import of the win api conditional

* removved the duplicate code

* added the ability to use python debugger on ml-agents

* added newline at the end, changed the import to be complete path

* changed the info.log into policy.export_model, changed the sys.platform to use startswith

* fixed a bug

* remove the printing of the path

* tweaked the info message to notify the user about the expected error message

* removed some logging according to comments

* removed the sys import

* Revert "Documentation tweaks and updates (#1479)"

This reverts commit 84ef07a4525fa8a89f4...
/hotfix-v0.9.2a
GitHub 6 年前
当前提交
cc083fd8
共有 3 个文件被更改,包括 40 次插入15 次删除
  1. 8
      ml-agents/mlagents/trainers/learn.py
  2. 1
      ml-agents/mlagents/trainers/policy.py
  3. 46
      ml-agents/mlagents/trainers/trainer_controller.py

8
ml-agents/mlagents/trainers/learn.py


import numpy as np
from docopt import docopt
from .trainer_controller import TrainerController
from .exception import TrainerError
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.exception import TrainerError
def run_training(sub_id, run_seed, run_options, process_queue):

# Wait for signal that environment has successfully launched
while process_queue.get() is not True:
continue
# For python debugger to directly run this script
if __name__ == "__main__":
main()

1
ml-agents/mlagents/trainers/policy.py


clear_devices=True, initializer_nodes='', input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0')
logger.info('Exported ' + self.model_path + '.bytes file')
def _process_graph(self):
"""

46
ml-agents/mlagents/trainers/trainer_controller.py


import glob
import logging
import shutil
import sys
if sys.platform.startswith('win'):
import win32api
import win32con
import yaml
import re

self.keep_checkpoints = keep_checkpoints
self.trainers = {}
self.seed = seed
self.global_step = 0
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
self.env = UnityEnvironment(file_name=env_path,

self.trainers[brain_name].save_model()
self.logger.info('Saved Model')
def _save_model_when_interrupted(self, steps=0):
self.logger.info('Learning was interrupted. Please wait '
'while the graph is generated.')
self._save_model(steps)
def _win_handler(self, event):
"""
This function gets triggered after ctrl-c or ctrl-break is pressed
under Windows platform.
"""
if event in (win32con.CTRL_C_EVENT, win32con.CTRL_BREAK_EVENT):
self._save_model_when_interrupted(self.global_step)
self._export_graph()
sys.exit()
return True
return False
def _export_graph(self):
"""
Exports latest saved models to .bytes format for Unity embedding.

self._initialize_trainers(trainer_config)
for _, t in self.trainers.items():
self.logger.info(t)
global_step = 0 # This is only for saving the model
if sys.platform.startswith('win'):
# Add the _win_handler function to the windows console's handler function list
win32api.SetConsoleCtrlHandler(self._win_handler, True)
try:
while any([t.get_step <= t.get_max_steps \
for k, t in self.trainers.items()]) \

# Write training statistics to Tensorboard.
if self.meta_curriculum is not None:
trainer.write_summary(
global_step,
self.global_step,
trainer.write_summary(global_step)
trainer.write_summary(self.global_step)
global_step += 1
if global_step % self.save_freq == 0 and global_step != 0 \
self.global_step += 1
if self.global_step % self.save_freq == 0 and self.global_step != 0 \
self._save_model(steps=global_step)
self._save_model(steps=self.global_step)
if global_step != 0 and self.train_model:
self._save_model(steps=global_step)
if self.global_step != 0 and self.train_model:
self._save_model(steps=self.global_step)
print('--------------------------Now saving model--------------'
'-----------')
self.logger.info('Learning was interrupted. Please wait '
'while the graph is generated.')
self._save_model(steps=global_step)
self._save_model_when_interrupted(steps=self.global_step)
pass
self.env.close()
if self.train_model:
正在加载...
取消
保存