浏览代码

Develop new ll api (#3022)

* initial commit for LL-API

* fixing ml-agents-envs tests

* Implementing action masks

* training is fixed for 3DBall

* Tests all fixed, gym is broken and missing documentation changes

* adding case where no vector obs

* Fixed Gym

* fixing tests of float64

* fixing float64

* reverting some of brain.py

* removing old proto apis

* comment type fixes

* added properties to AgentGroupSpec and edited the notebooks.

* clearing the notebook outputs

* Update gym-unity/gym_unity/tests/test_gym.py

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update gym-unity/gym_unity/tests/test_gym.py

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update ml-agents-envs/mlagents/envs/base_env.py

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* Update ml-agents-envs/mlagents/envs/base_env.py

Co-Authored-By: Chris Elion <chris.elion@unity3d.com>

* addressing first comments

* NaN checks for r...
/develop
GitHub 5 年前
当前提交
a6df9f43
共有 22 个文件被更改,包括 1298 次插入657 次删除
  1. 1
      docs/Migrating.md
  2. 202
      docs/Python-API.md
  3. 40
      gym-unity/gym_unity/envs/__init__.py
  4. 87
      gym-unity/gym_unity/tests/test_gym.py
  5. 46
      ml-agents-envs/mlagents/envs/brain.py
  6. 327
      ml-agents-envs/mlagents/envs/environment.py
  7. 42
      ml-agents-envs/mlagents/envs/simple_env_manager.py
  8. 44
      ml-agents-envs/mlagents/envs/subprocess_env_manager.py
  9. 82
      ml-agents-envs/mlagents/envs/tests/test_envs.py
  10. 4
      ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py
  11. 4
      ml-agents/mlagents/trainers/learn.py
  12. 23
      ml-agents/mlagents/trainers/tests/test_bc.py
  13. 36
      ml-agents/mlagents/trainers/tests/test_ppo.py
  14. 112
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  15. 4
      notebooks/getting-started-gym.ipynb
  16. 80
      notebooks/getting-started.ipynb
  17. 301
      ml-agents-envs/mlagents/envs/base_env.py
  18. 70
      ml-agents-envs/mlagents/envs/brain_conversion_utils.py
  19. 165
      ml-agents-envs/mlagents/envs/rpc_utils.py
  20. 187
      ml-agents-envs/mlagents/envs/tests/test_rpc_utils.py
  21. 73
      ml-agents-envs/mlagents/envs/tests/test_brain.py
  22. 25
      ml-agents-envs/mlagents/envs/base_unity_environment.py

1
docs/Migrating.md


## Migrating from master to develop
### Important changes
* The low level Python API has changed. You can look at the document [Low Level Python API documentation](Python-API.md) for more information. This should only affect you if you're writing a custom trainer; if you use `mlagents-learn` for training, this should be a transparent change.
* `CustomResetParameters` are now removed.
* `reset()` on the Low-Level Python API no longer takes a `train_mode` argument. To modify the performance/speed of the engine, you must use an `EngineConfigurationChannel`
* `reset()` on the Low-Level Python API no longer takes a `config` argument. `UnityEnvironment` no longer has a `reset_parameters` field. To modify float properties in the environment, you must use a `FloatPropertiesChannel`. For more information, refer to the [Low Level Python API documentation](Python-API.md)

202
docs/Python-API.md


# Unity ML-Agents Python Interface and Trainers
The `mlagents` Python package is part of the [ML-Agents
Toolkit](https://github.com/Unity-Technologies/ml-agents). `mlagents` provides a
Python API that allows direct interaction with the Unity game engine as well as
a collection of trainers and algorithms to train agents in Unity environments.
# Unity ML-Agents Python Low Level API
The `mlagents` Python package contains two components: a low level API which
allows you to interact directly with a Unity Environment (`mlagents.envs`) and

You can use the Python Low Level API to interact directly with your learning
environment, and use it to develop new learning algorithms.
The ML-Agents Toolkit provides a Python API for controlling the Agent simulation
The ML-Agents Toolkit Low Level API is a Python API for controlling the simulation
loop of an environment or game built with Unity. This API is used by the
training algorithms inside the ML-Agent Toolkit, but you can also write your own
Python programs using this API. Go [here](../notebooks/getting-started.ipynb)

- **UnityEnvironment** — the main interface between the Unity application and
your code. Use UnityEnvironment to start and control a simulation or training
session.
- **BrainInfo** — contains all the data from Agents in the simulation, such as
observations and rewards.
- **BrainParameters** — describes the data elements in a BrainInfo object. For
example, provides the array length of an observation in BrainInfo.
- **BatchedStepResult** — contains the data from Agents belonging to the same
"AgentGroup" in the simulation, such as observations and rewards.
- **AgentGroupSpec** — describes the shape of the data inside a BatchedStepResult.
For example, provides the dimensions of the observations of a group.
These classes are all defined in the [base_env](../ml-agents-envs/mlagents/envs/base_env.py)
script.
These classes are all defined in the `ml-agents/mlagents/envs` folder of
the ML-Agents SDK.
An Agent Group is a group of Agents identified by a string name that share the same
observations and action types. You can think about Agent Group as a group of agents
that will share the same policy or behavior. All Agents in a group have the same goal
and reward signals.
Agent must use a LearningBrain.
Your code is expected to return
actions for Agents with LearningBrains.
Agent in the simulation must have `Behavior Parameters` set to communicate. You
must set the `Behavior Type` to `Default` and give it a `Behavior Name`.
__Note__: The `Behavior Name` corresponds to the Agent Group name on Python.
_Notice: Currently communication between Unity and Python takes place over an
open socket without authentication. As such, please make sure that the network

### Loading a Unity Environment
## Loading a Unity Environment
Python-side communication happens through `UnityEnvironment` which is located in
`ml-agents/mlagents/envs`. To load a Unity environment from a built binary

```python
from mlagents.envs.environment import UnityEnvironment
env = UnityEnvironment(file_name="3DBall", worker_id=0, seed=1)
env = UnityEnvironment(file_name="3DBall", base_port=5005, seed=1, side_channels=[])
```
- `file_name` is the name of the environment binary (located in the root

training process. In environments which do not involve physics calculations,
setting the seed enables reproducible experimentation by ensuring that the
environment and trainers utilize the same random seed.
- `side_channels` provides a way to exchange data with the Unity simulation that
is not related to the reinforcement learning loop. For example: configurations
or properties. More on them in the [Modifying the environment from Python](Python-API.md#modifying-the-environment-from-python) section.
If you want to directly interact with the Editor, you need to use
`file_name=None`, then press the :arrow_forward: button in the Editor when the

### Interacting with a Unity Environment
A BrainInfo object contains the following fields:
#### The BaseEnv interface
- **`visual_observations`** : A list of 4 dimensional numpy arrays. Matrix n of
the list corresponds to the n<sup>th</sup> observation of the Brain.
- **`vector_observations`** : A two dimensional numpy array of dimension `(batch
size, vector observation size)`.
- **`rewards`** : A list as long as the number of Agents using the Brain
containing the rewards they each obtained at the previous step.
- **`local_done`** : A list as long as the number of Agents using the Brain
containing `done` flags (whether or not the Agent is done).
- **`max_reached`** : A list as long as the number of Agents using the Brain
containing true if the Agents reached their max steps.
- **`agents`** : A list of the unique ids of the Agents using the Brain.
A `BaseEnv` has the following methods:
Once loaded, you can use your UnityEnvironment object, which referenced by a
variable named `env` in this example, can be used in the following way:
- **Reset : `env.reset()`** Sends a signal to reset the environment. Returns None.
- **Step : `env.step()`** Sends a signal to step the environment. Returns None.
Note that a "step" for Python does not correspond to either Unity `Update` nor
`FixedUpdate`. When `step()` or `reset()` is called, the Unity simulation will
move forward until an Agent in the simulation needs a input from Python to act.
- **Close : `env.close()`** Sends a shutdown signal to the environment and terminates
the communication.
- **Get Agent Group Names : `env.get_agent_groups()`** Returns a list of agent group ids.
Note that the number of groups can change over time in the simulation if new
agent groups are created in the simulation.
- **Get Agent Group Spec : `env.get_agent_group_spec(agent_group: str)`** Returns
the `AgentGroupSpec` corresponding to the agent_group given as input. An
`AgentGroupSpec` contains information such as the observation shapes, the action
type (multi-discrete or continuous) and the action shape. Note that the `AgentGroupSpec`
for a specific group is fixed throughout the simulation.
- **Get Batched Step Result for Agent Group : `env.get_step_result(agent_group: str)`**
Returns a `BatchedStepResult` corresponding to the agent_group given as input.
A `BatchedStepResult` contains information about the state of the agents in a group
such as the observations, the rewards, the done flags and the agent identifiers. The
data is in `np.array` of which the first dimension is always the number of agents which
requested a decision in the simulation since the last call to `env.step()` note that the
number of agents is not guaranteed to remain constant during the simulation.
- **Set Actions for Agent Group :`env.set_actions(agent_group: str, action: np.array)`**
Sets the actions for a whole agent group. `action` is a 2D `np.array` of `dtype=np.int32`
in the discrete action case and `dtype=np.float32` in the continuous action case.
The first dimension of `action` is the number of agents that requested a decision
since the last call to `env.step()`. The second dimension is the number of discrete actions
in multi-discrete action type and the number of actions in continuous action type.
- **Set Action for Agent : `env.set_action_for_agent(agent_group: str, agent_id: int, action: np.array)`**
Sets the action for a specific Agent in an agent group. `agent_group` is the name of the
group the Agent belongs to and `agent_id` is the integer identifier of the Agent. Action
is a 1D array of type `dtype=np.int32` and size equal to the number of discrete actions
in multi-discrete action type and of type `dtype=np.float32` and size equal to the number
of actions in continuous action type.
- **Print : `print(str(env))`**
Prints all parameters relevant to the loaded environment and the
Brains.
- **Reset : `env.reset()`**
Send a reset signal to the environment, and provides a dictionary mapping
Brain names to BrainInfo objects.
- **Step : `env.step(action)`**
Sends a step signal to the environment using the actions. For each Brain :
- `action` can be one dimensional arrays or two dimensional arrays if you have
multiple Agents per Brain.
__Note:__ If no action is provided for an agent group between two calls to `env.step()` then
the default action will be all zeros (in either discrete or continuous action space)
#### BathedStepResult and StepResult
A `BatchedStepResult` has the following fields :
- `obs` is a list of numpy arrays observations collected by the group of
agent. The first dimension of the array corresponds to the batch size of
the group (number of agents requesting a decision since the last call to
`env.step()`).
- `reward` is a float vector of length batch size. Corresponds to the
rewards collected by each agent since the last simulation step.
- `done` is an array of booleans of length batch size. Is true if the
associated Agent was terminated during the last simulation step.
- `max_step` is an array of booleans of length batch size. Is true if the
associated Agent reached its maximum number of steps during the last
simulation step.
- `agent_id` is an int vector of length batch size containing unique
identifier for the corresponding Agent. This is used to track Agents
across simulation steps.
- `action_mask` is an optional list of two dimensional array of booleans.
Only available in multi-discrete action space type.
Each array corresponds to an action branch. The first dimension of each
array is the batch size and the second contains a mask for each action of
the branch. If true, the action is not available for the agent during
this simulation step.
It also has the two following methods:
- `n_agents()` Returns the number of agents requesting a decision since
the last call to `env.step()`
- `get_agent_step_result(agent_id: int)` Returns a `StepResult`
for the Agent with the `agent_id` unique identifier.
Returns a dictionary mapping Brain names to BrainInfo objects.
A `StepResult` has the following fields:
- `obs` is a list of numpy arrays observations collected by the group of
agent. (Each array has one less dimension than the arrays in `BatchedStepResult`)
- `reward` is a float. Corresponds to the rewards collected by the agent
since the last simulation step.
- `done` is a bool. Is true if the Agent was terminated during the last
simulation step.
- `max_step` is a bool. Is true if the Agent reached its maximum number of
steps during the last simulation step.
- `agent_id` is an int and an unique identifier for the corresponding Agent.
- `action_mask` is an optional list of one dimensional array of booleans.
Only available in multi-discrete action space type.
Each array corresponds to an action branch. Each array contains a mask
for each action of the branch. If true, the action is not available for
the agent during this simulation step.
For example, to access the BrainInfo belonging to a Brain called
'brain_name', and the BrainInfo field 'vector_observations':
#### AgentGroupSpec
```python
info = env.step()
brainInfo = info['brain_name']
observations = brainInfo.vector_observations
```
An Agent group can either have discrete or continuous actions. To check which type
it is, use `spec.is_action_discrete()` or `spec.is_action_continuous()` to see
which one it is. If discrete, the action tensors are expected to be `np.int32`. If
continuous, the actions are expected to be `np.float32`.
Note that if you have more than one LearningBrain in the scene, you
must provide dictionaries from Brain names to arrays for `action`, `memory`
and `value`. For example: If you have two Learning Brains named `brain1` and
`brain2` each with one Agent taking two continuous actions, then you can
have:
An `AgentGroupSpec` has the following fields :
```python
action = {'brain1':[1.0, 2.0], 'brain2':[3.0,4.0]}
```
- `observation_shapes` is a List of Tuples of int : Each Tuple corresponds
to an observation's dimensions (without the number of agents dimension).
The shape tuples have the same ordering as the ordering of the
BatchedStepResult and StepResult.
- `action_type` is the type of data of the action. it can be discrete or
continuous. If discrete, the action tensors are expected to be `np.int32`. If
continuous, the actions are expected to be `np.float32`.
- `action_size` is an `int` corresponding to the expected dimension of the action
array.
- In continuous action space it is the number of floats that constitute the action.
- In discrete action space (same as multi-discrete) it corresponds to the
number of branches (the number of independent actions)
- `discrete_action_branches` is a Tuple of int only for discrete action space. Each int
corresponds to the number of different options for each branch of the action.
For example : In a game direction input (no movement, left, right) and jump input
(no jump, jump) there will be two branches (direction and jump), the first one with 3
options and the second with 2 options. (`action_size = 2` and
`discrete_action_branches = (3,2,)`)
Returns a dictionary mapping Brain names to BrainInfo objects.
- **Close : `env.close()`**
Sends a shutdown signal to the environment and closes the communication
socket.
### Modifying the environment from Python
The Environment can be modified by using side channels to send data to the

var sharedProperties = academy.FloatProperties;
float property1 = sharedProperties.GetPropertyWithDefault("parameter_1", 0.0f);
```
## mlagents-learn
For more detailed documentation on using `mlagents-learn`, check out
[Training ML-Agents](Training-ML-Agents.md)

40
gym-unity/gym_unity/envs/__init__.py


import numpy as np
from mlagents.envs.environment import UnityEnvironment
from gym import error, spaces
from mlagents.envs.brain_conversion_utils import (
step_result_to_brain_info,
group_spec_to_brain_parameters,
)
class UnityGymException(error.Error):

)
# Take a single step so that the brain information will be sent over
if not self._env.brains:
if not self._env.get_agent_groups():
self.name = self._env.academy_name
self.visual_obs = None
self._current_state = None
self._n_agents = None

self._allow_multiple_visual_obs = allow_multiple_visual_obs
# Check brain configuration
if len(self._env.brains) != 1:
if len(self._env.get_agent_groups()) != 1:
if len(self._env.external_brain_names) <= 0:
raise UnityGymException(
"There are not any external brain in the UnityEnvironment"
)
self.brain_name = self._env.external_brain_names[0]
brain = self._env.brains[self.brain_name]
self.brain_name = self._env.get_agent_groups()[0]
self.name = self.brain_name
brain = group_spec_to_brain_parameters(
self.brain_name, self._env.get_agent_group_spec(self.brain_name)
)
if use_visual and brain.number_visual_observations == 0:
raise UnityGymException(

)
# Check for number of agents in scene.
initial_info = self._env.reset()[self.brain_name]
self._env.reset()
initial_info = step_result_to_brain_info(
self._env.get_step_result(self.brain_name),
self._env.get_agent_group_spec(self.brain_name),
)
self._check_agents(len(initial_info.agents))
# Set observation and action spaces

Returns: observation (object/list): the initial observation of the
space.
"""
info = self._env.reset()[self.brain_name]
self._env.reset()
info = step_result_to_brain_info(
self._env.get_step_result(self.brain_name),
self._env.get_agent_group_spec(self.brain_name),
)
n_agents = len(info.agents)
self._check_agents(n_agents)
self.game_over = False

# Translate action into list
action = self._flattener.lookup_action(action)
info = self._env.step(action)[self.brain_name]
spec = self._env.get_agent_group_spec(self.brain_name)
action = np.array(action).reshape((self._n_agents, spec.action_size))
self._env.set_actions(self.brain_name, action)
self._env.step()
info = step_result_to_brain_info(
self._env.get_step_result(self.brain_name), spec
)
n_agents = len(info.agents)
self._check_agents(n_agents)
self._current_state = info

87
gym-unity/gym_unity/tests/test_gym.py


from gym import spaces
from gym_unity.envs import UnityEnv, UnityGymException
from mlagents.envs.brain import CameraResolution
from mlagents.envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult
mock_brain = create_mock_brainparams()
mock_braininfo = create_mock_vector_braininfo()
mock_brain = create_mock_group_spec()
mock_braininfo = create_mock_vector_step_result()
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = UnityEnv(" ", use_visual=False, multiagent=False)

assert env.observation_space.contains(obs)
assert isinstance(obs, np.ndarray)
assert isinstance(rew, float)
assert isinstance(done, bool)
assert isinstance(done, (bool, np.bool_))
mock_brain = create_mock_brainparams()
mock_braininfo = create_mock_vector_braininfo(num_agents=2)
mock_brain = create_mock_group_spec()
mock_braininfo = create_mock_vector_step_result(num_agents=2)
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
with pytest.raises(UnityGymException):

@mock.patch("gym_unity.envs.UnityEnvironment")
def test_branched_flatten(mock_env):
mock_brain = create_mock_brainparams(
mock_brain = create_mock_group_spec(
mock_braininfo = create_mock_vector_braininfo(num_agents=1)
mock_braininfo = create_mock_vector_step_result(num_agents=1)
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = UnityEnv(" ", use_visual=False, multiagent=False, flatten_branched=True)

@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
@mock.patch("gym_unity.envs.UnityEnvironment")
def test_gym_wrapper_visual(mock_env, use_uint8):
mock_brain = create_mock_brainparams(number_visual_observations=1)
mock_braininfo = create_mock_vector_braininfo(number_visual_observations=1)
mock_brain = create_mock_group_spec(number_visual_observations=1)
mock_braininfo = create_mock_vector_step_result(number_visual_observations=1)
setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = UnityEnv(" ", use_visual=True, multiagent=False, uint8_visual=use_uint8)

assert env.observation_space.contains(obs)
assert isinstance(obs, np.ndarray)
assert isinstance(rew, float)
assert isinstance(done, bool)
assert isinstance(done, (bool, np.bool_))
assert isinstance(info, dict)

def create_mock_brainparams(
def create_mock_group_spec(
number_visual_observations=0,
vector_action_space_type="continuous",
vector_observation_space_size=3,

Creates a mock BrainParameters object with parameters.
"""
# Avoid using mutable object as default param
if vector_action_space_size is None:
vector_action_space_size = [2]
mock_brain = mock.Mock()
mock_brain.return_value.number_visual_observations = number_visual_observations
if number_visual_observations:
mock_brain.return_value.camera_resolutions = [
CameraResolution(width=8, height=8, num_channels=3)
for _ in range(number_visual_observations)
]
act_type = ActionType.DISCRETE
if vector_action_space_type == "continuous":
act_type = ActionType.CONTINUOUS
if vector_action_space_size is None:
vector_action_space_size = 2
else:
vector_action_space_size = vector_action_space_size[0]
else:
if vector_action_space_size is None:
vector_action_space_size = (2,)
else:
vector_action_space_size = tuple(vector_action_space_size)
obs_shapes = [(vector_observation_space_size,)]
for i in range(number_visual_observations):
obs_shapes += [(8, 8, 3)]
return AgentGroupSpec(obs_shapes, act_type, vector_action_space_size)
mock_brain.return_value.vector_action_space_type = vector_action_space_type
mock_brain.return_value.vector_observation_space_size = (
vector_observation_space_size
)
mock_brain.return_value.vector_action_space_size = vector_action_space_size
return mock_brain()
def create_mock_vector_braininfo(num_agents=1, number_visual_observations=0):
def create_mock_vector_step_result(num_agents=1, number_visual_observations=0):
"""
Creates a mock BrainInfo with vector observations. Imitates constant
vector observations, rewards, dones, and agents.

mock_braininfo = mock.Mock()
mock_braininfo.return_value.vector_observations = np.array([num_agents * [1, 2, 3]])
obs = [np.array([num_agents * [1, 2, 3]])]
mock_braininfo.return_value.visual_observations = [
[np.zeros(shape=(8, 8, 3), dtype=np.float32)]
]
mock_braininfo.return_value.rewards = num_agents * [1.0]
mock_braininfo.return_value.local_done = num_agents * [False]
mock_braininfo.return_value.agents = range(0, num_agents)
return mock_braininfo()
obs += [np.zeros(shape=(num_agents, 8, 8, 3), dtype=np.float32)]
rewards = np.array(num_agents * [1.0])
done = np.array(num_agents * [False])
agents = np.array(range(0, num_agents))
return BatchedStepResult(obs, rewards, done, done, agents, None)
def setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo):
def setup_mock_unityenvironment(mock_env, mock_spec, mock_result):
:Mock mock_brain: A mock Brain object that specifies the params of this environment.
:Mock mock_braininfo: A mock BrainInfo object that will be returned at each step and reset.
:Mock mock_spec: An AgentGroupSpec object that specifies the params of this environment.
:Mock mock_result: A BatchedStepResult object that will be returned at each step and reset.
mock_env.return_value.academy_name = "MockAcademy"
mock_env.return_value.brains = {"MockBrain": mock_brain}
mock_env.return_value.external_brain_names = ["MockBrain"]
mock_env.return_value.reset.return_value = {"MockBrain": mock_braininfo}
mock_env.return_value.step.return_value = {"MockBrain": mock_braininfo}
mock_env.return_value.get_agent_groups.return_value = ["MockBrain"]
mock_env.return_value.get_agent_group_spec.return_value = mock_spec
mock_env.return_value.get_step_result.return_value = mock_result

46
ml-agents-envs/mlagents/envs/brain.py


from mlagents.envs.communicator_objects.brain_parameters_pb2 import BrainParametersProto
from mlagents.envs.communicator_objects.observation_pb2 import ObservationProto
from mlagents.envs.timers import hierarchical_timer, timed
from typing import Dict, List, NamedTuple, Optional
from typing import Dict, List, NamedTuple
from PIL import Image
logger = logging.getLogger("mlagents.envs")

self.max_reached = max_reached
self.agents = agents
self.action_masks = action_mask
@staticmethod
def merge_memories(m1, m2, agents1, agents2):
if len(m1) == 0 and len(m2) != 0:
m1 = np.zeros((len(agents1), m2.shape[1]), dtype=np.float32)
elif len(m2) == 0 and len(m1) != 0:
m2 = np.zeros((len(agents2), m1.shape[1]), dtype=np.float32)
elif m2.shape[1] > m1.shape[1]:
new_m1 = np.zeros((m1.shape[0], m2.shape[1]), dtype=np.float32)
new_m1[0 : m1.shape[0], 0 : m1.shape[1]] = m1
return np.append(new_m1, m2, axis=0)
elif m1.shape[1] > m2.shape[1]:
new_m2 = np.zeros((m2.shape[0], m1.shape[1]), dtype=np.float32)
new_m2[0 : m2.shape[0], 0 : m2.shape[1]] = m2
return np.append(m1, new_m2, axis=0)
return np.append(m1, m2, axis=0)
@staticmethod
@timed

f"An agent had a NaN observation for brain {brain_params.brain_name}"
)
return vector_obs
def safe_concat_lists(l1: Optional[List], l2: Optional[List]) -> Optional[List]:
if l1 is None:
if l2 is None:
return None
else:
return l2.copy()
else:
if l2 is None:
return l1.copy()
else:
copy = l1.copy()
copy.extend(l2)
return copy
def safe_concat_np_ndarray(
a1: Optional[np.ndarray], a2: Optional[np.ndarray]
) -> Optional[np.ndarray]:
if a1 is not None and a1.size != 0:
if a2 is not None and a2.size != 0:
return np.append(a1, a2, axis=0)
else:
return a1.copy()
elif a2 is not None and a2.size != 0:
return a2.copy()
return None
# Renaming of dictionary of brain name to BrainInfo for clarity

327
ml-agents-envs/mlagents/envs/environment.py


from typing import Dict, List, Optional, Any
from mlagents.envs.side_channel.side_channel import SideChannel
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.base_env import (
BaseEnv,
BatchedStepResult,
AgentGroupSpec,
AgentGroup,
AgentId,
)
from .brain import AllBrainInfo, BrainInfo, BrainParameters
)
from mlagents.envs.rpc_utils import (
agent_group_spec_from_proto,
batched_step_result_from_proto,
)
from mlagents.envs.communicator_objects.unity_rl_input_pb2 import UnityRLInputProto

logger = logging.getLogger("mlagents.envs")
class UnityEnvironment(BaseUnityEnvironment):
class UnityEnvironment(BaseEnv):
SCALAR_ACTION_TYPES = (int, np.int32, np.int64, float, np.float32, np.float64)
SINGLE_BRAIN_ACTION_TYPES = SCALAR_ACTION_TYPES + (list, np.ndarray)
API_VERSION = "API-12"

"{1}.\nPlease go to https://github.com/Unity-Technologies/ml-agents to download the latest version "
"of ML-Agents.".format(self._version_, self._unity_version)
)
self._n_agents: Dict[str, int] = {}
self._env_state: Dict[str, BatchedStepResult] = {}
self._env_specs: Dict[str, AgentGroupSpec] = {}
self._env_actions: Dict[str, np.ndarray] = {}
self._academy_name = aca_params.name
self._log_path = aca_params.log_path
self._brains: Dict[str, BrainParameters] = {}
self._external_brain_names: List[str] = []
self._num_external_brains = 0
self._update_brain_parameters(aca_output)
logger.info(
"\n'{0}' started successfully!\n{1}".format(self._academy_name, str(self))
)
@property
def logfile_path(self):
return self._log_path
@property
def brains(self):
return self._brains
@property
def academy_name(self):
return self._academy_name
@property
def number_external_brains(self):
return self._num_external_brains
@property
def external_brain_names(self):
return self._external_brain_names
self._update_group_specs(aca_output)
@property
def external_brains(self):
external_brains = {}
for brain_name in self.external_brain_names:
external_brains[brain_name] = self.brains[brain_name]
return external_brains
def executable_launcher(self, file_name, docker_training, no_graphics, args):
cwd = os.getcwd()

shell=True,
)
def __str__(self):
return """Unity Academy name: {0}""".format(self._academy_name)
def _update_group_specs(self, output: UnityOutputProto) -> None:
init_output = output.rl_initialization_output
for brain_param in init_output.brain_parameters:
# Each BrainParameter in the rl_initialization_output should have at least one AgentInfo
# Get that agent, because we need some of its observations.
agent_infos = output.rl_output.agentInfos[brain_param.brain_name]
if agent_infos.value:
agent = agent_infos.value[0]
new_spec = agent_group_spec_from_proto(brain_param, agent)
self._env_specs[brain_param.brain_name] = new_spec
logger.info(f"Connected new brain:\n{brain_param.brain_name}")
def reset(self) -> AllBrainInfo:
def _update_state(self, output: UnityRLOutputProto) -> None:
Sends a signal to reset the unity environment.
:return: AllBrainInfo : A data structure corresponding to the initial reset state of the environment.
Collects experience information from all external brains in environment at current step.
for brain_name in self._env_specs.keys():
if brain_name in output.agentInfos:
agent_info_list = output.agentInfos[brain_name].value
self._env_state[brain_name] = batched_step_result_from_proto(
agent_info_list, self._env_specs[brain_name]
)
else:
self._env_state[brain_name] = BatchedStepResult.empty(
self._env_specs[brain_name]
)
self._parse_side_channel_message(self.side_channels, output.side_channel)
def reset(self) -> None:
self._update_brain_parameters(outputs)
self._update_group_specs(outputs)
s = self._get_state(rl_output)
for _b in self._external_brain_names:
self._n_agents[_b] = len(s[_b].agents)
self._update_state(rl_output)
return s
self._env_actions.clear()
def step(
self,
vector_action: Dict[str, np.ndarray] = None,
value: Optional[Dict[str, np.ndarray]] = None,
) -> AllBrainInfo:
"""
Provides the environment with an action, moves the environment dynamics forward accordingly,
and returns observation, state, and reward information to the agent.
:param value: Value estimates provided by agents.
:param vector_action: Agent's vector action. Can be a scalar or vector of int/floats.
:param memory: Vector corresponding to memory used for recurrent policies.
:return: AllBrainInfo : A Data structure corresponding to the new state of the environment.
"""
def step(self) -> None:
vector_action = {} if vector_action is None else vector_action
value = {} if value is None else value
# Check that environment is loaded, and episode is currently running.
else:
if isinstance(vector_action, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
vector_action = {self._external_brain_names[0]: vector_action}
elif self._num_external_brains > 1:
raise UnityActionException(
"You have {0} brains, you need to feed a dictionary of brain names a keys, "
"and vector_actions as values".format(self._num_external_brains)
)
else:
raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a vector_action input"
)
# fill the blanks for missing actions
for group_name in self._env_specs:
if group_name not in self._env_actions:
n_agents = 0
if group_name in self._env_state:
n_agents = self._env_state[group_name].n_agents()
self._env_actions[group_name] = self._env_specs[
group_name
].create_empty_action(n_agents)
step_input = self._generate_step_input(self._env_actions)
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
if outputs is None:
raise UnityCommunicationException("Communicator has stopped.")
self._update_group_specs(outputs)
rl_output = outputs.rl_output
self._update_state(rl_output)
self._env_actions.clear()
if isinstance(value, self.SINGLE_BRAIN_ACTION_TYPES):
if self._num_external_brains == 1:
value = {self._external_brain_names[0]: value}
elif self._num_external_brains > 1:
raise UnityActionException(
"You have {0} brains, you need to feed a dictionary of brain names as keys "
"and state/action value estimates as values".format(
self._num_external_brains
)
)
else:
raise UnityActionException(
"There are no external brains in the environment, "
"step cannot take a value input"
)
def get_agent_groups(self) -> List[AgentGroup]:
return list(self._env_specs.keys())
for brain_name in list(vector_action.keys()):
if brain_name not in self._external_brain_names:
raise UnityActionException(
"The name {0} does not correspond to an external brain "
"in the environment".format(brain_name)
)
def _assert_group_exists(self, agent_group: str) -> None:
if agent_group not in self._env_specs:
raise UnityActionException(
"The group {0} does not correspond to an existing agent group "
"in the environment".format(agent_group)
)
for brain_name in self._external_brain_names:
n_agent = self._n_agents[brain_name]
if brain_name not in vector_action:
if self._brains[brain_name].vector_action_space_type == "discrete":
vector_action[brain_name] = (
[0.0]
* n_agent
* len(self._brains[brain_name].vector_action_space_size)
)
else:
vector_action[brain_name] = (
[0.0]
* n_agent
* self._brains[brain_name].vector_action_space_size[0]
)
else:
vector_action[brain_name] = self._flatten(vector_action[brain_name])
discrete_check = (
self._brains[brain_name].vector_action_space_type == "discrete"
def set_actions(self, agent_group: AgentGroup, action: np.ndarray) -> None:
self._assert_group_exists(agent_group)
if agent_group not in self._env_state:
return
spec = self._env_specs[agent_group]
expected_type = np.float32 if spec.is_action_continuous() else np.int32
expected_shape = (self._env_state[agent_group].n_agents(), spec.action_size)
if action.shape != expected_shape:
raise UnityActionException(
"The group {0} needs an input of dimension {1} but received input of dimension {2}".format(
agent_group, expected_shape, action.shape
)
if action.dtype != expected_type:
action = action.astype(expected_type)
self._env_actions[agent_group] = action
expected_discrete_size = n_agent * len(
self._brains[brain_name].vector_action_space_size
def set_action_for_agent(
self, agent_group: AgentGroup, agent_id: AgentId, action: np.ndarray
) -> None:
self._assert_group_exists(agent_group)
if agent_group not in self._env_state:
return
spec = self._env_specs[agent_group]
expected_shape = (spec.action_size,)
if action.shape != expected_shape:
raise UnityActionException(
"The Agent {0} in group {1} needs an input of dimension {2} but received input of dimension {3}".format(
agent_id, agent_group, expected_shape, action.shape
)
expected_type = np.float32 if spec.is_action_continuous() else np.int32
if action.dtype != expected_type:
action = action.astype(expected_type)
continuous_check = (
self._brains[brain_name].vector_action_space_type == "continuous"
if agent_group not in self._env_actions:
self._env_actions[agent_group] = self._empty_action(
spec, self._env_state[agent_group].n_agents()
)
try:
index = np.where(self._env_state[agent_group].agent_id == agent_id)[0][0]
except IndexError as ie:
raise IndexError(
"agent_id {} is did not request a decision at the previous step".format(
agent_id
) from ie
self._env_actions[agent_group][index] = action
expected_continuous_size = (
self._brains[brain_name].vector_action_space_size[0] * n_agent
)
def get_step_result(self, agent_group: AgentGroup) -> BatchedStepResult:
self._assert_group_exists(agent_group)
return self._env_state[agent_group]
if not (
(
discrete_check
and len(vector_action[brain_name]) == expected_discrete_size
)
or (
continuous_check
and len(vector_action[brain_name]) == expected_continuous_size
)
):
raise UnityActionException(
"There was a mismatch between the provided action and "
"the environment's expectation: "
"The brain {0} expected {1} {2} action(s), but was provided: {3}".format(
brain_name,
str(expected_discrete_size)
if discrete_check
else str(expected_continuous_size),
self._brains[brain_name].vector_action_space_type,
str(vector_action[brain_name]),
)
)
step_input = self._generate_step_input(vector_action, value)
with hierarchical_timer("communicator.exchange"):
outputs = self.communicator.exchange(step_input)
if outputs is None:
raise UnityCommunicationException("Communicator has stopped.")
self._update_brain_parameters(outputs)
rl_output = outputs.rl_output
state = self._get_state(rl_output)
for _b in self._external_brain_names:
self._n_agents[_b] = len(state[_b].agents)
return state
def get_agent_group_spec(self, agent_group: AgentGroup) -> AgentGroupSpec:
self._assert_group_exists(agent_group)
return self._env_specs[agent_group]
def close(self):
"""

arr = [float(x) for x in arr]
return arr
def _get_state(self, output: UnityRLOutputProto) -> AllBrainInfo:
"""
Collects experience information from all external brains in environment at current step.
:return: a dictionary of BrainInfo objects.
"""
_data = {}
for brain_name in output.agentInfos:
agent_info_list = output.agentInfos[brain_name].value
_data[brain_name] = BrainInfo.from_agent_proto(
self.worker_id, agent_info_list, self.brains[brain_name]
)
self._parse_side_channel_message(self.side_channels, output.side_channel)
return _data
@staticmethod
def _parse_side_channel_message(
side_channels: Dict[int, SideChannel], data: bytearray

channel.message_queue = []
return result
def _update_brain_parameters(self, output: UnityOutputProto) -> None:
init_output = output.rl_initialization_output
for brain_param in init_output.brain_parameters:
# Each BrainParameter in the rl_initialization_output should have at least one AgentInfo
# Get that agent, because we need some of its observations.
agent_infos = output.rl_output.agentInfos[brain_param.brain_name]
if agent_infos.value:
agent = agent_infos.value[0]
new_brain = BrainParameters.from_proto(brain_param, agent)
self._brains[brain_param.brain_name] = new_brain
logger.info(f"Connected new brain:\n{new_brain}")
self._external_brain_names = list(self._brains.keys())
self._num_external_brains = len(self._external_brain_names)
self, vector_action: Dict[str, np.ndarray], value: Dict[str, np.ndarray]
self, vector_action: Dict[str, np.ndarray]
n_agents = self._n_agents[b]
n_agents = self._env_state[b].n_agents()
_a_s = len(vector_action[b]) // n_agents
action = AgentActionProto(
vector_actions=vector_action[b][i * _a_s : (i + 1) * _a_s]
)
if b in value:
if value[b] is not None:
action.value = float(value[b][i])
action = AgentActionProto(vector_actions=vector_action[b][i])
rl_in.agent_actions[b].value.extend([action])
rl_in.command = 0
rl_in.side_channel = bytes(self._generate_side_channel_data(self.side_channels))

42
ml-agents-envs/mlagents/envs/simple_env_manager.py


from typing import Dict, List
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.base_env import BaseEnv
from mlagents.envs.brain import BrainParameters
from mlagents.envs.brain import BrainParameters, AllBrainInfo
from mlagents.envs.brain_conversion_utils import (
step_result_to_brain_info,
group_spec_to_brain_parameters,
)
Simple implementation of the EnvManager interface that only handles one BaseUnityEnvironment at a time.
Simple implementation of the EnvManager interface that only handles one BaseEnv at a time.
def __init__(
self, env: BaseUnityEnvironment, float_prop_channel: FloatPropertiesChannel
):
def __init__(self, env: BaseEnv, float_prop_channel: FloatPropertiesChannel):
super().__init__()
self.shared_float_properties = float_prop_channel
self.env = env

def step(self) -> List[EnvironmentStep]:
actions = {}
values = {}
actions[brain_name] = action_info.action
values[brain_name] = action_info.value
all_brain_info = self.env.step(vector_action=actions, value=values)
self.env.set_actions(brain_name, action_info.action)
self.env.step()
all_brain_info = self._generate_all_brain_info()
step_brain_info = all_brain_info
step_info = EnvironmentStep(

if config is not None:
for k, v in config.items():
self.shared_float_properties.set_property(k, v)
all_brain_info = self.env.reset()
self.env.reset()
all_brain_info = self._generate_all_brain_info()
return self.env.external_brains
result = {}
for brain_name in self.env.get_agent_groups():
result[brain_name] = group_spec_to_brain_parameters(
brain_name, self.env.get_agent_group_spec(brain_name)
)
return result
@property
def get_properties(self) -> Dict[str, float]:

brain_info
)
return all_action_info
def _generate_all_brain_info(self) -> AllBrainInfo:
all_brain_info = {}
for brain_name in self.env.get_agent_groups():
all_brain_info[brain_name] = step_result_to_brain_info(
self.env.get_step_result(brain_name),
self.env.get_agent_group_spec(brain_name),
)
return all_brain_info

44
ml-agents-envs/mlagents/envs/subprocess_env_manager.py


from multiprocessing import Process, Pipe, Queue
from multiprocessing.connection import Connection
from queue import Empty as EmptyQueueException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.base_env import BaseEnv
from mlagents.envs.env_manager import EnvManager, EnvironmentStep
from mlagents.envs.timers import (
TimerNode,

EngineConfig,
)
from mlagents.envs.side_channel.side_channel import SideChannel
from mlagents.envs.brain_conversion_utils import (
step_result_to_brain_info,
group_spec_to_brain_parameters,
)
logger = logging.getLogger("mlagents.envs")

shared_float_properties = FloatPropertiesChannel()
engine_configuration_channel = EngineConfigurationChannel()
engine_configuration_channel.set_configuration(engine_configuration)
env: BaseUnityEnvironment = env_factory(
env: BaseEnv = env_factory(
worker_id, [shared_float_properties, engine_configuration_channel]
)

def _generate_all_brain_info() -> AllBrainInfo:
all_brain_info = {}
for brain_name in env.get_agent_groups():
all_brain_info[brain_name] = step_result_to_brain_info(
env.get_step_result(brain_name),
env.get_agent_group_spec(brain_name),
worker_id,
)
return all_brain_info
def external_brains():
result = {}
for brain_name in env.get_agent_groups():
result[brain_name] = group_spec_to_brain_parameters(
brain_name, env.get_agent_group_spec(brain_name)
)
return result
actions = {}
values = {}
actions[brain_name] = action_info.action
values[brain_name] = action_info.value
all_brain_info = env.step(vector_action=actions, value=values)
if len(action_info.action) != 0:
env.set_actions(brain_name, action_info.action)
env.step()
all_brain_info = _generate_all_brain_info()
# The timers in this process are independent from all the processes and the "main" process
# So after we send back the root timer, we can safely clear them.
# Note that we could randomly return timers a fraction of the time if we wanted to reduce

step_queue.put(EnvironmentResponse("step", worker_id, step_response))
reset_timers()
elif cmd.name == "external_brains":
_send_response("external_brains", env.external_brains)
_send_response("external_brains", external_brains())
elif cmd.name == "get_properties":
reset_params = {}
for k in shared_float_properties.list_properties():

elif cmd.name == "reset":
for k, v in cmd.payload.items():
shared_float_properties.set_property(k, v)
all_brain_info = env.reset()
env.reset()
all_brain_info = _generate_all_brain_info()
_send_response("reset", all_brain_info)
elif cmd.name == "close":
break

class SubprocessEnvManager(EnvManager):
def __init__(
self,
env_factory: Callable[[int, List[SideChannel]], BaseUnityEnvironment],
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
engine_configuration: EngineConfig,
n_env: int = 1,
):

def create_worker(
worker_id: int,
step_queue: Queue,
env_factory: Callable[[int, List[SideChannel]], BaseUnityEnvironment],
env_factory: Callable[[int, List[SideChannel]], BaseEnv],
engine_configuration: EngineConfig,
) -> UnityEnvWorker:
parent_conn, child_conn = Pipe()

82
ml-agents-envs/mlagents/envs/tests/test_envs.py


import numpy as np
from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.base_env import BatchedStepResult
from mlagents.envs.brain import BrainInfo
from mlagents.envs.mock_communicator import MockCommunicator

discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
assert env.external_brain_names[0] == "RealFakeBrain"
assert env.get_agent_groups() == ["RealFakeBrain"]
env.close()

discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
brain = env.brains["RealFakeBrain"]
brain_info = env.reset()
spec = env.get_agent_group_spec("RealFakeBrain")
env.reset()
batched_step_result = env.get_step_result("RealFakeBrain")
assert isinstance(brain_info, dict)
assert isinstance(brain_info["RealFakeBrain"], BrainInfo)
assert isinstance(brain_info["RealFakeBrain"].visual_observations, list)
assert isinstance(brain_info["RealFakeBrain"].vector_observations, np.ndarray)
assert (
len(brain_info["RealFakeBrain"].visual_observations)
== brain.number_visual_observations
)
assert len(brain_info["RealFakeBrain"].vector_observations) == len(
brain_info["RealFakeBrain"].agents
)
assert (
len(brain_info["RealFakeBrain"].vector_observations[0])
== brain.vector_observation_space_size
)
assert isinstance(batched_step_result, BatchedStepResult)
assert len(spec.observation_shapes) == len(batched_step_result.obs)
n_agents = batched_step_result.n_agents()
for shape, obs in zip(spec.observation_shapes, batched_step_result.obs):
assert (n_agents,) + shape == obs.shape
@mock.patch("mlagents.envs.environment.UnityEnvironment.executable_launcher")

discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
brain = env.brains["RealFakeBrain"]
brain_info = env.step()
brain_info = env.step(
[0]
* brain.vector_action_space_size[0]
* len(brain_info["RealFakeBrain"].agents)
spec = env.get_agent_group_spec("RealFakeBrain")
env.step()
batched_step_result = env.get_step_result("RealFakeBrain")
n_agents = batched_step_result.n_agents()
env.set_actions(
"RealFakeBrain", np.zeros((n_agents, spec.action_size), dtype=np.float32)
env.step()
env.step([0])
brain_info = env.step(
[-1]
* brain.vector_action_space_size[0]
* len(brain_info["RealFakeBrain"].agents)
env.set_actions(
"RealFakeBrain",
np.zeros((n_agents - 1, spec.action_size), dtype=np.float32),
)
batched_step_result = env.get_step_result("RealFakeBrain")
n_agents = batched_step_result.n_agents()
env.set_actions(
"RealFakeBrain", -1 * np.ones((n_agents, spec.action_size), dtype=np.float32)
env.step()
assert isinstance(brain_info, dict)
assert isinstance(brain_info["RealFakeBrain"], BrainInfo)
assert isinstance(brain_info["RealFakeBrain"].visual_observations, list)
assert isinstance(brain_info["RealFakeBrain"].vector_observations, np.ndarray)
assert (
len(brain_info["RealFakeBrain"].visual_observations)
== brain.number_visual_observations
)
assert len(brain_info["RealFakeBrain"].vector_observations) == len(
brain_info["RealFakeBrain"].agents
)
assert (
len(brain_info["RealFakeBrain"].vector_observations[0])
== brain.vector_observation_space_size
)
print("\n\n\n\n\n\n\n" + str(brain_info["RealFakeBrain"].local_done))
assert not brain_info["RealFakeBrain"].local_done[0]
assert brain_info["RealFakeBrain"].local_done[2]
assert isinstance(batched_step_result, BatchedStepResult)
assert len(spec.observation_shapes) == len(batched_step_result.obs)
for shape, obs in zip(spec.observation_shapes, batched_step_result.obs):
assert (n_agents,) + shape == obs.shape
assert not batched_step_result.done[0]
assert batched_step_result.done[2]
@mock.patch("mlagents.envs.environment.UnityEnvironment.executable_launcher")

4
ml-agents-envs/mlagents/envs/tests/test_subprocess_env_manager.py


EnvironmentResponse,
StepResponse,
)
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.base_env import BaseEnv
return mock.create_autospec(spec=BaseUnityEnvironment)
return mock.create_autospec(spec=BaseEnv)
class MockEnvWorker:

4
ml-agents/mlagents/trainers/learn.py


from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.sampler_class import SamplerManager
from mlagents.envs.exception import SamplerException
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.base_env import BaseEnv
from mlagents.envs.subprocess_env_manager import SubprocessEnvManager
from mlagents.envs.side_channel.side_channel import SideChannel
from mlagents.envs.side_channel.engine_configuration_channel import EngineConfig

seed: Optional[int],
start_port: int,
env_args: Optional[List[str]],
) -> Callable[[int, List[SideChannel]], BaseUnityEnvironment]:
) -> Callable[[int, List[SideChannel]], BaseEnv]:
if env_path is not None:
# Strip out executable extensions if passed
env_path = (

23
ml-agents/mlagents/trainers/tests/test_bc.py


import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.bc.policy import BCPolicy
from mlagents.trainers.bc.offline_trainer import BCTrainer
from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.environment import UnityEnvironment
from mlagents.envs.brain_conversion_utils import (
step_result_to_brain_info,
group_spec_to_brain_parameters,
)
@pytest.fixture

discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
brain_infos = env.reset()
brain_info = brain_infos[env.external_brain_names[0]]
env.reset()
brain_name = env.get_agent_groups()[0]
brain_info = step_result_to_brain_info(
env.get_step_result(brain_name), env.get_agent_group_spec(brain_name)
)
brain_params = group_spec_to_brain_parameters(
brain_name, env.get_agent_group_spec(brain_name)
)
model_path = env.external_brain_names[0]
model_path = brain_name
policy = BCPolicy(
0, env.brains[env.external_brain_names[0]], trainer_parameters, False
)
policy = BCPolicy(0, brain_params, trainer_parameters, False)
run_out = policy.evaluate(brain_info)
assert run_out["action"].shape == (3, 2)

36
ml-agents/mlagents/trainers/tests/test_ppo.py


from mlagents.envs.mock_communicator import MockCommunicator
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.mock_brain import make_brain_parameters
from mlagents.envs.brain_conversion_utils import (
step_result_to_brain_info,
group_spec_to_brain_parameters,
)
@pytest.fixture

discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
brain_infos = env.reset()
brain_info = brain_infos[env.external_brain_names[0]]
env.reset()
brain_name = env.get_agent_groups()[0]
brain_info = step_result_to_brain_info(
env.get_step_result(brain_name), env.get_agent_group_spec(brain_name)
)
brain_params = group_spec_to_brain_parameters(
brain_name, env.get_agent_group_spec(brain_name)
)
model_path = env.external_brain_names[0]
model_path = brain_name
policy = PPOPolicy(
0, env.brains[env.external_brain_names[0]], trainer_parameters, False, False
)
policy = PPOPolicy(0, brain_params, trainer_parameters, False, False)
run_out = policy.evaluate(brain_info)
assert run_out["action"].shape == (3, 2)
env.close()

discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(" ")
brain_infos = env.reset()
brain_info = brain_infos[env.external_brain_names[0]]
env.reset()
brain_name = env.get_agent_groups()[0]
brain_info = step_result_to_brain_info(
env.get_step_result(brain_name), env.get_agent_group_spec(brain_name)
)
brain_params = group_spec_to_brain_parameters(
brain_name, env.get_agent_group_spec(brain_name)
)
model_path = env.external_brain_names[0]
model_path = brain_name
policy = PPOPolicy(
0, env.brains[env.external_brain_names[0]], trainer_parameters, False, False
)
policy = PPOPolicy(0, brain_params, trainer_parameters, False, False)
run_out = policy.get_value_estimates(brain_info, 0, done=False)
for key, val in run_out.items():
assert type(key) is str

112
ml-agents/mlagents/trainers/tests/test_simple_rl.py


import tempfile
import pytest
import yaml
from typing import Any, Dict
from typing import Dict
import numpy as np
from mlagents.envs.base_unity_environment import BaseUnityEnvironment
from mlagents.envs.brain import BrainInfo, AllBrainInfo, BrainParameters
from mlagents.envs.communicator_objects.agent_info_pb2 import AgentInfoProto
from mlagents.envs.communicator_objects.observation_pb2 import (
ObservationProto,
NONE as COMPRESSION_TYPE_NONE,
from mlagents.envs.base_env import (
BaseEnv,
AgentGroupSpec,
BatchedStepResult,
ActionType,
from mlagents.envs.brain import BrainParameters
BRAIN_NAME = __name__
OBS_SIZE = 1

return max(min_val, min(x, max_val))
class Simple1DEnvironment(BaseUnityEnvironment):
class Simple1DEnvironment(BaseEnv):
"""
Very simple "game" - the agent has a position on [-1, 1], gets a reward of 1 if it reaches 1, and a reward of -1 if
it reaches -1. The position is incremented by the action amount (clamped to [-step_size, step_size]).

super().__init__()
self.discrete = use_discrete
self._brains: Dict[str, BrainParameters] = {}
brain_params = BrainParameters(
brain_name=BRAIN_NAME,
vector_observation_space_size=OBS_SIZE,
camera_resolutions=[],
vector_action_space_size=[2] if use_discrete else [1],
vector_action_descriptions=["moveDirection"],
vector_action_space_type=0 if use_discrete else 1,
action_type = ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS
self.group_spec = AgentGroupSpec(
[(OBS_SIZE,)], action_type, (2,) if use_discrete else 1
self._brains[BRAIN_NAME] = brain_params
self.random = random.Random(str(brain_params))
self.random = random.Random(str(self.group_spec))
self.action = None
self.step_result = None
def step(
self,
vector_action: Dict[str, Any] = None,
memory: Dict[str, Any] = None,
value: Dict[str, Any] = None,
) -> AllBrainInfo:
assert vector_action is not None
def get_agent_groups(self):
return [BRAIN_NAME]
def get_agent_group_spec(self, name):
return self.group_spec
def set_action_for_agent(self, name, id, data):
pass
def set_actions(self, name, data):
self.action = data
def get_step_result(self, name):
return self.step_result
def step(self) -> None:
assert self.action is not None
act = vector_action[BRAIN_NAME][0][0]
act = self.action[0][0]
delta = vector_action[BRAIN_NAME][0][0]
delta = self.action[0][0]
delta = clamp(delta, -STEP_SIZE, STEP_SIZE)
self.position += delta
self.position = clamp(self.position, -1, 1)

else:
reward = -TIME_PENALTY
vector_obs = [self.goal] * OBS_SIZE
vector_obs_proto = ObservationProto(
float_data=ObservationProto.FloatData(data=vector_obs),
shape=[len(vector_obs)],
compression_type=COMPRESSION_TYPE_NONE,
)
agent_info = AgentInfoProto(
reward=reward, done=bool(done), observations=[vector_obs_proto]
)
m_vector_obs = [np.ones((1, OBS_SIZE), dtype=np.float32) * self.goal]
m_reward = np.array