浏览代码

Fix half of the tests

/MLA-1734-demo-provider
Arthur Juliani 4 年前
当前提交
e331fb63
共有 2 个文件被更改,包括 20 次插入9 次删除
  1. 22
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  2. 7
      ml-agents-envs/mlagents_envs/tests/test_steps.py

22
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


from mlagents_envs.base_env import (
BehaviorSpec,
ActionSpec,
SensorType,
DecisionSteps,
TerminalSteps,
)

def test_batched_step_result_from_proto():
n_agents = 10
shapes = [(3,), (4,)]
spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, spec)
for agent_id in range(n_agents):

def test_action_masking_discrete():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((7, 3)))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_discrete((7, 3)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_1():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((10,)))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_discrete((10,)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_2():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((2, 2, 6)))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_discrete((2, 2, 6)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_continuous():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(10))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(10))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_batched_step_result_from_proto_raises_on_infinite():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes, infinite_rewards=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
behavior_spec = BehaviorSpec(shapes, sensor_type, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes, nan_observations=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

7
ml-agents-envs/mlagents_envs/tests/test_steps.py


TerminalSteps,
ActionSpec,
BehaviorSpec,
SensorType,
)

def test_empty_decision_steps():
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec.create_continuous(3)
observation_shapes=[(3, 2), (5,)], sensor_types=sensor_type, action_spec=ActionSpec.create_continuous(3)
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2

def test_empty_terminal_steps():
sensor_type = [SensorType.OBSERVATION, SensorType.OBSERVATION]
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec.create_continuous(3)
observation_shapes=[(3, 2), (5,)], sensor_types=sensor_type, action_spec=ActionSpec.create_continuous(3)
)
ts = TerminalSteps.empty(specs)
assert len(ts.obs) == 2

正在加载...
取消
保存