浏览代码

add ActionSpec; test_simple_rl torch passes

/develop/action-spec-gym
Andrew Cohen 4 年前
当前提交
7827ca06
共有 8 个文件被更改,包括 113 次插入131 次删除
  1. 101
      ml-agents-envs/mlagents_envs/base_env.py
  2. 27
      ml-agents-envs/mlagents_envs/rpc_utils.py
  3. 15
      ml-agents/mlagents/trainers/policy/policy.py
  4. 32
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  5. 17
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
  6. 26
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  7. 24
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
  8. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py

101
ml-agents-envs/mlagents_envs/base_env.py


class ActionType(Enum):
DISCRETE = 0
CONTINUOUS = 1
HYBRID = 2
class BehaviorSpec(NamedTuple):
"""
A NamedTuple to containing information about the observations and actions
spaces for a group of Agents under the same behavior.
- observation_shapes is a List of Tuples of int : Each Tuple corresponds
to an observation's dimensions. The shape tuples have the same ordering as
the ordering of the DecisionSteps and TerminalSteps.
- action_type is the type of data of the action. it can be discrete or
continuous. If discrete, the action tensors are expected to be int32. If
continuous, the actions are expected to be float32.
- action_shape is:
- An int in continuous action space corresponding to the number of
floats that constitute the action.
- A Tuple of int in discrete action space where each int corresponds to
the number of discrete actions available to the agent.
"""
class ActionSpec(NamedTuple):
num_continuous_actions: int
discrete_branch_sizes: Tuple[int, ...]
observation_shapes: List[Tuple]
action_type: ActionType
action_shape: Union[int, Tuple[int, ...]]
# For backwards compatibility
return self.action_type == ActionType.DISCRETE
return self.discrete_action_size > 0
# For backwards compatibility
return self.action_type == ActionType.CONTINUOUS
return self.continuous_action_size > 0
@property
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
return self.discrete_branch_sizes # type: ignore
def action_size(self) -> int:
"""
Returns the dimension of the action.
- In the continuous case, will return the number of continuous actions.
- In the (multi-)discrete case, will return the number of action.
branches.
"""
if self.action_type == ActionType.DISCRETE:
return len(self.action_shape) # type: ignore
else:
return self.action_shape # type: ignore
def discrete_action_size(self) -> int:
return len(self.discrete_branch_sizes)
@property
def continuous_action_size(self) -> int:
return self.num_continuous_actions
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
"""
Returns a Tuple of int corresponding to the number of possible actions
for each branch (only for discrete actions). Will return None in
for continuous actions.
"""
if self.action_type == ActionType.DISCRETE:
return self.action_shape # type: ignore
else:
return None
def action_size(self) -> int:
return self.discrete_action_size + self.continuous_action_size
"""
Generates a numpy array corresponding to an empty action (all zeros)
for a number of agents.
:param n_agents: The number of agents that will have actions generated
"""
if self.action_type == ActionType.DISCRETE:
return np.zeros((n_agents, self.action_size), dtype=np.int32)
else:
return np.zeros((n_agents, self.action_size), dtype=np.float32)
if self.is_action_continuous():
return np.zeros((n_agents, self.continuous_action_size), dtype=np.float32)
return np.zeros((n_agents, self.discrete_action_size), dtype=np.int32)
"""
Generates a numpy array corresponding to a random action (either discrete
or continuous) for a number of agents.
:param n_agents: The number of agents that will have actions generated
:param generator: The random number generator used for creating random action
"""
low=-1.0, high=1.0, size=(n_agents, self.action_size)
low=-1.0, high=1.0, size=(n_agents, self.continuous_action_size)
return action
elif self.is_action_discrete():
else:
branch_size = self.discrete_action_branches
action = np.column_stack(
[

size=(n_agents),
dtype=np.int32,
)
for i in range(self.action_size)
for i in range(self.discrete_action_size)
return action
return action
class BehaviorSpec(NamedTuple):
observation_shapes: List[Tuple]
action_spec: ActionSpec
class BehaviorMapping(Mapping):

"""
@abstractmethod
def set_actions(self, behavior_name: BehaviorName, action: np.ndarray) -> None:
def set_actions(
self, behavior_name: BehaviorName, action: Union[np.ndarray]
) -> None:
"""
Sets the action for all of the agents in the simulation for the next
step. The Actions must be in the same order as the order received in

@abstractmethod
def set_action_for_agent(
self, behavior_name: BehaviorName, agent_id: AgentId, action: np.ndarray
self, behavior_name: BehaviorName, agent_id: AgentId, action: Union[np.ndarray]
) -> None:
"""
Sets the action for one of the agents in the simulation for the next

27
ml-agents-envs/mlagents_envs/rpc_utils.py


from mlagents_envs.base_env import (
BehaviorSpec,
ActionType,
ActionSpec,
DecisionSteps,
TerminalSteps,
)

from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
import numpy as np
import io
from typing import cast, List, Tuple, Union, Collection, Optional, Iterable
from typing import cast, List, Tuple, Collection, Optional, Iterable
from PIL import Image

:return: BehaviorSpec object.
"""
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
action_type = (
ActionType.DISCRETE
if brain_param_proto.vector_action_space_type == 0
else ActionType.CONTINUOUS
)
if action_type == ActionType.CONTINUOUS:
action_shape: Union[
int, Tuple[int, ...]
] = brain_param_proto.vector_action_size[0]
if brain_param_proto.vector_action_space_type == 1:
action_spec = ActionSpec(brain_param_proto.vector_action_size[0], ())
action_shape = tuple(brain_param_proto.vector_action_size)
return BehaviorSpec(observation_shape, action_type, action_shape)
action_spec = ActionSpec(0, tuple(brain_param_proto.vector_action_size))
return BehaviorSpec(observation_shape, action_spec)
class OffsetBytesIO:

[agent_info.id for agent_info in terminal_agent_info_list], dtype=np.int32
)
action_mask = None
if behavior_spec.is_action_discrete():
if behavior_spec.action_spec.is_action_discrete():
a_size = np.sum(behavior_spec.discrete_action_branches)
a_size = np.sum(behavior_spec.action_spec.discrete_action_branches)
mask_matrix = np.ones((n_agents, a_size), dtype=np.bool)
for agent_index, agent_info in enumerate(decision_agent_info_list):
if agent_info.action_mask is not None:

for k in range(a_size)
]
action_mask = (1 - mask_matrix).astype(np.bool)
indices = _generate_split_indices(behavior_spec.discrete_action_branches)
indices = _generate_split_indices(
behavior_spec.action_spec.discrete_action_branches
)
action_mask = np.split(action_mask, indices, axis=1)
return (
DecisionSteps(

15
ml-agents/mlagents/trainers/policy/policy.py


condition_sigma_on_obs: bool = True,
):
self.behavior_spec = behavior_spec
self.action_spec = behavior_spec.action_spec
# For mixed action spaces
self.continuous_act_size = self.action_spec.continuous_action_size
self.discrete_act_size = self.action_spec.discrete_action_size
self.discrete_act_branches = self.action_spec.discrete_action_branches
list(behavior_spec.discrete_action_branches)
if behavior_spec.is_action_discrete()
else [behavior_spec.action_size]
list(self.action_spec.discrete_action_branches)
if self.action_spec.is_action_discrete()
else [self.action_spec.action_size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1

)
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.use_continuous_act = self.action_spec.is_action_continuous()
self.num_branches = self.action_spec.action_size
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize

32
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


CuriosityRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import CuriositySettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

SEED = [42]
ACTIONSPEC_CONTINUOUS = ActionSpec(5, ())
ACTIONSPEC_TWODISCRETE = ActionSpec(0, (2, 3))
ACTIONSPEC_DISCRETE = ActionSpec(0, (2,))
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 1)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:

@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5)]
"behavior_spec", [BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS)]
)
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:

17
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py


ExtrinsicRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
ACTIONSPEC_CONTINUOUS = ActionSpec(5, ())
ACTIONSPEC_TWODISCRETE = ActionSpec(0, (2, 3))
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None:

26
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


GAILRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import GAILSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

)
SEED = [42]
ACTIONSPEC_CONTINUOUS = ActionSpec(2, ())
ACTIONSPEC_FOURDISCRETE = ActionSpec(0, (2, 3, 3, 3))
ACTIONSPEC_DISCRETE = ActionSpec(0, (20,))
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
@pytest.mark.parametrize("behavior_spec", [BehaviorSpec([(8,)], ACTIONSPEC_CONTINUOUS)])
def test_construction(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = GAILRewardProvider(behavior_spec, gail_settings)

@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
@pytest.mark.parametrize("behavior_spec", [BehaviorSpec([(8,)], ACTIONSPEC_CONTINUOUS)])
def test_factory(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = create_reward_provider(

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
BehaviorSpec([(8,), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(50,)], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
BehaviorSpec([(8,), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(50,)], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

24
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py


RNDRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import RNDSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

ACTIONSPEC_CONTINUOUS = ActionSpec(5, ())
ACTIONSPEC_TWODISCRETE = ActionSpec(0, (2, 3))
ACTIONSPEC_DISCRETE = ActionSpec(0, (2,))
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 1)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py


next_observations = [
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes
]
action = behavior_spec.create_random_action(1)[0, :]
action = behavior_spec.action_spec.create_random_action(1)[0, :]
for _ in range(number):
curr_split_obs = SplitObservations.from_observations(curr_observations)
next_split_obs = SplitObservations.from_observations(next_observations)

正在加载...
取消
保存