浏览代码

fix mlagents-envs tests

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
b40e7793
共有 6 个文件被更改,包括 56 次插入66 次删除
  1. 6
      ml-agents-envs/mlagents_envs/base_env.py
  2. 8
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  3. 30
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  4. 36
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  5. 40
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  6. 2
      ml-agents/mlagents/trainers/torch/networks.py

6
ml-agents-envs/mlagents_envs/base_env.py


return self.continuous_action_size > 0
@property
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
def discrete_action_branches(self) -> Tuple[int, ...]:
return self.discrete_branch_sizes # type: ignore
@property

@property
def action_size(self) -> int:
return self.discrete_action_size + self.continuous_action_size
@property
def total_action_size(self) -> int:
return sum(self.discrete_action_branches) + self.continuous_action_size
def create_empty_action(self, n_agents: int) -> np.ndarray:
if self.is_action_continuous():

8
ml-agents-envs/mlagents_envs/tests/test_envs.py


decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
n_agents = len(decision_steps)
env.set_actions(
"RealFakeBrain", np.zeros((n_agents, spec.action_size), dtype=np.float32)
"RealFakeBrain",
np.zeros((n_agents, spec.action_spec.action_size), dtype=np.float32),
np.zeros((n_agents - 1, spec.action_size), dtype=np.float32),
np.zeros((n_agents - 1, spec.action_spec.action_size), dtype=np.float32),
"RealFakeBrain", -1 * np.ones((n_agents, spec.action_size), dtype=np.float32)
"RealFakeBrain",
-1 * np.ones((n_agents, spec.action_spec.action_size), dtype=np.float32),
)
env.step()

30
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


from mlagents_envs.communicator_objects.agent_action_pb2 import AgentActionProto
from mlagents_envs.base_env import (
BehaviorSpec,
ActionType,
ActionSpec,
DecisionSteps,
TerminalSteps,
)

def test_batched_step_result_from_proto():
n_agents = 10
shapes = [(3,), (4,)]
spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 3)
spec = BehaviorSpec(shapes, ActionSpec(3, ()))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, spec)
for agent_id in range(n_agents):

def test_action_masking_discrete():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.DISCRETE, (7, 3))
behavior_spec = BehaviorSpec(shapes, ActionSpec(0, (7, 3)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_1():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.DISCRETE, (10,))
behavior_spec = BehaviorSpec(shapes, ActionSpec(0, (10,)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_2():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.DISCRETE, (2, 2, 6))
behavior_spec = BehaviorSpec(shapes, ActionSpec(0, (2, 2, 6)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_continuous():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 10)
behavior_spec = BehaviorSpec(shapes, ActionSpec(10, ()))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

bp.vector_action_size.extend([5, 4])
bp.vector_action_space_type = 0
behavior_spec = behavior_spec_from_proto(bp, agent_proto)
assert behavior_spec.is_action_discrete()
assert not behavior_spec.is_action_continuous()
assert behavior_spec.action_spec.is_action_discrete()
assert not behavior_spec.action_spec.is_action_continuous()
assert behavior_spec.discrete_action_branches == (5, 4)
assert behavior_spec.action_size == 2
assert behavior_spec.action_spec.discrete_action_branches == (5, 4)
assert behavior_spec.action_spec.action_size == 2
assert not behavior_spec.is_action_discrete()
assert behavior_spec.is_action_continuous()
assert behavior_spec.action_size == 6
assert not behavior_spec.action_spec.is_action_discrete()
assert behavior_spec.action_spec.is_action_continuous()
assert behavior_spec.action_spec.action_size == 6
behavior_spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 3)
behavior_spec = BehaviorSpec(shapes, ActionSpec(3, ()))
ap_list = generate_list_agent_proto(n_agents, shapes, infinite_rewards=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 3)
behavior_spec = BehaviorSpec(shapes, ActionSpec(3, ()))
ap_list = generate_list_agent_proto(n_agents, shapes, nan_observations=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

36
ml-agents-envs/mlagents_envs/tests/test_steps.py


from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,
ActionType,
ActionSpec,
BehaviorSpec,
)

def test_empty_decision_steps():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec(3, ())
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2

def test_empty_terminal_steps():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec(3, ())
)
ts = TerminalSteps.empty(specs)
assert len(ts.obs) == 2

def test_specs():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
)
assert specs.discrete_action_branches is None
specs = ActionSpec(3, ())
assert specs.discrete_action_branches == ()
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.DISCRETE,
action_shape=(3,),
)
specs = ActionSpec(0, (3,))
assert specs.discrete_action_branches == (3,)
assert specs.action_size == 1
assert specs.create_empty_action(5).shape == (5, 1)

def test_action_generator():
# Continuous
action_len = 30
specs = BehaviorSpec(
observation_shapes=[(5,)],
action_type=ActionType.CONTINUOUS,
action_shape=action_len,
)
specs = ActionSpec(action_len, ())
zero_action = specs.create_empty_action(4)
assert np.array_equal(zero_action, np.zeros((4, action_len), dtype=np.float32))
random_action = specs.create_random_action(4)

# Discrete
action_shape = (10, 20, 30)
specs = BehaviorSpec(
observation_shapes=[(5,)],
action_type=ActionType.DISCRETE,
action_shape=action_shape,
)
specs = ActionSpec(0, action_shape)
zero_action = specs.create_empty_action(4)
assert np.array_equal(zero_action, np.zeros((4, len(action_shape)), dtype=np.int32))

40
ml-agents/mlagents/trainers/tests/torch/test_networks.py


SeparateActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionType
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,

assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
def test_simple_actor(action_type):
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_actor(use_discrete):
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
if use_discrete:
masks = torch.ones((1, 1))
action_spec = ActionSpec(0, tuple(act_size))
else:
masks = None
action_spec = ActionSpec(act_size[0], ())
actor = SimpleActor(obs_shapes, network_settings, action_spec)
if action_type == ActionType.CONTINUOUS:
assert isinstance(dist, GaussianDistInstance)
if use_discrete:
assert isinstance(dist, CategoricalDistInstance)
assert isinstance(dist, CategoricalDistInstance)
assert isinstance(dist, GaussianDistInstance)
if action_type == ActionType.CONTINUOUS:
assert act.shape == (1, act_size[0])
else:
if use_discrete:
else:
assert act.shape == (1, act_size[0])
# Test forward
actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward(

# This is different from above for ONNX export
if action_type == ActionType.CONTINUOUS:
if use_discrete:
assert act.shape == tuple(act_size)
else:
else:
assert act.shape == tuple(act_size)
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert is_cont == int(not use_discrete)
assert act_size_vec == torch.tensor(act_size)

obs_shapes = [(obs_size,)]
act_size = [2]
stream_names = [f"stream_name{n}" for n in range(4)]
actor = ac_type(
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
)
action_spec = ActionSpec(act_size[0], ())
actor = ac_type(obs_shapes, network_settings, action_spec, stream_names)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(

2
ml-agents/mlagents/trainers/torch/networks.py


torch.Tensor([int(self.act_type == ActionType.CONTINUOUS)])
)
self.act_size_vector = torch.nn.Parameter(
torch.Tensor([self.action_spec.action_size]), requires_grad=False
torch.Tensor([self.action_spec.total_action_size]), requires_grad=False
)
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:

正在加载...
取消
保存