浏览代码

Merge branch 'master' into self-play-mutex

/develop/cubewars
Andrew Cohen 4 年前
当前提交
4c9ac553
共有 27 个文件被更改,包括 220 次插入118 次删除
  1. 6
      Dockerfile
  2. 1
      README.md
  3. 2
      com.unity.ml-agents/CHANGELOG.md
  4. 3
      com.unity.ml-agents/Runtime/Agent.cs
  5. 17
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  6. 5
      com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs
  7. 2
      docs/Custom-SideChannels.md
  8. 1
      docs/Learning-Environment-Design-Agents.md
  9. 4
      docs/Migrating.md
  10. 6
      docs/Python-API.md
  11. 2
      gym-unity/gym_unity/envs/__init__.py
  12. 2
      ml-agents-envs/mlagents_envs/communicator.py
  13. 15
      ml-agents-envs/mlagents_envs/environment.py
  14. 3
      ml-agents-envs/mlagents_envs/mock_communicator.py
  15. 3
      ml-agents-envs/mlagents_envs/rpc_communicator.py
  16. 23
      ml-agents-envs/mlagents_envs/tests/test_envs.py
  17. 31
      ml-agents/mlagents/trainers/agent_processor.py
  18. 36
      ml-agents/mlagents/trainers/distributions.py
  19. 2
      ml-agents/mlagents/trainers/learn.py
  20. 1
      ml-agents/mlagents/trainers/policy/nn_policy.py
  21. 19
      ml-agents/mlagents/trainers/policy/tf_policy.py
  22. 1
      ml-agents/mlagents/trainers/sac/trainer.py
  23. 78
      ml-agents/mlagents/trainers/tests/simple_test_envs.py
  24. 9
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  25. 4
      ml-agents/mlagents/trainers/tests/test_meta_curriculum.py
  26. 58
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  27. 4
      ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py

6
Dockerfile


WORKDIR /ml-agents
RUN pip install -e .
# port 5005 is the port used in in Editor training.
EXPOSE 5005
# Port 5004 is the port used in in Editor training.
# Environments will start from port 5005,
# so allow enough ports for several environments.
EXPOSE 5004-5050
ENTRYPOINT ["mlagents-learn"]

1
README.md


## Releases & Documentation
**Our latest, stable release is 0.15.0. Click
[here](https://github.com/Unity-Technologies/ml-agents/tree/latest_release/docs/Readme.md) to
get started with the latest release of ML-Agents.**
The table below lists all our releases, including our `master` branch which is under active

2
com.unity.ml-agents/CHANGELOG.md


- Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616)
- Raise the wall in CrawlerStatic scene to prevent Agent from falling off. (#3650)
- Renamed 'Generalization' feature to 'Environment Parameter Randomization'.
- Fixed an issue where specifying `vis_encode_type` was required only for SAC. (#3677)
- The way that UnityEnvironment decides the port was changed. If no port is specified, the behavior will depend on the `file_name` parameter. If it is `None`, 5004 (the editor port) will be used; otherwise 5005 (the base environment port) will be used.
## [0.15.0-preview] - 2020-03-18
### Major Changes

3
com.unity.ml-agents/Runtime/Agent.cs


void NotifyAgentDone(DoneReason doneReason)
{
m_Info.episodeId = m_EpisodeId;
m_Info.reward = m_Reward;
m_Info.done = true;
m_Info.maxStepReached = doneReason == DoneReason.MaxStepReached;

// If everything is the same, don't make any changes.
return;
}
NotifyAgentDone(DoneReason.Disabled);
m_PolicyFactory.model = model;
m_PolicyFactory.inferenceDevice = inferenceDevice;
m_PolicyFactory.behaviorName = behaviorName;

17
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


{
if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName))
{
if (output == null)
if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0)
output = new UnityRLInitializationOutputProto();
}
// Only send the BrainParameters if there is a non empty list of
// AgentInfos ready to be sent.
// This is to ensure that The Python side will always have a first
// observation when receiving the BrainParameters
if (output == null)
{
output = new UnityRLInitializationOutputProto();
}
var brainParameters = m_UnsentBrainKeys[behaviorName];
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
var brainParameters = m_UnsentBrainKeys[behaviorName];
output.BrainParameters.Add(brainParameters.ToProto(behaviorName, true));
}
}
}

5
com.unity.ml-agents/Runtime/Policies/HeuristicPolicy.cs


public void RequestDecision(AgentInfo info, List<ISensor> sensors)
{
StepSensors(sensors);
m_LastDecision = m_Heuristic.Invoke();
if (!info.done)
{
m_LastDecision = m_Heuristic.Invoke();
}
}
/// <inheritdoc />

2
docs/Custom-SideChannels.md


string_log = StringLogChannel()
# We start the communication with the Unity Editor and pass the string_log side channel as input
env = UnityEnvironment(base_port=UnityEnvironment.DEFAULT_EDITOR_PORT, side_channels=[string_log])
env = UnityEnvironment(side_channels=[string_log])
env.reset()
string_log.send_string("The environment was reset")

1
docs/Learning-Environment-Design-Agents.md


```csharp
normalizedValue = (currentValue - minValue)/(maxValue - minValue)
```
:warning: For vectors, you should apply the above formula to each component (x, y, and z). Note that this is *not* the same as using the `Vector3.normalized` property or `Vector3.Normalize()` method in Unity (and similar for `Vector2`).
Rotations and angles should also be normalized. For angles between 0 and 360
degrees, you can use the following formulas:

4
docs/Migrating.md


* The interface for SideChannels was changed:
* In C#, `OnMessageReceived` now takes a `IncomingMessage` argument, and `QueueMessageToSend` takes an `OutgoingMessage` argument.
* In python, `on_message_received` now takes a `IncomingMessage` argument, and `queue_message_to_send` takes an `OutgoingMessage` argument.
* Automatic stepping for Academy is now controlled from the AutomaticSteppingEnabled property.
### Steps to Migrate
* Add the `using MLAgents.Sensors;` in addition to `using MLAgents;` on top of your Agent's script.

* We strongly recommend replacing the following methods with their new equivalent as they will be removed in a later release:
* `InitializeAgent()` to `Initialize()`
* `AgentAction()` to `OnActionReceived()`
* `AgentReset()` to `OnEpsiodeBegin()`
* `AgentReset()` to `OnEpisodeBegin()`
* Replace calls to Academy.EnableAutomaticStepping()/DisableAutomaticStepping() with Academy.AutomaticSteppingEnabled = true/false.
## Migrating from 0.13 to 0.14

6
docs/Python-API.md


```python
from mlagents_envs.environment import UnityEnvironment
env = UnityEnvironment(file_name="3DBall", base_port=5005, seed=1, side_channels=[])
env = UnityEnvironment(file_name="3DBall", seed=1, side_channels=[])
```
- `file_name` is the name of the environment binary (located in the root

channel = EngineConfigurationChannel()
env = UnityEnvironment(base_port = UnityEnvironment.DEFAULT_EDITOR_PORT, side_channels = [channel])
env = UnityEnvironment(side_channels=[channel])
channel.set_configuration_parameters(time_scale = 2.0)

channel = FloatPropertiesChannel()
env = UnityEnvironment(base_port = UnityEnvironment.DEFAULT_EDITOR_PORT, side_channels = [channel])
env = UnityEnvironment(side_channels=[channel])
channel.set_property("parameter_1", 2.0)

2
gym-unity/gym_unity/envs/__init__.py


:param no_graphics: Whether to run the Unity simulator in no-graphics mode
:param allow_multiple_visual_obs: If True, return a list of visual observations instead of only one.
"""
base_port = 5005
base_port = UnityEnvironment.BASE_ENVIRONMENT_PORT
if environment_filename is None:
base_port = UnityEnvironment.DEFAULT_EDITOR_PORT

2
ml-agents-envs/mlagents_envs/communicator.py


"""
Python side of the communication. Must be used in pair with the right Unity Communicator equivalent.
:int worker_id: Offset from base_port. Used for training multiple environments simultaneously.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
"""
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto:

15
ml-agents-envs/mlagents_envs/environment.py


# isn't specified, this port will be used.
DEFAULT_EDITOR_PORT = 5004
# Default base port for environments. Each environment will be offset from this
# by it's worker_id.
BASE_ENVIRONMENT_PORT = 5005
# Command line argument used to pass the port to the executable environment.
PORT_COMMAND_LINE_ARG = "--mlagents-port"

worker_id: int = 0,
base_port: int = 5005,
base_port: Optional[int] = None,
seed: int = 0,
docker_training: bool = False,
no_graphics: bool = False,

:string file_name: Name of Unity environment binary.
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
If no environment is specified (i.e. file_name is None), the DEFAULT_EDITOR_PORT will be used.
:int worker_id: Offset from base_port. Used for training multiple environments simultaneously.
:bool docker_training: Informs this class whether the process is being run within a container.
:bool no_graphics: Whether to run the Unity simulator in no-graphics mode
:int timeout_wait: Time (in seconds) to wait for connection from environment.

args = args or []
atexit.register(self._close)
# If base port is not specified, use BASE_ENVIRONMENT_PORT if we have
# an environment, otherwise DEFAULT_EDITOR_PORT
if base_port is None:
base_port = (
self.BASE_ENVIRONMENT_PORT if file_name else self.DEFAULT_EDITOR_PORT
)
self.port = base_port + worker_id
self._buffer_size = 12000
# If true, this means the environment was successfully loaded

3
ml-agents-envs/mlagents_envs/mock_communicator.py


):
"""
Python side of the grpc communication. Python is the client and Unity the server
:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
"""
super().__init__()
self.is_discrete = discrete_action

3
ml-agents-envs/mlagents_envs/rpc_communicator.py


:int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this.
:int worker_id: Number to add to communication port (5005) [0]. Used for asynchronous agent scenarios.
:int worker_id: Offset from base_port. Used for training multiple environments simultaneously.
:int timeout_wait: Timeout (in seconds) to wait for a response before exiting.
"""
super().__init__(worker_id, base_port)
self.port = base_port + worker_id

23
ml-agents-envs/mlagents_envs/tests/test_envs.py


env.close()
@pytest.mark.parametrize(
"base_port,file_name,expected",
[
# Non-None base port value will always be used
(6001, "foo.exe", 6001),
# No port specified and environment specified, so use BASE_ENVIRONMENT_PORT
(None, "foo.exe", UnityEnvironment.BASE_ENVIRONMENT_PORT),
# No port specified and no environment, so use DEFAULT_EDITOR_PORT
(None, None, UnityEnvironment.DEFAULT_EDITOR_PORT),
],
)
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_port_defaults(
mock_communicator, mock_launcher, base_port, file_name, expected
):
mock_communicator.return_value = MockCommunicator(
discrete_action=False, visual_inputs=0
)
env = UnityEnvironment(file_name=file_name, worker_id=0, base_port=base_port)
assert expected == env.port
@mock.patch("mlagents_envs.environment.UnityEnvironment.executable_launcher")
@mock.patch("mlagents_envs.environment.UnityEnvironment.get_communicator")
def test_reset(mock_communicator, mock_launcher):

31
ml-agents/mlagents/trainers/agent_processor.py


import sys
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Any
from collections import defaultdict, Counter, deque
from mlagents_envs.base_env import BatchedStepResult, StepResult

for _entropy in take_action_outputs["entropy"]:
self.stats_reporter.add_stat("Policy/Entropy", _entropy)
terminated_agents: Set[str] = set()
# Make unique agent_ids that are global across workers
action_global_agent_ids = [
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids

stored_take_action_outputs = self.last_take_action_outputs.get(
global_id, None
)
if stored_agent_step is not None and stored_take_action_outputs is not None:
# We know the step is from the same worker, so use the local agent id.
obs = stored_agent_step.obs

traj_queue.put(trajectory)
self.experience_buffers[global_id] = []
if curr_agent_step.done:
# Record episode length for agents which have had at least
# 1 step. Done after reset ignored.
terminated_agents.add(global_id)
elif not curr_agent_step.done:
self.episode_steps[global_id] += 1

batched_step_result.agent_id_to_index[_id],
)
for terminated_id in terminated_agents:
self._clean_agent_data(terminated_id)
# Delete all done agents, regardless of if they had a 0-length episode.
if curr_agent_step.done:
self._clean_agent_data(global_id)
for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,

"""
Removes the data for an Agent.
"""
del self.experience_buffers[global_id]
del self.last_take_action_outputs[global_id]
del self.last_step_result[global_id]
del self.episode_steps[global_id]
del self.episode_rewards[global_id]
self._safe_delete(self.experience_buffers, global_id)
self._safe_delete(self.last_take_action_outputs, global_id)
self._safe_delete(self.last_step_result, global_id)
self._safe_delete(self.episode_steps, global_id)
self._safe_delete(self.episode_rewards, global_id)
def _safe_delete(self, my_dictionary: Dict[Any, Any], key: Any) -> None:
"""
Safe removes data from a dictionary. If not found,
don't delete.
"""
if key in my_dictionary:
del my_dictionary[key]
def publish_trajectory_queue(
self, trajectory_queue: "AgentManagerQueue[Trajectory]"

36
ml-agents/mlagents/trainers/distributions.py


act_size: List[int],
reparameterize: bool = False,
tanh_squash: bool = False,
condition_sigma: bool = True,
log_sigma_min: float = -20,
log_sigma_max: float = 2,
):

:param log_sigma_max: Maximum log standard deviation to clip by.
"""
encoded = self._create_mu_log_sigma(
logits, act_size, log_sigma_min, log_sigma_max
logits,
act_size,
log_sigma_min,
log_sigma_max,
condition_sigma=condition_sigma,
)
self._sampled_policy = self._create_sampled_policy(encoded)
if not reparameterize:

act_size: List[int],
log_sigma_min: float,
log_sigma_max: float,
condition_sigma: bool,
) -> "GaussianDistribution.MuSigmaTensors":
mu = tf.layers.dense(

reuse=tf.AUTO_REUSE,
)
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
if condition_sigma:
# Policy-dependent log_sigma_sq
log_sigma = tf.layers.dense(
logits,
act_size[0],
activation=None,
name="log_std",
kernel_initializer=ModelUtils.scaled_init(0.01),
)
else:
log_sigma = tf.get_variable(
"log_std",
[act_size[0]],
dtype=tf.float32,
initializer=tf.zeros_initializer(),
)
log_sigma = tf.clip_by_value(log_sigma, log_sigma_min, log_sigma_max)
sigma = tf.exp(log_sigma)
return self.MuSigmaTensors(mu, log_sigma, sigma)

"""
Adjust probabilities for squashed sample before output
"""
probs -= tf.log(1 - squashed_policy ** 2 + EPSILON)
return probs
adjusted_probs = probs - tf.log(1 - squashed_policy ** 2 + EPSILON)
return adjusted_probs
@property
def total_log_probs(self) -> tf.Tensor:

2
ml-agents/mlagents/trainers/learn.py


)
argparser.add_argument(
"--base-port",
default=5005,
default=UnityEnvironment.BASE_ENVIRONMENT_PORT,
type=int,
help="Base port for environment communication",
)

1
ml-agents/mlagents/trainers/policy/nn_policy.py


self.act_size,
reparameterize=reparameterize,
tanh_squash=tanh_squash,
condition_sigma=condition_sigma_on_obs,
)
if tanh_squash:

19
ml-agents/mlagents/trainers/policy/tf_policy.py


if batched_step_result.n_agents() == 0:
return ActionInfo.empty()
agents_done = [
agent
for agent, done in zip(
batched_step_result.agent_id, batched_step_result.done
)
if done
]
self.remove_memories(agents_done)
self.remove_previous_action(agents_done)
global_agent_ids = [
get_global_agent_id(worker_id, int(agent_id))
for agent_id in batched_step_result.agent_id

def create_input_placeholders(self):
with self.graph.as_default():
self.global_step, self.increment_step_op, self.steps_to_increment = (
ModelUtils.create_global_steps()
)
(
self.global_step,
self.increment_step_op,
self.steps_to_increment,
) = ModelUtils.create_global_steps()
self.visual_in = ModelUtils.create_visual_input_placeholders(
self.brain.camera_resolutions
)

1
ml-agents/mlagents/trainers/sac/trainer.py


"memory_size",
"model_path",
"reward_signals",
"vis_encode_type",
]
self._check_param_keys()

78
ml-agents/mlagents/trainers/tests/simple_test_envs.py


return max(min_val, min(x, max_val))
class Simple1DEnvironment(BaseEnv):
class SimpleEnvironment(BaseEnv):
"""
Very simple "game" - the agent has a position on [-1, 1], gets a reward of 1 if it reaches 1, and a reward of -1 if
it reaches -1. The position is incremented by the action amount (clamped to [-step_size, step_size]).

num_vector=1,
vis_obs_size=VIS_OBS_SIZE,
vec_obs_size=OBS_SIZE,
action_size=1,
):
super().__init__()
self.discrete = use_discrete

self.vec_obs_size = vec_obs_size
action_type = ActionType.DISCRETE if use_discrete else ActionType.CONTINUOUS
self.group_spec = AgentGroupSpec(
self._make_obs_spec(), action_type, (2,) if use_discrete else 1
self._make_obs_spec(),
action_type,
tuple(2 for _ in range(action_size)) if use_discrete else action_size,
self.action_size = action_size
self.position: Dict[str, float] = {}
self.positions: Dict[str, List[float]] = {}
self.step_count: Dict[str, float] = {}
self.random = random.Random(str(self.group_spec))
self.goal: Dict[str, int] = {}

return self.step_result[name]
def _take_action(self, name: str) -> bool:
deltas = []
for _act in self.action[name][0]:
if self.discrete:
deltas.append(1 if _act else -1)
else:
deltas.append(_act)
for i, _delta in enumerate(deltas):
_delta = clamp(_delta, -self.step_size, self.step_size)
self.positions[name][i] += _delta
self.positions[name][i] = clamp(self.positions[name][i], -1, 1)
self.step_count[name] += 1
# Both must be in 1.0 to be done
done = all(pos >= 1.0 or pos <= -1.0 for pos in self.positions[name])
return done
def _generate_mask(self):
act = self.action[name][0][0]
delta = 1 if act else -1
# LL-Python API will return an empty dim if there is only 1 agent.
ndmask = np.array(2 * self.action_size * [False], dtype=np.bool)
ndmask = np.expand_dims(ndmask, axis=0)
action_mask = [ndmask]
delta = self.action[name][0][0]
delta = clamp(delta, -self.step_size, self.step_size)
self.position[name] += delta
self.position[name] = clamp(self.position[name], -1, 1)
self.step_count[name] += 1
done = self.position[name] >= 1.0 or self.position[name] <= -1.0
return done
action_mask = None
return action_mask
reward = SUCCESS_REWARD * self.position[name] * self.goal[name]
reward = 0.0
for _pos in self.positions[name]:
reward += (SUCCESS_REWARD * _pos * self.goal[name]) / len(
self.positions[name]
)
def _reset_agent(self, name):
self.goal[name] = self.random.choice([-1, 1])
self.positions[name] = [0.0 for _ in range(self.action_size)]
self.step_count[name] = 0
self.final_rewards[name].append(self.rewards[name])
self.rewards[name] = 0
self.agent_id[name] = self.agent_id[name] + 1
def _make_batched_step(
self, name: str, done: bool, reward: float

self.rewards[name] += reward
self.step_result[name] = self._make_batched_step(name, done, reward)
def _generate_mask(self):
if self.discrete:
# LL-Python API will return an empty dim if there is only 1 agent.
ndmask = np.array(2 * [False], dtype=np.bool)
ndmask = np.expand_dims(ndmask, axis=0)
action_mask = [ndmask]
else:
action_mask = None
return action_mask
def _reset_agent(self, name):
self.goal[name] = self.random.choice([-1, 1])
self.position[name] = 0.0
self.step_count[name] = 0
self.final_rewards[name].append(self.rewards[name])
self.rewards[name] = 0
self.agent_id[name] = self.agent_id[name] + 1
def reset(self) -> None: # type: ignore
for name in self.names:
self._reset_agent(name)

pass
class Memory1DEnvironment(Simple1DEnvironment):
class MemoryEnvironment(SimpleEnvironment):
def __init__(self, brain_names, use_discrete, step_size=0.2):
super().__init__(brain_names, use_discrete, step_size=step_size)
# Number of steps to reveal the goal for. Lower is harder. Should be

)
class Record1DEnvironment(Simple1DEnvironment):
class RecordEnvironment(SimpleEnvironment):
def __init__(
self,
brain_names,

9
ml-agents/mlagents/trainers/tests/test_agent_processor.py


assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0
assert len(processor.last_step_result.keys()) == 0
# check that steps with immediate dones don't add to dicts
processor.add_experiences(mock_done_step, 0, ActionInfo.empty())
assert len(processor.experience_buffers.keys()) == 0
assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0
assert len(processor.last_step_result.keys()) == 0
def test_end_episode():

4
ml-agents/mlagents/trainers/tests/test_meta_curriculum.py


import json
import yaml
from mlagents.trainers.tests.simple_test_envs import Simple1DEnvironment
from mlagents.trainers.tests.simple_test_envs import SimpleEnvironment
from mlagents.trainers.tests.test_simple_rl import _check_environment_trains, BRAIN_NAME
from mlagents.trainers.tests.test_curriculum import dummy_curriculum_json_str

@pytest.mark.parametrize("curriculum_brain_name", [BRAIN_NAME, "WrongBrainName"])
def test_simple_metacurriculum(curriculum_brain_name):
env = Simple1DEnvironment([BRAIN_NAME], use_discrete=False)
env = SimpleEnvironment([BRAIN_NAME], use_discrete=False)
curriculum_config = json.loads(dummy_curriculum_json_str)
mc = MetaCurriculum({curriculum_brain_name: curriculum_config})
trainer_config = yaml.safe_load(TRAINER_CONFIG)

58
ml-agents/mlagents/trainers/tests/test_simple_rl.py


from typing import Dict, Any
from mlagents.trainers.tests.simple_test_envs import (
Simple1DEnvironment,
Memory1DEnvironment,
Record1DEnvironment,
SimpleEnvironment,
MemoryEnvironment,
RecordEnvironment,
)
from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.trainer_util import TrainerFactory

@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_ppo(use_discrete):
env = Simple1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete)
config = generate_config(PPO_CONFIG)
_check_environment_trains(env, config)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_2d_ppo(use_discrete):
env = SimpleEnvironment(
[BRAIN_NAME], use_discrete=use_discrete, action_size=2, step_size=0.5
)
config = generate_config(PPO_CONFIG)
_check_environment_trains(env, config)

def test_visual_ppo(num_visual, use_discrete):
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=use_discrete,
num_visual=num_visual,

@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
def test_visual_advanced_ppo(vis_encode_type, num_visual):
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=True,
num_visual=num_visual,

@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_ppo(use_discrete):
env = Memory1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
env = MemoryEnvironment([BRAIN_NAME], use_discrete=use_discrete)
"max_steps": 3000,
"max_steps": 4000,
"learning_rate": 1e-3,
_check_environment_trains(env, config)
_check_environment_trains(env, config, success_threshold=0.9)
env = Simple1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete)
def test_2d_sac(use_discrete):
env = SimpleEnvironment(
[BRAIN_NAME], use_discrete=use_discrete, action_size=2, step_size=0.5
)
override_vals = {"buffer_init_steps": 2000, "max_steps": 3000}
config = generate_config(SAC_CONFIG, override_vals)
_check_environment_trains(env, config)
@pytest.mark.parametrize("use_discrete", [True, False])
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=use_discrete,
num_visual=num_visual,

@pytest.mark.parametrize("num_visual", [1, 2])
@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn"])
def test_visual_advanced_sac(vis_encode_type, num_visual):
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME],
use_discrete=True,
num_visual=num_visual,

@pytest.mark.parametrize("use_discrete", [True, False])
def test_recurrent_sac(use_discrete):
env = Memory1DEnvironment([BRAIN_NAME], use_discrete=use_discrete)
env = MemoryEnvironment([BRAIN_NAME], use_discrete=use_discrete)
override_vals = {"batch_size": 32, "use_recurrent": True, "max_steps": 2000}
config = generate_config(SAC_CONFIG, override_vals)
_check_environment_trains(env, config)

def test_simple_ghost(use_discrete):
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME + "?team=0", BRAIN_NAME + "?team=1"], use_discrete=use_discrete
)
override_vals = {

@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_ghost_fails(use_discrete):
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME + "?team=0", BRAIN_NAME + "?team=1"], use_discrete=use_discrete
)
# This config should fail because the ghosted policy is never swapped with a competent policy.

@pytest.fixture(scope="session")
def simple_record(tmpdir_factory):
def record_demo(use_discrete, num_visual=0, num_vector=1):
env = Record1DEnvironment(
env = RecordEnvironment(
[BRAIN_NAME],
use_discrete=use_discrete,
num_visual=num_visual,

@pytest.mark.parametrize("trainer_config", [PPO_CONFIG, SAC_CONFIG])
def test_gail(simple_record, use_discrete, trainer_config):
demo_path = simple_record(use_discrete)
env = Simple1DEnvironment([BRAIN_NAME], use_discrete=use_discrete, step_size=0.2)
env = SimpleEnvironment([BRAIN_NAME], use_discrete=use_discrete, step_size=0.2)
override_vals = {
"max_steps": 500,
"behavioral_cloning": {"demo_path": demo_path, "strength": 1.0, "steps": 1000},

@pytest.mark.parametrize("use_discrete", [True, False])
def test_gail_visual_ppo(simple_record, use_discrete):
demo_path = simple_record(use_discrete, num_visual=1, num_vector=0)
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME],
num_visual=1,
num_vector=0,

@pytest.mark.parametrize("use_discrete", [True, False])
def test_gail_visual_sac(simple_record, use_discrete):
demo_path = simple_record(use_discrete, num_visual=1, num_vector=0)
env = Simple1DEnvironment(
env = SimpleEnvironment(
[BRAIN_NAME],
num_visual=1,
num_vector=0,

4
ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py


from mlagents.trainers.env_manager import EnvironmentStep
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents.trainers.tests.simple_test_envs import Simple1DEnvironment
from mlagents.trainers.tests.simple_test_envs import SimpleEnvironment
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.tests.test_simple_rl import (
_check_environment_trains,

def simple_env_factory(worker_id, config):
env = Simple1DEnvironment(["1D"], use_discrete=True)
env = SimpleEnvironment(["1D"], use_discrete=True)
return env

正在加载...
取消
保存