浏览代码

Created exception module for unitytrainers.

/develop-generalizationTraining-TrainerController
Deric Pang 7 年前
当前提交
cd7c854c
共有 4 个文件被更改,包括 27 次插入10 次删除
  1. 9
      python/tests/test_unitytrainers.py
  2. 1
      python/unitytrainers/__init__.py
  3. 12
      python/unitytrainers/curriculum.py
  4. 15
      python/unitytrainers/exception.py

9
python/tests/test_unitytrainers.py


from unitytrainers.models import *
from unitytrainers.ppo.trainer import PPOTrainer
from unitytrainers.bc.trainer import BehavioralCloningTrainer
from unitytrainers import Curriculum
from unityagents import UnityEnvironmentException
from unitytrainers.curriculum import Curriculum
from unitytrainers.exception import CurriculumError
from unityagents.exception import UnityEnvironmentException
from .mock_communicator import MockCommunicator
dummy_start = '''{

with mock.patch(open_name, create=True) as mock_open:
mock_open.return_value = 0
mock_load.return_value = bad_curriculum
with pytest.raises(UnityEnvironmentException):
with pytest.raises(CurriculumError):
with pytest.raises(UnityEnvironmentException):
with pytest.raises(CurriculumError):
Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1})
curriculum = Curriculum('tests/test_unityagents.py', {"param1": 1, "param2": 1, "param3": 1})
assert curriculum.get_lesson_number == 0

1
python/unitytrainers/__init__.py


from .bc.trainer import *
from .ppo.models import *
from .ppo.trainer import *
from .exception import *

12
python/unitytrainers/curriculum.py


import json
from unityagents.exception import UnityEnvironmentException
from .exception import CurriculumError
import logging

with open(location) as data_file:
self.data = json.load(data_file)
except IOError:
raise UnityEnvironmentException(
raise CurriculumError(
raise UnityEnvironmentException("There was an error decoding {}".format(location))
raise CurriculumError("There was an error decoding {}".format(location))
raise UnityEnvironmentException("{0} does not contain a "
raise CurriculumError("{0} does not contain a "
"{1} field.".format(location, key))
parameters = self.data['parameters']
self.measure_type = self.data['measure']

raise UnityEnvironmentException(
raise CurriculumError(
raise UnityEnvironmentException(
raise CurriculumError(
"The parameter {0} in Curriculum {1} must have {2} values "
"but {3} were found".format(key, location,
self.max_lesson_number + 1, len(parameters[key])))

15
python/unitytrainers/exception.py


"""
Contains exceptions for the unitytrainers package.
"""
class TrainerError(Exception):
"""
Any error related to the trainers in the ML-Agents Toolkit.
"""
pass
class CurriculumError(TrainerError):
"""
Any error related to training with a curriculum.
"""
pass
正在加载...
取消
保存