浏览代码

Fix remaining tests

/MLA-1734-demo-provider
Arthur Juliani 4 年前
当前提交
b074c252
共有 3 个文件被更改,包括 9 次插入8 次删除
  1. 4
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  2. 9
      ml-agents/mlagents/trainers/tests/tensorflow/test_models.py
  3. 4
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py

4
ml-agents/mlagents/trainers/tests/simple_test_envs.py


ActionSpec,
BaseEnv,
BehaviorSpec,
SensorType,
DecisionSteps,
TerminalSteps,
BehaviorMapping,

)
else:
action_spec = ActionSpec.create_continuous(action_size)
self.behavior_spec = BehaviorSpec(self._make_obs_spec(), action_spec)
sensor_type_list = [SensorType.OBSERVATION for i in range(len(self._make_obs_spec()))]
self.behavior_spec = BehaviorSpec(self._make_obs_spec(), sensor_type_list, action_spec)
self.action_size = action_size
self.names = brain_names
self.positions: Dict[str, List[float]] = {}

9
ml-agents/mlagents/trainers/tests/tensorflow/test_models.py


from mlagents.trainers.tf.models import ModelUtils
from mlagents.tf_utils import tf
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents_envs.base_env import BehaviorSpec, ActionSpec, SensorType
behavior_spec = BehaviorSpec(
[(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector),
ActionSpec.create_discrete((1,)),
)
obs_shapes = [(84, 84, 3)] * int(num_visual) + [(vector_size,)] * int(num_vector)
sensor_types = [SensorType.OBSERVATION for _ in range(len(obs_shapes))]
behavior_spec = BehaviorSpec(obs_shapes, sensor_types, ActionSpec.create_discrete((1,)))
return behavior_spec

4
ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py


from unittest.mock import MagicMock
from mlagents.trainers.settings import TrainerSettings
import numpy as np
from mlagents_envs.base_env import ActionSpec
from mlagents_envs.base_env import ActionSpec, SensorType
dummy_groupspec = BehaviorSpec([(1,)], dummy_actionspec)
dummy_groupspec = BehaviorSpec([(1,)], [SensorType.OBSERVATION], dummy_actionspec)
return dummy_groupspec

正在加载...
取消
保存