浏览代码

Add ActionSpec (#4586)

Co-authored-by: Ervin T <ervin@unity3d.com>
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
cb8e4d25
共有 49 个文件被更改,包括 440 次插入422 次删除
  1. 21
      gym-unity/gym_unity/envs/__init__.py
  2. 8
      gym-unity/gym_unity/tests/test_gym.py
  3. 154
      ml-agents-envs/mlagents_envs/base_env.py
  4. 35
      ml-agents-envs/mlagents_envs/environment.py
  5. 27
      ml-agents-envs/mlagents_envs/rpc_utils.py
  6. 15
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  7. 30
      ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py
  8. 60
      ml-agents-envs/mlagents_envs/tests/test_steps.py
  9. 13
      ml-agents/mlagents/trainers/demo_loader.py
  10. 19
      ml-agents/mlagents/trainers/policy/policy.py
  11. 6
      ml-agents/mlagents/trainers/policy/tf_policy.py
  12. 3
      ml-agents/mlagents/trainers/policy/torch_policy.py
  13. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  14. 18
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  15. 60
      ml-agents/mlagents/trainers/tests/mock_brain.py
  16. 15
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  17. 2
      ml-agents/mlagents/trainers/tests/tensorflow/test_ghost.py
  18. 5
      ml-agents/mlagents/trainers/tests/tensorflow/test_models.py
  19. 8
      ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py
  20. 12
      ml-agents/mlagents/trainers/tests/tensorflow/test_ppo.py
  21. 9
      ml-agents/mlagents/trainers/tests/tensorflow/test_sac.py
  22. 4
      ml-agents/mlagents/trainers/tests/tensorflow/test_saver.py
  23. 8
      ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py
  24. 25
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py
  25. 19
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  26. 6
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  27. 4
      ml-agents/mlagents/trainers/tests/test_trajectory.py
  28. 2
      ml-agents/mlagents/trainers/tests/torch/test_ghost.py
  29. 41
      ml-agents/mlagents/trainers/tests/torch/test_networks.py
  30. 11
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  31. 8
      ml-agents/mlagents/trainers/tests/torch/test_ppo.py
  32. 32
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py
  33. 18
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py
  34. 27
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py
  35. 25
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py
  36. 2
      ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py
  37. 6
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  38. 2
      ml-agents/mlagents/trainers/tf/components/bc/model.py
  39. 4
      ml-agents/mlagents/trainers/tf/components/bc/module.py
  40. 2
      ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/model.py
  41. 2
      ml-agents/mlagents/trainers/tf/components/reward_signals/gail/model.py
  42. 2
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  43. 16
      ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py
  44. 3
      ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py
  45. 1
      ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py
  46. 4
      ml-agents/mlagents/trainers/torch/model_serialization.py
  47. 42
      ml-agents/mlagents/trainers/torch/networks.py
  48. 16
      ml-agents/mlagents/trainers/torch/utils.py
  49. 8
      ml-agents/tests/yamato/scripts/run_llapi.py

21
gym-unity/gym_unity/envs/__init__.py


self._previous_decision_step = decision_steps
# Set action spaces
if self.group_spec.is_action_discrete():
branches = self.group_spec.discrete_action_branches
if self.group_spec.action_size == 1:
if self.group_spec.action_spec.is_discrete():
self.action_size = self.group_spec.action_spec.discrete_size
branches = self.group_spec.action_spec.discrete_branches
if self.group_spec.action_spec.discrete_size == 1:
self._action_space = spaces.Discrete(branches[0])
else:
if flatten_branched:

self._action_space = spaces.MultiDiscrete(branches)
else:
elif self.group_spec.action_spec.is_continuous():
high = np.array([1] * self.group_spec.action_shape)
self.action_size = self.group_spec.action_spec.continuous_size
high = np.array([1] * self.group_spec.action_spec.continuous_size)
else:
raise UnityGymException(
"The gym wrapper does not provide explicit support for both discrete "
"and continuous actions."
)
# Set observations space
list_spaces: List[gym.Space] = []

# Translate action into list
action = self._flattener.lookup_action(action)
spec = self.group_spec
action = np.array(action).reshape((1, spec.action_size))
action = np.array(action).reshape((1, self.action_size))
self._env.set_actions(self.name, action)
self._env.step()

8
gym-unity/gym_unity/tests/test_gym.py


from gym_unity.envs import UnityToGymWrapper
from mlagents_envs.base_env import (
BehaviorSpec,
ActionType,
ActionSpec,
DecisionSteps,
TerminalSteps,
BehaviorMapping,

Creates a mock BrainParameters object with parameters.
"""
# Avoid using mutable object as default param
act_type = ActionType.DISCRETE
act_type = ActionType.CONTINUOUS
action_spec = ActionSpec.create_continuous(vector_action_space_size)
action_spec = ActionSpec.create_discrete(vector_action_space_size)
return BehaviorSpec(obs_shapes, act_type, vector_action_space_size)
return BehaviorSpec(obs_shapes, action_spec)
def create_mock_vector_steps(specs, num_agents=1, number_visual_observations=0):

154
ml-agents-envs/mlagents_envs/base_env.py


NamedTuple,
Tuple,
Optional,
Union,
Dict,
Iterator,
Any,

from enum import Enum
from mlagents_envs.exception import UnityActionException
AgentId = int
BehaviorName = str

)
class ActionType(Enum):
DISCRETE = 0
CONTINUOUS = 1
class BehaviorSpec(NamedTuple):
class ActionSpec(NamedTuple):
A NamedTuple to containing information about the observations and actions
spaces for a group of Agents under the same behavior.
- observation_shapes is a List of Tuples of int : Each Tuple corresponds
to an observation's dimensions. The shape tuples have the same ordering as
the ordering of the DecisionSteps and TerminalSteps.
- action_type is the type of data of the action. it can be discrete or
continuous. If discrete, the action tensors are expected to be int32. If
continuous, the actions are expected to be float32.
- action_shape is:
- An int in continuous action space corresponding to the number of
floats that constitute the action.
- A Tuple of int in discrete action space where each int corresponds to
the number of discrete actions available to the agent.
A NamedTuple containing utility functions and information about the action spaces
for a group of Agents under the same behavior.
- num_continuous_actions is an int corresponding to the number of floats which
constitute the action.
- discrete_branch_sizes is a Tuple of int where each int corresponds to
the number of discrete actions available to the agent on an independent action branch.
observation_shapes: List[Tuple]
action_type: ActionType
action_shape: Union[int, Tuple[int, ...]]
continuous_size: int
discrete_branches: Tuple[int, ...]
def __eq__(self, other):
return (
self.continuous_size == other.continuous_size
and self.discrete_branches == other.discrete_branches
)
def __str__(self):
return f"Continuous: {self.continuous_size}, Discrete: {self.discrete_branches}"
def is_action_discrete(self) -> bool:
# For backwards compatibility
def is_discrete(self) -> bool:
return self.action_type == ActionType.DISCRETE
return self.discrete_size > 0 and self.continuous_size == 0
def is_action_continuous(self) -> bool:
# For backwards compatibility
def is_continuous(self) -> bool:
return self.action_type == ActionType.CONTINUOUS
return self.discrete_size == 0 and self.continuous_size > 0
def action_size(self) -> int:
def discrete_size(self) -> int:
Returns the dimension of the action.
- In the continuous case, will return the number of continuous actions.
- In the (multi-)discrete case, will return the number of action.
branches.
Returns a an int corresponding to the number of discrete branches.
if self.action_type == ActionType.DISCRETE:
return len(self.action_shape) # type: ignore
else:
return self.action_shape # type: ignore
@property
def discrete_action_branches(self) -> Optional[Tuple[int, ...]]:
"""
Returns a Tuple of int corresponding to the number of possible actions
for each branch (only for discrete actions). Will return None in
for continuous actions.
"""
if self.action_type == ActionType.DISCRETE:
return self.action_shape # type: ignore
else:
return None
return len(self.discrete_branches)
def create_empty_action(self, n_agents: int) -> np.ndarray:
def empty_action(self, n_agents: int) -> np.ndarray:
if self.action_type == ActionType.DISCRETE:
return np.zeros((n_agents, self.action_size), dtype=np.int32)
else:
return np.zeros((n_agents, self.action_size), dtype=np.float32)
if self.is_continuous():
return np.zeros((n_agents, self.continuous_size), dtype=np.float32)
return np.zeros((n_agents, self.discrete_size), dtype=np.int32)
def create_random_action(self, n_agents: int) -> np.ndarray:
def random_action(self, n_agents: int) -> np.ndarray:
:param generator: The random number generator used for creating random action
if self.is_action_continuous():
if self.is_continuous():
low=-1.0, high=1.0, size=(n_agents, self.action_size)
low=-1.0, high=1.0, size=(n_agents, self.continuous_size)
return action
elif self.is_action_discrete():
branch_size = self.discrete_action_branches
else:
branch_size = self.discrete_branches
action = np.column_stack(
[
np.random.randint(

dtype=np.int32,
)
for i in range(self.action_size)
for i in range(self.discrete_size)
return action
return action
def _validate_action(
self, actions: np.ndarray, n_agents: int, name: str
) -> np.ndarray:
"""
Validates that action has the correct action dim
for the correct number of agents and ensures the type.
"""
if self.continuous_size > 0:
_size = self.continuous_size
else:
_size = self.discrete_size
_expected_shape = (n_agents, _size)
if actions.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs an input of dimension "
f"{_expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {actions.shape}"
)
_expected_type = np.float32 if self.is_continuous() else np.int32
if actions.dtype != _expected_type:
actions = actions.astype(_expected_type)
return actions
@staticmethod
def create_continuous(continuous_size: int) -> "ActionSpec":
"""
Creates an ActionSpec that is homogenously continuous
"""
return ActionSpec(continuous_size, ())
@staticmethod
def create_discrete(discrete_branches: Tuple[int]) -> "ActionSpec":
"""
Creates an ActionSpec that is homogenously discrete
"""
return ActionSpec(0, discrete_branches)
class BehaviorSpec(NamedTuple):
"""
A NamedTuple containing information about the observation and action
spaces for a group of Agents under the same behavior.
- observation_shapes is a List of Tuples of int : Each Tuple corresponds
to an observation's dimensions. The shape tuples have the same ordering as
the ordering of the DecisionSteps and TerminalSteps.
- action_spec is an ActionSpec NamedTuple
"""
observation_shapes: List[Tuple]
action_spec: ActionSpec
class BehaviorMapping(Mapping):

35
ml-agents-envs/mlagents_envs/environment.py


n_agents = len(self._env_state[group_name][0])
self._env_actions[group_name] = self._env_specs[
group_name
].create_empty_action(n_agents)
].action_spec.empty_action(n_agents)
step_input = self._generate_step_input(self._env_actions)
with hierarchical_timer("communicator.exchange"):
outputs = self._communicator.exchange(step_input)

self._assert_behavior_exists(behavior_name)
if behavior_name not in self._env_state:
return
spec = self._env_specs[behavior_name]
expected_type = np.float32 if spec.is_action_continuous() else np.int32
expected_shape = (len(self._env_state[behavior_name][0]), spec.action_size)
if action.shape != expected_shape:
raise UnityActionException(
f"The behavior {behavior_name} needs an input of dimension "
f"{expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {action.shape}"
)
if action.dtype != expected_type:
action = action.astype(expected_type)
action_spec = self._env_specs[behavior_name].action_spec
num_agents = len(self._env_state[behavior_name][0])
action = action_spec._validate_action(action, num_agents, behavior_name)
self._env_actions[behavior_name] = action
def set_action_for_agent(

if behavior_name not in self._env_state:
return
spec = self._env_specs[behavior_name]
expected_shape = (spec.action_size,)
if action.shape != expected_shape:
raise UnityActionException(
f"The Agent {agent_id} with BehaviorName {behavior_name} needs "
f"an input of dimension {expected_shape} but received input of "
f"dimension {action.shape}"
)
expected_type = np.float32 if spec.is_action_continuous() else np.int32
if action.dtype != expected_type:
action = action.astype(expected_type)
action_spec = self._env_specs[behavior_name].action_spec
num_agents = len(self._env_state[behavior_name][0])
action = action_spec._validate_action(action, num_agents, behavior_name)
self._env_actions[behavior_name] = spec.create_empty_action(
len(self._env_state[behavior_name][0])
)
self._env_actions[behavior_name] = action_spec.empty_action(num_agents)
try:
index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][
0

27
ml-agents-envs/mlagents_envs/rpc_utils.py


from mlagents_envs.base_env import (
BehaviorSpec,
ActionType,
ActionSpec,
DecisionSteps,
TerminalSteps,
)

from mlagents_envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
import numpy as np
import io
from typing import cast, List, Tuple, Union, Collection, Optional, Iterable
from typing import cast, List, Tuple, Collection, Optional, Iterable
from PIL import Image

:return: BehaviorSpec object.
"""
observation_shape = [tuple(obs.shape) for obs in agent_info.observations]
action_type = (
ActionType.DISCRETE
if brain_param_proto.vector_action_space_type == 0
else ActionType.CONTINUOUS
)
if action_type == ActionType.CONTINUOUS:
action_shape: Union[
int, Tuple[int, ...]
] = brain_param_proto.vector_action_size[0]
if brain_param_proto.vector_action_space_type == 1:
action_spec = ActionSpec(brain_param_proto.vector_action_size[0], ())
action_shape = tuple(brain_param_proto.vector_action_size)
return BehaviorSpec(observation_shape, action_type, action_shape)
action_spec = ActionSpec(0, tuple(brain_param_proto.vector_action_size))
return BehaviorSpec(observation_shape, action_spec)
class OffsetBytesIO:

[agent_info.id for agent_info in terminal_agent_info_list], dtype=np.int32
)
action_mask = None
if behavior_spec.is_action_discrete():
if behavior_spec.action_spec.discrete_size > 0:
a_size = np.sum(behavior_spec.discrete_action_branches)
a_size = np.sum(behavior_spec.action_spec.discrete_branches)
mask_matrix = np.ones((n_agents, a_size), dtype=np.bool)
for agent_index, agent_info in enumerate(decision_agent_info_list):
if agent_info.action_mask is not None:

for k in range(a_size)
]
action_mask = (1 - mask_matrix).astype(np.bool)
indices = _generate_split_indices(behavior_spec.discrete_action_branches)
indices = _generate_split_indices(
behavior_spec.action_spec.discrete_branches
)
action_mask = np.split(action_mask, indices, axis=1)
return (
DecisionSteps(

15
ml-agents-envs/mlagents_envs/tests/test_envs.py


from unittest import mock
import pytest
import numpy as np
from mlagents_envs.environment import UnityEnvironment
from mlagents_envs.base_env import DecisionSteps, TerminalSteps
from mlagents_envs.exception import UnityEnvironmentException, UnityActionException

env.step()
decision_steps, terminal_steps = env.get_steps("RealFakeBrain")
n_agents = len(decision_steps)
env.set_actions(
"RealFakeBrain", np.zeros((n_agents, spec.action_size), dtype=np.float32)
)
env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents))
env.set_actions(
"RealFakeBrain",
np.zeros((n_agents - 1, spec.action_size), dtype=np.float32),
)
env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents - 1))
env.set_actions(
"RealFakeBrain", -1 * np.ones((n_agents, spec.action_size), dtype=np.float32)
)
env.set_actions("RealFakeBrain", spec.action_spec.empty_action(n_agents) - 1)
env.step()
env.close()

30
ml-agents-envs/mlagents_envs/tests/test_rpc_utils.py


from mlagents_envs.communicator_objects.agent_action_pb2 import AgentActionProto
from mlagents_envs.base_env import (
BehaviorSpec,
ActionType,
ActionSpec,
DecisionSteps,
TerminalSteps,
)

def test_batched_step_result_from_proto():
n_agents = 10
shapes = [(3,), (4,)]
spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 3)
spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, spec)
for agent_id in range(n_agents):

def test_action_masking_discrete():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.DISCRETE, (7, 3))
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((7, 3)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_1():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.DISCRETE, (10,))
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((10,)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_discrete_2():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.DISCRETE, (2, 2, 6))
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_discrete((2, 2, 6)))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

def test_action_masking_continuous():
n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 10)
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(10))
ap_list = generate_list_agent_proto(n_agents, shapes)
decision_steps, terminal_steps = steps_from_proto(ap_list, behavior_spec)
masks = decision_steps.action_mask

bp.vector_action_size.extend([5, 4])
bp.vector_action_space_type = 0
behavior_spec = behavior_spec_from_proto(bp, agent_proto)
assert behavior_spec.is_action_discrete()
assert not behavior_spec.is_action_continuous()
assert behavior_spec.action_spec.is_discrete()
assert not behavior_spec.action_spec.is_continuous()
assert behavior_spec.discrete_action_branches == (5, 4)
assert behavior_spec.action_size == 2
assert behavior_spec.action_spec.discrete_branches == (5, 4)
assert behavior_spec.action_spec.discrete_size == 2
assert not behavior_spec.is_action_discrete()
assert behavior_spec.is_action_continuous()
assert behavior_spec.action_size == 6
assert not behavior_spec.action_spec.is_discrete()
assert behavior_spec.action_spec.is_continuous()
assert behavior_spec.action_spec.continuous_size == 6
behavior_spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 3)
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes, infinite_rewards=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

n_agents = 10
shapes = [(3,), (4,)]
behavior_spec = BehaviorSpec(shapes, ActionType.CONTINUOUS, 3)
behavior_spec = BehaviorSpec(shapes, ActionSpec.create_continuous(3))
ap_list = generate_list_agent_proto(n_agents, shapes, nan_observations=True)
with pytest.raises(RuntimeError):
steps_from_proto(ap_list, behavior_spec)

60
ml-agents-envs/mlagents_envs/tests/test_steps.py


from mlagents_envs.base_env import (
DecisionSteps,
TerminalSteps,
ActionType,
ActionSpec,
BehaviorSpec,
)

def test_empty_decision_steps():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec.create_continuous(3)
)
ds = DecisionSteps.empty(specs)
assert len(ds.obs) == 2

def test_empty_terminal_steps():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
observation_shapes=[(3, 2), (5,)], action_spec=ActionSpec.create_continuous(3)
)
ts = TerminalSteps.empty(specs)
assert len(ts.obs) == 2

def test_specs():
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.CONTINUOUS,
action_shape=3,
)
assert specs.discrete_action_branches is None
assert specs.action_size == 3
assert specs.create_empty_action(5).shape == (5, 3)
assert specs.create_empty_action(5).dtype == np.float32
specs = ActionSpec.create_continuous(3)
assert specs.discrete_branches == ()
assert specs.discrete_size == 0
assert specs.continuous_size == 3
assert specs.empty_action(5).shape == (5, 3)
assert specs.empty_action(5).dtype == np.float32
specs = BehaviorSpec(
observation_shapes=[(3, 2), (5,)],
action_type=ActionType.DISCRETE,
action_shape=(3,),
)
assert specs.discrete_action_branches == (3,)
assert specs.action_size == 1
assert specs.create_empty_action(5).shape == (5, 1)
assert specs.create_empty_action(5).dtype == np.int32
specs = ActionSpec.create_discrete((3,))
assert specs.discrete_branches == (3,)
assert specs.discrete_size == 1
assert specs.continuous_size == 0
assert specs.empty_action(5).shape == (5, 1)
assert specs.empty_action(5).dtype == np.int32
specs = BehaviorSpec(
observation_shapes=[(5,)],
action_type=ActionType.CONTINUOUS,
action_shape=action_len,
)
zero_action = specs.create_empty_action(4)
specs = ActionSpec.create_continuous(action_len)
zero_action = specs.empty_action(4)
random_action = specs.create_random_action(4)
random_action = specs.random_action(4)
assert random_action.dtype == np.float32
assert random_action.shape == (4, action_len)
assert np.min(random_action) >= -1

action_shape = (10, 20, 30)
specs = BehaviorSpec(
observation_shapes=[(5,)],
action_type=ActionType.DISCRETE,
action_shape=action_shape,
)
zero_action = specs.create_empty_action(4)
specs = ActionSpec.create_discrete(action_shape)
zero_action = specs.empty_action(4)
random_action = specs.create_random_action(4)
random_action = specs.random_action(4)
assert random_action.dtype == np.int32
assert random_action.shape == (4, len(action_shape))
assert np.min(random_action) >= 0

13
ml-agents/mlagents/trainers/demo_loader.py


demo_buffer = make_demo_buffer(info_action_pair, behavior_spec, sequence_length)
if expected_behavior_spec:
# check action dimensions in demonstration match
if behavior_spec.action_shape != expected_behavior_spec.action_shape:
if behavior_spec.action_spec != expected_behavior_spec.action_spec:
"The action dimensions {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_shape, expected_behavior_spec.action_shape
)
)
# check the action types in demonstration match
if behavior_spec.action_type != expected_behavior_spec.action_type:
raise RuntimeError(
"The action type of {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_type, expected_behavior_spec.action_type
"The action spaces {} in demonstration do not match the policy's {}.".format(
behavior_spec.action_spec, expected_behavior_spec.action_spec
)
)
# check observations match

19
ml-agents/mlagents/trainers/policy/policy.py


self.trainer_settings = trainer_settings
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
if (
self.behavior_spec.action_spec.continuous_size > 0
and self.behavior_spec.action_spec.discrete_size > 0
):
raise UnityPolicyException("Trainers do not support mixed action spaces.")
list(behavior_spec.discrete_action_branches)
if behavior_spec.is_action_discrete()
else [behavior_spec.action_size]
list(self.behavior_spec.action_spec.discrete_branches)
if self.behavior_spec.action_spec.is_discrete()
else [self.behavior_spec.action_spec.continuous_size]
)
self.vec_obs_size = sum(
shape[0] for shape in behavior_spec.observation_shapes if len(shape) == 1

)
self.use_continuous_act = behavior_spec.is_action_continuous()
self.num_branches = self.behavior_spec.action_size
self.use_continuous_act = self.behavior_spec.action_spec.is_continuous()
# This line will be removed in the ActionBuffer change
self.num_branches = (
self.behavior_spec.action_spec.continuous_size
+ self.behavior_spec.action_spec.discrete_size
)
self.previous_action_dict: Dict[str, np.array] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize

6
ml-agents/mlagents/trainers/policy/tf_policy.py


mask = np.ones(
(
len(batched_step_result),
sum(self.behavior_spec.discrete_action_branches),
sum(self.behavior_spec.action_spec.discrete_branches),
),
dtype=np.float32,
)

self.mask = tf.cast(self.mask_input, tf.int32)
tf.Variable(
int(self.behavior_spec.is_action_continuous()),
int(self.behavior_spec.action_spec.is_continuous()),
name="is_continuous_control",
trainable=False,
dtype=tf.int32,

tf.Variable(
self.m_size, name="memory_size", trainable=False, dtype=tf.int32
)
if self.behavior_spec.is_action_continuous():
if self.behavior_spec.action_spec.is_continuous():
tf.Variable(
self.act_size[0],
name="action_output_shape",

3
ml-agents/mlagents/trainers/policy/torch_policy.py


self.actor_critic = ac_class(
observation_shapes=self.behavior_spec.observation_shapes,
network_settings=trainer_settings.network_settings,
act_type=behavior_spec.action_type,
act_size=self.act_size,
action_spec=behavior_spec.action_spec,
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,

2
ml-agents/mlagents/trainers/ppo/trainer.py


behavior_spec,
self.trainer_settings,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.is_action_continuous(),
separate_critic=behavior_spec.action_spec.is_continuous(),
)
return policy

18
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents.torch_utils import torch, nn, default_device
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import ActionType
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.settings import NetworkSettings

stream_names: List[str],
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
if act_type == ActionType.CONTINUOUS:
self.action_spec = action_spec
if self.action_spec.is_continuous():
self.act_size = self.action_spec.continuous_size
num_action_ins = sum(act_size)
num_action_ins = self.act_size
num_value_outs = sum(act_size)
self.act_size = self.action_spec.discrete_branches
num_value_outs = sum(self.act_size)
num_action_ins = 0
self.q1_network = ValueNetwork(
stream_names,

self.stream_names,
self.policy.behavior_spec.observation_shapes,
policy_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
self.policy.behavior_spec.action_spec,
)
self.target_network = ValueNetwork(

60
ml-agents/mlagents/trainers/tests/mock_brain.py


from typing import List, Tuple, Union
from collections.abc import Iterable
from typing import List, Tuple
import numpy as np
from mlagents.trainers.buffer import AgentBuffer

TerminalSteps,
BehaviorSpec,
ActionType,
ActionSpec,
)

action_shape: Union[int, Tuple[int]] = None,
discrete: bool = False,
action_spec: ActionSpec,
done: bool = False,
) -> Tuple[DecisionSteps, TerminalSteps]:
"""

:bool discrete: Whether or not action space is discrete
:bool done: Whether all the agents in the batch are done
"""
if action_shape is None:
action_shape = 2
if discrete and isinstance(action_shape, Iterable):
if action_spec.is_discrete():
for action_size in action_shape # type: ignore
for action_size in action_spec.discrete_branches # type: ignore
behavior_spec = BehaviorSpec(
observation_shapes,
ActionType.DISCRETE if discrete else ActionType.CONTINUOUS,
action_shape,
)
behavior_spec = BehaviorSpec(observation_shapes, action_spec)
if done:
return (
DecisionSteps.empty(behavior_spec),

return create_mock_steps(
num_agents=num_agents,
observation_shapes=behavior_spec.observation_shapes,
action_shape=behavior_spec.action_shape,
discrete=behavior_spec.is_action_discrete(),
action_spec=behavior_spec.action_spec,
)

action_spec: ActionSpec,
action_space: Union[int, Tuple[int]] = 2,
is_discrete: bool = True,
) -> Trajectory:
"""
Makes a fake trajectory of length length. If max_step_complete,

action_size = action_spec.discrete_size + action_spec.continuous_size
action_probs = np.ones(
int(np.sum(action_spec.discrete_branches) + action_spec.continuous_size),
dtype=np.float32,
)
for _i in range(length - 1):
obs = []
for _shape in observation_shapes:

if is_discrete:
action_size = len(action_space) # type: ignore
action_probs = np.ones(np.sum(action_space), dtype=np.float32)
else:
action_size = int(action_space) # type: ignore
action_probs = np.ones((action_size), dtype=np.float32)
[[False for _ in range(branch)] for branch in action_space] # type: ignore
if is_discrete
[
[False for _ in range(branch)]
for branch in action_spec.discrete_branches
] # type: ignore
if action_spec.is_discrete()
else None
)
prev_action = np.ones(action_size, dtype=np.float32)

memory_size: int = 10,
exclude_key_list: List[str] = None,
) -> AgentBuffer:
action_space = behavior_spec.action_shape
is_discrete = behavior_spec.is_action_discrete()
action_space=action_space,
action_spec=behavior_spec.action_spec,
is_discrete=is_discrete,
)
buffer = trajectory.to_agentbuffer()
# If a key_list was given, remove those keys

def setup_test_behavior_specs(
use_discrete=True, use_visual=False, vector_action_space=2, vector_obs_space=8
):
if use_discrete:
action_spec = ActionSpec.create_discrete(tuple(vector_action_space))
else:
action_spec = ActionSpec.create_continuous(vector_action_space)
[(84, 84, 3)] * int(use_visual) + [(vector_obs_space,)],
ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS,
tuple(vector_action_space) if use_discrete else vector_action_space,
[(84, 84, 3)] * int(use_visual) + [(vector_obs_space,)], action_spec
)
return behavior_spec

15
ml-agents/mlagents/trainers/tests/simple_test_envs.py


import numpy as np
from mlagents_envs.base_env import (
ActionSpec,
ActionType,
BehaviorMapping,
)
from mlagents_envs.tests.test_rpc_utils import proto_from_steps_and_action

self.num_vector = num_vector
self.vis_obs_size = vis_obs_size
self.vec_obs_size = vec_obs_size
action_type = ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS
self.behavior_spec = BehaviorSpec(
self._make_obs_spec(),
action_type,
tuple(2 for _ in range(action_size)) if use_discrete else action_size,
)
if use_discrete:
action_spec = ActionSpec.create_discrete(
tuple(2 for _ in range(action_size))
)
else:
action_spec = ActionSpec.create_continuous(action_size)
self.behavior_spec = BehaviorSpec(self._make_obs_spec(), action_spec)
self.action_size = action_size
self.names = brain_names
self.positions: Dict[str, List[float]] = {}

2
ml-agents/mlagents/trainers/tests/tensorflow/test_ghost.py


length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=mock_specs.action_spec,
)
trajectory_queue0.put(trajectory)
trainer.advance()

5
ml-agents/mlagents/trainers/tests/tensorflow/test_models.py


from mlagents.trainers.tf.models import ModelUtils
from mlagents.tf_utils import tf
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
ActionType.DISCRETE,
(1,),
ActionSpec.create_discrete((1,)),
)
return behavior_spec

8
ml-agents/mlagents/trainers/tests/tensorflow/test_nn_policy.py


length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=behavior_spec.action_spec,
)
for i in range(time_horizon):
trajectory.steps[i].obs[0] = np.array([large_obs1[i]], dtype=np.float32)

length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=behavior_spec.action_spec,
)
for i in range(time_horizon):
trajectory.steps[i].obs[0] = np.array([large_obs2[i]], dtype=np.float32)

length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=behavior_spec.action_spec,
)
# Change half of the obs to 0
for i in range(3):

length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=behavior_spec.action_spec,
)
trajectory_buffer = trajectory.to_agentbuffer()
policy.update_normalization(trajectory_buffer["vector_obs"])

12
ml-agents/mlagents/trainers/tests/tensorflow/test_ppo.py


ppo_dummy_config,
)
from mlagents_envs.base_env import ActionSpec
@pytest.fixture
def dummy_config():

DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 64
NUM_AGENTS = 12
CONTINUOUS_ACTION_SPEC = ActionSpec.create_continuous(VECTOR_ACTION_SPACE)
DISCRETE_ACTION_SPEC = ActionSpec.create_discrete(tuple(DISCRETE_ACTION_SPACE))
def _create_ppo_optimizer_ops_mock(dummy_config, use_rnn, use_discrete, use_visual):

length=time_horizon,
observation_shapes=optimizer.policy.behavior_spec.observation_shapes,
max_step_complete=True,
action_space=DISCRETE_ACTION_SPACE if discrete else VECTOR_ACTION_SPACE,
is_discrete=discrete,
action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC,
)
run_out, final_value_out = optimizer.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False

length=time_horizon,
observation_shapes=behavior_spec.observation_shapes,
max_step_complete=True,
action_space=[2],
action_spec=behavior_spec.action_spec,
)
trajectory_queue.put(trajectory)
trainer.advance()

length=time_horizon + 1,
max_step_complete=False,
observation_shapes=behavior_spec.observation_shapes,
action_space=[2],
action_spec=behavior_spec.action_spec,
)
trajectory_queue.put(trajectory)
trainer.advance()

9
ml-agents/mlagents/trainers/tests/tensorflow/test_sac.py


length=15,
observation_shapes=specs.observation_shapes,
max_step_complete=True,
action_space=2,
is_discrete=False,
action_spec=specs.action_spec,
)
trajectory_queue.put(trajectory)
trainer.advance()

length=6,
observation_shapes=specs.observation_shapes,
max_step_complete=False,
action_space=2,
is_discrete=False,
action_spec=specs.action_spec,
)
trajectory_queue.put(trajectory)
trainer.advance()

trajectory = make_fake_trajectory(
length=5,
observation_shapes=specs.observation_shapes,
action_spec=specs.action_spec,
action_space=2,
is_discrete=False,
)
trajectory_queue.put(trajectory)
trainer.advance()

4
ml-agents/mlagents/trainers/tests/tensorflow/test_saver.py


length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=behavior_spec.action_spec,
)
# Change half of the obs to 0
for i in range(3):

length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=behavior_spec.action_spec,
)
trajectory_buffer = trajectory.to_agentbuffer()
policy1.update_normalization(trajectory_buffer["vector_obs"])

8
ml-agents/mlagents/trainers/tests/tensorflow/test_simple_rl.py


PPO_TF_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_networksettings,
max_steps=500,
max_steps=300,
summary_freq=100,
framework=FrameworkType.TENSORFLOW,
)

@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_sac(use_discrete):
step_size = 0.5 if use_discrete else 0.2
step_size = 0.2 if use_discrete else 0.5
memory=NetworkSettings.MemorySettings(memory_size=16, sequence_length=16),
memory=NetworkSettings.MemorySettings(memory_size=16),
)
new_hyperparams = attr.evolve(
SAC_TF_CONFIG.hyperparameters,

SAC_TF_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_networksettings,
max_steps=5000,
max_steps=4000,
framework=FrameworkType.TENSORFLOW,
)
_check_environment_trains(env, {BRAIN_NAME: config})

25
ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py


from unittest.mock import MagicMock
from mlagents.trainers.settings import TrainerSettings
import numpy as np
from mlagents_envs.base_env import ActionSpec
def basic_mock_brain():
mock_brain = MagicMock()
mock_brain.vector_action_space_type = "continuous"
mock_brain.vector_observation_space_size = 1
mock_brain.vector_action_space_size = [1]
mock_brain.brain_name = "MockBrain"
return mock_brain
def basic_behavior_spec():
dummy_actionspec = ActionSpec.create_continuous(1)
dummy_groupspec = BehaviorSpec([(1,)], dummy_actionspec)
return dummy_groupspec
class FakePolicy(TFPolicy):

def test_take_action_returns_empty_with_no_agents():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
# Doesn't really matter what this is
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1)
no_agent_step = DecisionSteps.empty(dummy_groupspec)
behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
no_agent_step = DecisionSteps.empty(behavior_spec)
result = policy.get_action(no_agent_step)
assert result == ActionInfo.empty()

policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
policy.evaluate = MagicMock(return_value={})
policy.save_memories = MagicMock()
step_with_agents = DecisionSteps(

def test_take_action_returns_action_info_when_available():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),

19
ml-agents/mlagents/trainers/tests/test_agent_processor.py


from mlagents.trainers.behavior_id_utils import get_global_agent_id
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents_envs.base_env import ActionSpec
def create_mock_policy():
mock_policy = mock.Mock()

mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=2,
observation_shapes=[(8,)] + num_vis_obs * [(84, 84, 3)],
action_shape=2,
action_spec=ActionSpec.create_continuous(2),
)
fake_action_info = ActionInfo(
action=[0.1, 0.1],

mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=0,
observation_shapes=[(8,)] + num_vis_obs * [(84, 84, 3)],
action_shape=2,
action_spec=ActionSpec.create_continuous(2),
)
processor.add_experiences(
mock_decision_steps, mock_terminal_steps, 0, ActionInfo([], [], {}, [])

"log_probs": [0.1],
}
mock_decision_step, mock_terminal_step = mb.create_mock_steps(
num_agents=1, observation_shapes=[(8,)], action_shape=2
num_agents=1,
observation_shapes=[(8,)],
action_spec=ActionSpec.create_continuous(2),
num_agents=1, observation_shapes=[(8,)], action_shape=2, done=True
num_agents=1,
observation_shapes=[(8,)],
action_spec=ActionSpec.create_continuous(2),
done=True,
)
fake_action_info = ActionInfo(
action=[0.1],

"log_probs": [0.1],
}
mock_decision_step, mock_terminal_step = mb.create_mock_steps(
num_agents=1, observation_shapes=[(8,)], action_shape=2
num_agents=1,
observation_shapes=[(8,)],
action_spec=ActionSpec.create_continuous(2),
)
fake_action_info = ActionInfo(
action=[0.1],

6
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.settings import TrainerSettings, FrameworkType
from mlagents_envs.base_env import ActionSpec
# Add concrete implementations of abstract methods
class FakeTrainer(RLTrainer):

length=time_horizon,
observation_shapes=[(1,)],
max_step_complete=True,
action_space=[2],
action_spec=ActionSpec.create_discrete((2,)),
)
trajectory_queue.put(trajectory)

length=time_horizon,
observation_shapes=[(1,)],
max_step_complete=True,
action_space=[2],
action_spec=ActionSpec.create_discrete((2,)),
)
# Check that we can turn off the trainer and that the buffer is cleared
num_trajectories = 5

4
ml-agents/mlagents/trainers/tests/test_trajectory.py


from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.tests.mock_brain import make_fake_trajectory
from mlagents_envs.base_env import ActionSpec
VEC_OBS_SIZE = 6
ACTION_SIZE = 4

trajectory = make_fake_trajectory(
length=length,
observation_shapes=[(VEC_OBS_SIZE,), (84, 84, 3)],
action_space=[ACTION_SIZE],
action_spec=ActionSpec.create_continuous(ACTION_SIZE),
)
agentbuffer = trajectory.to_agentbuffer()
seen_keys = set()

2
ml-agents/mlagents/trainers/tests/torch/test_ghost.py


length=time_horizon,
max_step_complete=True,
observation_shapes=[(1,)],
action_space=[2],
action_spec=mock_specs.action_spec,
)
trajectory_queue0.put(trajectory)
trainer.advance()

41
ml-agents/mlagents/trainers/tests/torch/test_networks.py


SeparateActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionType
from mlagents_envs.base_env import ActionSpec
def test_networkbody_vector():

assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
def test_simple_actor(action_type):
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_actor(use_discrete):
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
if use_discrete:
masks = torch.ones((1, 1))
action_spec = ActionSpec.create_discrete(tuple(act_size))
else:
masks = None
action_spec = ActionSpec.create_continuous(act_size[0])
actor = SimpleActor(obs_shapes, network_settings, action_spec)
if action_type == ActionType.CONTINUOUS:
assert isinstance(dist, GaussianDistInstance)
if use_discrete:
assert isinstance(dist, CategoricalDistInstance)
assert isinstance(dist, CategoricalDistInstance)
assert isinstance(dist, GaussianDistInstance)
if action_type == ActionType.CONTINUOUS:
assert act.shape == (1, act_size[0])
else:
if use_discrete:
else:
assert act.shape == (1, act_size[0])
# Test forward
actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward(

# This is different from above for ONNX export
if action_type == ActionType.CONTINUOUS:
if use_discrete:
assert act.shape == tuple(act_size)
else:
else:
assert act.shape == tuple(act_size)
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert is_cont == int(not use_discrete)
assert act_size_vec == torch.tensor(act_size)

obs_shapes = [(obs_size,)]
act_size = [2]
stream_names = [f"stream_name{n}" for n in range(4)]
actor = ac_type(
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
)
action_spec = ActionSpec.create_continuous(act_size[0])
actor = ac_type(obs_shapes, network_settings, action_spec, stream_names)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(

11
ml-agents/mlagents/trainers/tests/torch/test_policy.py


memories=memories,
seq_len=policy.sequence_length,
)
assert log_probs.shape == (64, policy.behavior_spec.action_size)
if discrete:
_size = policy.behavior_spec.action_spec.discrete_size
else:
_size = policy.behavior_spec.action_spec.continuous_size
assert log_probs.shape == (64, _size)
assert entropy.shape == (64,)
for val in values.values():
assert val.shape == (64,)

if discrete:
assert log_probs.shape == (
64,
sum(policy.behavior_spec.discrete_action_branches),
sum(policy.behavior_spec.action_spec.discrete_branches),
assert log_probs.shape == (64, policy.behavior_spec.action_shape)
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size)
assert entropies.shape == (64,)
if rnn:

8
ml-agents/mlagents/trainers/tests/torch/test_ppo.py


gail_dummy_config,
)
from mlagents_envs.base_env import ActionSpec
@pytest.fixture
def dummy_config():

DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 64
NUM_AGENTS = 12
CONTINUOUS_ACTION_SPEC = ActionSpec.create_continuous(VECTOR_ACTION_SPACE)
DISCRETE_ACTION_SPEC = ActionSpec.create_discrete(tuple(DISCRETE_ACTION_SPACE))
def create_test_ppo_optimizer(dummy_config, use_rnn, use_discrete, use_visual):

trajectory = make_fake_trajectory(
length=time_horizon,
observation_shapes=optimizer.policy.behavior_spec.observation_shapes,
action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC,
action_space=DISCRETE_ACTION_SPACE if discrete else VECTOR_ACTION_SPACE,
is_discrete=discrete,
)
run_out, final_value_out = optimizer.get_trajectory_value_estimates(
trajectory.to_agentbuffer(), trajectory.next_obs, done=False

32
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_curiosity.py


CuriosityRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import CuriositySettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

SEED = [42]
ACTIONSPEC_CONTINUOUS = ActionSpec.create_continuous(5)
ACTIONSPEC_TWODISCRETE = ActionSpec.create_discrete((2, 3))
ACTIONSPEC_DISCRETE = ActionSpec.create_discrete((2,))
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 1)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:

@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5)]
"behavior_spec", [BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS)]
)
def test_continuous_action_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:
np.random.seed(seed)

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_next_state_prediction(behavior_spec: BehaviorSpec, seed: int) -> None:

18
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_extrinsic.py


ExtrinsicRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

ACTIONSPEC_CONTINUOUS = ActionSpec.create_continuous(5)
ACTIONSPEC_TWODISCRETE = ActionSpec.create_discrete((2, 3))
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_reward(behavior_spec: BehaviorSpec, reward: float) -> None:

27
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_gail.py


GAILRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
from mlagents.trainers.settings import GAILSettings, RewardSignalType
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,

)
CONTINUOUS_PATH = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)

)
SEED = [42]
ACTIONSPEC_CONTINUOUS = ActionSpec.create_continuous(2)
ACTIONSPEC_FOURDISCRETE = ActionSpec.create_discrete((2, 3, 3, 3))
ACTIONSPEC_DISCRETE = ActionSpec.create_discrete((20,))
@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
@pytest.mark.parametrize("behavior_spec", [BehaviorSpec([(8,)], ACTIONSPEC_CONTINUOUS)])
def test_construction(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = GAILRewardProvider(behavior_spec, gail_settings)

@pytest.mark.parametrize(
"behavior_spec", [BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2)]
)
@pytest.mark.parametrize("behavior_spec", [BehaviorSpec([(8,)], ACTIONSPEC_CONTINUOUS)])
def test_factory(behavior_spec: BehaviorSpec) -> None:
gail_settings = GAILSettings(demo_path=CONTINUOUS_PATH)
gail_rp = create_reward_provider(

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,), (24, 26, 1)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(50,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
BehaviorSpec([(8,), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(50,)], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(8,)], ActionType.CONTINUOUS, 2),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3, 3, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (20,)),
BehaviorSpec([(8,), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(50,)], ACTIONSPEC_FOURDISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
@pytest.mark.parametrize("use_actions", [False, True])

25
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/test_rnd.py


RNDRewardProvider,
create_reward_provider,
)
from mlagents_envs.base_env import BehaviorSpec, ActionType
from mlagents_envs.base_env import BehaviorSpec, ActionSpec
ACTIONSPEC_CONTINUOUS = ActionSpec.create_continuous(5)
ACTIONSPEC_TWODISCRETE = ActionSpec.create_discrete((2, 3))
ACTIONSPEC_DISCRETE = ActionSpec.create_discrete((2,))
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
],
)
def test_construction(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,), (64, 66, 1)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 3), (84, 86, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,), (64, 66, 1)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_factory(behavior_spec: BehaviorSpec) -> None:

@pytest.mark.parametrize(
"behavior_spec",
[
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ActionType.CONTINUOUS, 5),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2, 3)),
BehaviorSpec([(10,)], ActionType.DISCRETE, (2,)),
BehaviorSpec([(10,), (64, 66, 3), (24, 26, 1)], ACTIONSPEC_CONTINUOUS),
BehaviorSpec([(10,)], ACTIONSPEC_TWODISCRETE),
BehaviorSpec([(10,)], ACTIONSPEC_DISCRETE),
],
)
def test_reward_decreases(behavior_spec: BehaviorSpec, seed: int) -> None:

2
ml-agents/mlagents/trainers/tests/torch/test_reward_providers/utils.py


next_observations = [
np.random.normal(size=shape) for shape in behavior_spec.observation_shapes
]
action = behavior_spec.create_random_action(1)[0, :]
action = behavior_spec.action_spec.random_action(1)[0, :]
for _ in range(number):
curr_split_obs = SplitObservations.from_observations(curr_observations)
next_split_obs = SplitObservations.from_observations(next_observations)

6
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


PPO_TORCH_CONFIG,
hyperparameters=new_hyperparams,
network_settings=new_networksettings,
max_steps=700,
max_steps=900,
summary_freq=100,
)
# The number of steps is pretty small for these encoders

)
new_hyperparams = attr.evolve(
SAC_TORCH_CONFIG.hyperparameters,
batch_size=128,
batch_size=256,
learning_rate=1e-3,
buffer_init_steps=1000,
steps_per_update=2,

hyperparameters=new_hyperparams,
network_settings=new_networksettings,
max_steps=5000,
max_steps=2000,
)
check_environment_trains(env, {BRAIN_NAME: config})

2
ml-agents/mlagents/trainers/tf/components/bc/model.py


self.done_expert = tf.placeholder(shape=[None, 1], dtype=tf.float32)
self.done_policy = tf.placeholder(shape=[None, 1], dtype=tf.float32)
if self.policy.behavior_spec.is_action_continuous():
if self.policy.behavior_spec.action_spec.is_continuous():
action_length = self.policy.act_size[0]
self.action_in_expert = tf.placeholder(
shape=[None, action_length], dtype=tf.float32

4
ml-agents/mlagents/trainers/tf/components/bc/module.py


self.policy.sequence_length_ph: self.policy.sequence_length,
}
feed_dict[self.model.action_in_expert] = mini_batch_demo["actions"]
if self.policy.behavior_spec.is_action_discrete():
if self.policy.behavior_spec.action_spec.is_discrete():
sum(self.policy.behavior_spec.discrete_action_branches),
sum(self.policy.behavior_spec.action_spec.discrete_branches),
),
dtype=np.float32,
)

2
ml-agents/mlagents/trainers/tf/components/reward_signals/curiosity/model.py


"""
combined_input = tf.concat([encoded_state, encoded_next_state], axis=1)
hidden = tf.layers.dense(combined_input, 256, activation=ModelUtils.swish)
if self.policy.behavior_spec.is_action_continuous():
if self.policy.behavior_spec.action_spec.is_continuous():
pred_action = tf.layers.dense(
hidden, self.policy.act_size[0], activation=None
)

2
ml-agents/mlagents/trainers/tf/components/reward_signals/gail/model.py


self.done_expert = tf.expand_dims(self.done_expert_holder, -1)
self.done_policy = tf.expand_dims(self.done_policy_holder, -1)
if self.policy.behavior_spec.is_action_continuous():
if self.policy.behavior_spec.action_spec.is_continuous():
action_length = self.policy.act_size[0]
self.action_in_expert = tf.placeholder(
shape=[None, action_length], dtype=tf.float32

2
ml-agents/mlagents/trainers/torch/components/bc/module.py


np.ones(
(
self.n_sequences * self.policy.sequence_length,
sum(self.policy.behavior_spec.discrete_action_branches),
sum(self.policy.behavior_spec.action_spec.discrete_branches),
),
dtype=np.float32,
)

16
ml-agents/mlagents/trainers/torch/components/reward_providers/curiosity_reward_provider.py


def __init__(self, specs: BehaviorSpec, settings: CuriositySettings) -> None:
super().__init__()
self._policy_specs = specs
self._action_spec = specs.action_spec
state_encoder_settings = NetworkSettings(
normalize=False,
hidden_units=settings.encoding_size,

specs.observation_shapes, state_encoder_settings
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self._action_flattener = ModelUtils.ActionFlattener(self._action_spec)
self.inverse_model_action_prediction = torch.nn.Sequential(
LinearEncoder(2 * settings.encoding_size, 1, 256),

(self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1
)
hidden = self.inverse_model_action_prediction(inverse_model_input)
if self._policy_specs.is_action_continuous():
if self._action_spec.is_continuous():
hidden, self._policy_specs.discrete_action_branches
hidden, self._action_spec.discrete_branches
)
branches = [torch.softmax(b, dim=1) for b in branches]
return torch.cat(branches, dim=1)

Uses the current state embedding and the action of the mini_batch to predict
the next state embedding.
"""
if self._policy_specs.is_action_continuous():
if self._action_spec.is_continuous():
self._policy_specs.discrete_action_branches,
self._action_spec.discrete_branches,
),
dim=1,
)

action prediction (given the current and next state).
"""
predicted_action = self.predict_action(mini_batch)
if self._policy_specs.is_action_continuous():
if self._action_spec.is_continuous():
sq_difference = (
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float)
- predicted_action

true_action = torch.cat(
ModelUtils.actions_to_onehot(
ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long),
self._policy_specs.discrete_action_branches,
self._action_spec.discrete_branches,
),
dim=1,
)

3
ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py


def __init__(self, specs: BehaviorSpec, settings: GAILSettings) -> None:
super().__init__()
self._policy_specs = specs
self._use_vail = settings.use_vail
self._settings = settings

vis_encode_type=EncoderType.SIMPLE,
memory=None,
)
self._action_flattener = ModelUtils.ActionFlattener(specs)
self._action_flattener = ModelUtils.ActionFlattener(specs.action_spec)
unencoded_size = (
self._action_flattener.flattened_size + 1 if settings.use_actions else 0
) # +1 is for dones

1
ml-agents/mlagents/trainers/torch/components/reward_providers/rnd_reward_provider.py


def __init__(self, specs: BehaviorSpec, settings: RNDSettings) -> None:
super().__init__()
self._policy_specs = specs
state_encoder_settings = NetworkSettings(
normalize=True,
hidden_units=settings.encoding_size,

4
ml-agents/mlagents/trainers/torch/model_serialization.py


for shape in self.policy.behavior_spec.observation_shapes
if len(shape) == 3
]
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)])
dummy_masks = torch.ones(
batch_dim + [sum(self.policy.behavior_spec.action_spec.discrete_branches)]
)
dummy_memories = torch.zeros(
batch_dim + seq_len_dim + [self.policy.export_memory_size]
)

42
ml-agents/mlagents/trainers/torch/networks.py


from mlagents.torch_utils import torch, nn
from mlagents_envs.base_env import ActionType
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import (
GaussianDistribution,
MultiCategoricalDistribution,

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
self.act_type = act_type
self.act_size = act_size
self.action_spec = action_spec
torch.Tensor([int(act_type == ActionType.CONTINUOUS)])
torch.Tensor([int(self.action_spec.is_continuous())])
torch.Tensor([sum(act_size)]), requires_grad=False
torch.Tensor(
[
self.action_spec.continuous_size
+ sum(self.action_spec.discrete_branches)
]
),
requires_grad=False,
)
self.network_body = NetworkBody(observation_shapes, network_settings)
if network_settings.memory is not None:

if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
act_size[0],
self.action_spec.continuous_size,
self.encoding_size, act_size
self.encoding_size, self.action_spec.discrete_branches
)
@property

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks)

Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
"""
dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)
else:

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,

network_settings,
act_type,
act_size,
action_spec,
conditional_sigma,
tanh_squash,
)

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
if self.act_type == ActionType.CONTINUOUS:
if self.action_spec.is_continuous():
dists = self.distribution(encoding)
else:
dists = self.distribution(encoding, masks=masks)

self,
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
action_spec: ActionSpec,
stream_names: List[str],
conditional_sigma: bool = False,
tanh_squash: bool = False,

super().__init__(
observation_shapes,
network_settings,
act_type,
act_size,
action_spec,
conditional_sigma,
tanh_squash,
)

16
ml-agents/mlagents/trainers/torch/utils.py


)
from mlagents.trainers.settings import EncoderType, ScheduleType
from mlagents.trainers.exception import UnityTrainerException
from mlagents_envs.base_env import BehaviorSpec
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance

}
class ActionFlattener:
def __init__(self, behavior_spec: BehaviorSpec):
self._specs = behavior_spec
def __init__(self, action_spec: ActionSpec):
self._specs = action_spec
if self._specs.is_action_continuous():
return self._specs.action_size
if self._specs.is_continuous():
return self._specs.continuous_size
return sum(self._specs.discrete_action_branches)
return sum(self._specs.discrete_branches)
if self._specs.is_action_continuous():
if self._specs.is_continuous():
self._specs.discrete_action_branches,
self._specs.discrete_branches,
),
dim=1,
)

8
ml-agents/tests/yamato/scripts/run_llapi.py


episode_rewards = 0
tracked_agent = -1
while not done:
if group_spec.is_action_continuous():
if group_spec.action_spec.is_continuous():
len(decision_steps), group_spec.action_size
len(decision_steps), group_spec.action_spec.continuous_size
elif group_spec.is_action_discrete():
branch_size = group_spec.discrete_action_branches
elif group_spec.action_spec.is_discrete():
branch_size = group_spec.action_spec.discrete_branches
action = np.column_stack(
[
np.random.randint(

正在加载...
取消
保存