浏览代码

Merge branch 'self-play-mutex' into soccer-2v1

/asymm-envs
Andrew Cohen 5 年前
当前提交
c7a34413
共有 24 个文件被更改,包括 426 次插入50 次删除
  1. 12
      Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs
  2. 3
      com.unity.ml-agents/CHANGELOG.md
  3. 28
      com.unity.ml-agents/Runtime/Academy.cs
  4. 16
      com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs
  5. 28
      com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs
  6. 7
      com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs
  7. 8
      config/trainer_config.yaml
  8. 4
      docs/Getting-Started.md
  9. 4
      docs/Training-Self-Play.md
  10. 7
      docs/Using-Tensorboard.md
  11. 18
      ml-agents-envs/mlagents_envs/environment.py
  12. 21
      ml-agents/mlagents/trainers/agent_processor.py
  13. 10
      ml-agents/mlagents/trainers/env_manager.py
  14. 18
      ml-agents/mlagents/trainers/ghost/controller.py
  15. 32
      ml-agents/mlagents/trainers/ghost/trainer.py
  16. 4
      ml-agents/mlagents/trainers/policy/tf_policy.py
  17. 6
      ml-agents/mlagents/trainers/simple_env_manager.py
  18. 19
      ml-agents/mlagents/trainers/subprocess_env_manager.py
  19. 34
      ml-agents/mlagents/trainers/tests/test_agent_processor.py
  20. 55
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  21. 11
      ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py
  22. 72
      com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs
  23. 11
      com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta
  24. 48
      ml-agents-envs/mlagents_envs/side_channel/stats_side_channel.py

12
Project/Assets/ML-Agents/Examples/FoodCollector/Scripts/FoodCollectorSettings.cs


using UnityEngine;
using UnityEngine.UI;
using MLAgents;
using MLAgents.SideChannels;
public class FoodCollectorSettings : MonoBehaviour
{

public int totalScore;
public Text scoreText;
StatsSideChannel m_statsSideChannel;
m_statsSideChannel = Academy.Instance.GetSideChannel<StatsSideChannel>();
}
public void EnvironmentReset()

public void Update()
{
scoreText.text = $"Score: {totalScore}";
// Send stats via SideChannel so that they'll appear in TensorBoard.
// These values get averaged every summary_frequency steps, so we don't
// need to send every Update() call.
if ((Time.frameCount % 100)== 0)
{
m_statsSideChannel?.AddStat("TotalScore", totalScore);
}
}
}

3
com.unity.ml-agents/CHANGELOG.md


### Minor Changes
- Format of console output has changed slightly and now matches the name of the model/summary directory. (#3630, #3616)
- Raise the wall in CrawlerStatic scene to prevent Agent from falling off. (#3650)
- Added a feature to allow sending stats from C# environments to TensorBoard (and other python StatsWriters). To do this from your code, use `Academy.Instance.GetSideChannel<StatsSideChannel>().AddStat(key, value)` (#3660)
- Fixed an issue where switching models using `SetModel()` during training would use an excessive amount of memory. (#3664)
- Environment subprocesses now close immediately on timeout or wrong API version. (#3679)
## [0.15.0-preview] - 2020-03-18
### Major Changes

28
com.unity.ml-agents/Runtime/Academy.cs


}
/// <summary>
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public T GetSideChannel<T>() where T: SideChannel
{
return Communicator?.GetSideChannel<T>();
}
/// <summary>
/// Returns all SideChannels of Type T that are registered. Use <see cref="GetSideChannel{T}()"/> if possible,
/// as that does not make any memory allocations.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
public List<T> GetSideChannels<T>() where T: SideChannel
{
if (Communicator == null)
{
// Make sure we return a non-null List.
return new List<T>();
}
return Communicator.GetSideChannels<T>();
}
/// <summary>
/// Disable stepping of the Academy during the FixedUpdate phase. If this is called, the Academy must be
/// stepped manually by the user by calling Academy.EnvironmentStep().
/// </summary>

{
Communicator.RegisterSideChannel(new EngineConfigurationChannel());
Communicator.RegisterSideChannel(floatProperties);
Communicator.RegisterSideChannel(new StatsSideChannel());
// We try to exchange the first message with Python. If this fails, it means
// no Python Process is ready to train the environment. In this case, the
//environment must use Inference.

16
com.unity.ml-agents/Runtime/Communicator/ICommunicator.cs


/// </summary>
/// <param name="sideChannel"> The side channel to be unregistered.</param>
void UnregisterSideChannel(SideChannel sideChannel);
/// <summary>
/// Returns the SideChannel of Type T if there is one registered, or null if it doesn't.
/// If there are multiple SideChannels of the same type registered, the returned instance is arbitrary.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
T GetSideChannel<T>() where T : SideChannel;
/// <summary>
/// Returns all SideChannels of Type T that are registered. Use <see cref="GetSideChannel{T}()"/> if possible,
/// as that does not make any memory allocations.
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
List<T> GetSideChannels<T>() where T : SideChannel;
}
}

28
com.unity.ml-agents/Runtime/Communicator/RpcCommunicator.cs


}
}
/// <inheritdoc/>
public T GetSideChannel<T>() where T: SideChannel
{
foreach (var sc in m_SideChannels.Values)
{
if (sc.GetType() == typeof(T))
{
return (T) sc;
}
}
return null;
}
/// <inheritdoc/>
public List<T> GetSideChannels<T>() where T: SideChannel
{
var output = new List<T>();
foreach (var sc in m_SideChannels.Values)
{
if (sc.GetType() == typeof(T))
{
output.Add((T) sc);
}
}
return output;
}
/// <summary>
/// Grabs the messages that the registered side channels will send to Python at the current step
/// into a singe byte array.

7
com.unity.ml-agents/Runtime/SideChannels/EngineConfigurationChannel.cs


/// </summary>
public class EngineConfigurationChannel : SideChannel
{
private const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7";
const string k_EngineConfigId = "e951342c-4f7e-11ea-b238-784f4387d1f7";
/// Initializes the side channel.
/// Initializes the side channel. The constructor is internal because only one instance is
/// supported at a time, and is created by the Academy.
public EngineConfigurationChannel()
internal EngineConfigurationChannel()
{
ChannelId = new Guid(k_EngineConfigId);
}

8
config/trainer_config.yaml


time_horizon: 1000
self_play:
window: 10
play_against_current_best_ratio: 0.2
play_against_latest_model_ratio: 0.2
save_steps: 50000
swap_steps: 50000
team_change: 100000

num_layers: 2
self_play:
window: 100
play_against_current_best_ratio: 0.2
play_against_latest_model_ratio: 0.2
save_steps: 50000
swap_steps: 25000
team_change: 200000

num_layers: 2
self_play:
window: 100
play_against_current_best_ratio: 0.2
play_against_latest_model_ratio: 0.2
save_steps: 50000
swap_steps: 100000
team_change: 200000

num_layers: 2
self_play:
window: 100
play_against_current_best_ratio: 0.2
play_against_latest_model_ratio: 0.2
save_steps: 50000
swap_steps: 50000
team_change: 200000

4
docs/Getting-Started.md


Depending on your version of Unity, it may be necessary to change the **Scripting Runtime Version** of your project. This can be done as follows:
1. Launch Unity
2. On the Projects dialog, choose the **Open** option at the top of the window.
1. Launch Unity Hub
2. On the Projects dialog, choose the **Add** option at the top of the window.
3. Using the file dialog that opens, locate the `Project` folder
within the ML-Agents toolkit project and click **Open**.
4. Go to **Edit** > **Project Settings** > **Player**

4
docs/Training-Self-Play.md


ML-Agents provides the functionality to train both symmetric and asymmetric adversarial games with
[Self-Play](https://openai.com/blog/competitive-self-play/).
A symmetric game is one in which opposing agents are equal in form, function snd objective. Examples of symmetric games
A symmetric game is one in which opposing agents are equal in form, function and objective. Examples of symmetric games
are our Tennis and Soccer example environments. In reinforcement learning, this means both agents have the same observation and
action spaces and learn from the same reward function and so *they can share the same policy*. In asymmetric games,
this is not the case. Examples of asymmetric games are Hide and Seek or Strikers vs Goalie in Soccer. Agents in these

### Swap Steps
The `swap_steps` parameter corresponds to the number of *ghost steps* (note, not trainer steps) between swapping the opponents policy with a different snapshot.
The `swap_steps` parameter corresponds to the number of *ghost steps* (not trainer steps) between swapping the opponents policy with a different snapshot.
A 'ghost step' refers to a step taken by an agent *that is following a fixed policy and not learning*. The reason for this distinction is that in asymmetric games,
we may have teams with an unequal number of agents e.g. the 2v1 scenario in our Strikers Vs Goalie environment. The team with two agents collects
twice as many agent steps per environment step as the team with one agent. Thus, these two values will need to be distinct to ensure that the same number

7
docs/Using-Tensorboard.md


taken between two observations.
* `Losses/Cloning Loss` (BC) - The mean magnitude of the behavioral cloning loss. Corresponds to how well the model imitates the demonstration data.
## Custom Metrics from C#
To get custom metrics from a C# environment into Tensorboard, you can use the StatsSideChannel:
```csharp
var statsSideChannel = Academy.Instance.GetSideChannel<StatsSideChannel>();
statsSideChannel.AddStat("MyMetric", 1.0);
```

18
ml-agents-envs/mlagents_envs/environment.py


aca_output = self.send_academy_parameters(rl_init_parameters_in)
aca_params = aca_output.rl_initialization_output
except UnityTimeOutException:
self._close()
self._close(0)
self._close()
self._close(0)
raise UnityEnvironmentException(
f"The communication API version is not compatible between Unity and python. "
f"Python API: {UnityEnvironment.API_VERSION}, Unity API: {unity_communicator_version}.\n "

def executable_launcher(self, file_name, docker_training, no_graphics, args):
launch_string = self.validate_environment_path(file_name)
if launch_string is None:
self._close()
self._close(0)
raise UnityEnvironmentException(
f"Couldn't launch the {file_name} environment. Provided filename does not match any environments."
)

else:
raise UnityEnvironmentException("No Unity environment is loaded.")
def _close(self):
def _close(self, timeout: Optional[int] = None) -> None:
"""
Close the communicator and environment subprocess (if necessary).
:int timeout: [Optional] Number of seconds to wait for the environment to shut down before
force-killing it. Defaults to `self.timeout_wait`.
"""
if timeout is None:
timeout = self.timeout_wait
self.proc1.wait(timeout=self.timeout_wait)
self.proc1.wait(timeout=timeout)
signal_name = self.returncode_to_signal_name(self.proc1.returncode)
signal_name = f" ({signal_name})" if signal_name else ""
return_info = f"Environment shut down with return code {self.proc1.returncode}{signal_name}."

21
ml-agents/mlagents/trainers/agent_processor.py


from collections import defaultdict, Counter, deque
from mlagents_envs.base_env import BatchedStepResult, StepResult
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy

self.behavior_id
)
self.publish_trajectory_queue(self.trajectory_queue)
def record_environment_stats(
self, env_stats: Dict[str, Tuple[float, StatsAggregationMethod]], worker_id: int
) -> None:
"""
Pass stats from the environment to the StatsReporter.
Depending on the StatsAggregationMethod, either StatsReporter.add_stat or StatsReporter.set_stat is used.
The worker_id is used to determin whether StatsReporter.set_stat should be used.
:param env_stats:
:param worker_id:
:return:
"""
for stat_name, (val, agg_type) in env_stats.items():
if agg_type == StatsAggregationMethod.AVERAGE:
self.stats_reporter.add_stat(stat_name, val)
elif agg_type == StatsAggregationMethod.MOST_RECENT:
# In order to prevent conflicts between multiple environments,
# only stats from the first environment are recorded.
if worker_id == 0:
self.stats_reporter.set_stat(stat_name, val)

10
ml-agents/mlagents/trainers/env_manager.py


from abc import ABC, abstractmethod
import logging
from typing import List, Dict, NamedTuple, Iterable
from typing import List, Dict, NamedTuple, Iterable, Tuple
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue

current_all_step_result: AllStepResult
worker_id: int
brain_name_to_action_info: Dict[AgentGroup, ActionInfo]
environment_stats: Dict[str, Tuple[float, StatsAggregationMethod]]
@property
def name_behavior_ids(self) -> Iterable[AgentGroup]:

def empty(worker_id: int) -> "EnvironmentStep":
return EnvironmentStep({}, worker_id, {})
return EnvironmentStep({}, worker_id, {}, {})
class EnvManager(ABC):

step_info.brain_name_to_action_info.get(
name_behavior_id, ActionInfo.empty()
),
)
self.agent_managers[name_behavior_id].record_environment_stats(
step_info.environment_stats, step_info.worker_id
)
return len(step_infos)

18
ml-agents/mlagents/trainers/ghost/controller.py


GhostController contains a queue of team ids. GhostTrainers subscribe to the GhostController and query
it to get the current learning team. The GhostController cycles through team ids every 'swap_interval'
which corresponds to the number of trainer steps between changing learning teams.
The GhostController is a unique object and there can only be one per training run.
"""
def __init__(self, maxlen: int = 10):

self._learning_team: int = -1
# Dict from team id to GhostTrainer for ELO calculation
self._ghost_trainers: Dict[int, GhostTrainer] = {}
@property
def get_learning_team(self) -> int:
"""
Returns the current learning team.
:return: The learning team id
"""
return self._learning_team
def subscribe_team_id(self, team_id: int, trainer: GhostTrainer) -> None:
"""

else:
self._queue.append(team_id)
def get_learning_team(self) -> int:
"""
Returns the current learning team.
:return: The learning team id
"""
return self._learning_team
def finish_training(self, step: int) -> None:
def change_training_team(self, step: int) -> None:
"""
The current learning team is added to the end of the queue and then updated with the
next in line.

32
ml-agents/mlagents/trainers/ghost/trainer.py


self_play_parameters = trainer_parameters["self_play"]
self.window = self_play_parameters.get("window", 10)
self.play_against_current_best_ratio = self_play_parameters.get(
"play_against_current_best_ratio", 0.5
self.play_against_latest_model_ratio = self_play_parameters.get(
"play_against_latest_model_ratio", 0.5
if self.steps_to_train_team > self.get_max_steps:
logger.warning(
"The max steps of the GhostTrainer for behavior name {} is less than team change. This team will not face \
opposition that has been trained if the opposition is managed by a different GhostTrainer as in an \
asymmetric game.".format(
self.brain_name
)
)
# Counts the The number of steps of the ghost policies. Snapshot swapping
# depends on this counter whereas snapshot saving and team switching depends

@property
def get_step(self) -> int:
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
Returns the number of steps the wrapped trainer has performed
:return: the step count of the wrapped trainer
"""
return self.trainer.get_step

self.next_summary_step = self.trainer.next_summary_step
self.trainer.advance()
self.controller.finish_training(self.get_step)
self.controller.change_training_team(self.get_step)
next_learning_team = self.controller.get_learning_team()
next_learning_team = self.controller.get_learning_team
# CASE 1: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the

policy.create_tf_graph()
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
# for saving/swapping snapshots
policy.init_load_weights()
# First policy or a new agent on the same team encountered
if self.wrapped_trainer_team is None or team_id == self.wrapped_trainer_team:

self._learning_team = self.controller.get_learning_team()
self._learning_team = self.controller.get_learning_team
else:
# for saving/swapping snapshots
policy.init_load_weights()
def get_policy(self, name_behavior_id: str) -> TFPolicy:
"""

for team_id in self._team_to_name_to_policy_queue:
if team_id == self._learning_team:
continue
elif np.random.uniform() < (1 - self.play_against_current_best_ratio):
elif np.random.uniform() < (1 - self.play_against_latest_model_ratio):
self.current_opponent = -1 if x == "current" else x
name_to_policy_queue = self._team_to_name_to_policy_queue[team_id]
for brain_name in self._team_to_name_to_policy_queue[team_id]:

4
ml-agents/mlagents/trainers/policy/tf_policy.py


self.assign_ops.append(tf.assign(var, assign_ph))
def load_weights(self, values):
if len(self.assign_ops) == 0:
logger.warning(
"Calling load_weights in tf_policy but assign_ops is empty. Did you forget to call init_load_weights?"
)
with self.graph.as_default():
feed_dict = {}
for assign_ph, value in zip(self.assign_phs, values):

6
ml-agents/mlagents/trainers/simple_env_manager.py


self.env.step()
all_step_result = self._generate_all_results()
step_info = EnvironmentStep(all_step_result, 0, self.previous_all_action_info)
step_info = EnvironmentStep(
all_step_result, 0, self.previous_all_action_info, {}
)
self.previous_step = step_info
return [step_info]

self.shared_float_properties.set_property(k, v)
self.env.reset()
all_step_result = self._generate_all_results()
self.previous_step = EnvironmentStep(all_step_result, 0, {})
self.previous_step = EnvironmentStep(all_step_result, 0, {}, {})
return [self.previous_step]
@property

19
ml-agents/mlagents/trainers/subprocess_env_manager.py


import logging
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set
from typing import Dict, NamedTuple, List, Any, Optional, Callable, Set, Tuple
import cloudpickle
from mlagents_envs.environment import UnityEnvironment

EngineConfigurationChannel,
EngineConfig,
)
from mlagents_envs.side_channel.stats_side_channel import (
StatsSideChannel,
StatsAggregationMethod,
)
from mlagents_envs.side_channel.side_channel import SideChannel
from mlagents.trainers.brain_conversion_utils import group_spec_to_brain_parameters

class StepResponse(NamedTuple):
all_step_result: AllStepResult
timer_root: Optional[TimerNode]
environment_stats: Dict[str, Tuple[float, StatsAggregationMethod]]
class UnityEnvWorker:

shared_float_properties = FloatPropertiesChannel()
engine_configuration_channel = EngineConfigurationChannel()
engine_configuration_channel.set_configuration(engine_configuration)
stats_channel = StatsSideChannel()
worker_id, [shared_float_properties, engine_configuration_channel]
worker_id,
[shared_float_properties, engine_configuration_channel, stats_channel],
)
def _send_response(cmd_name, payload):

# Note that we could randomly return timers a fraction of the time if we wanted to reduce
# the data transferred.
# TODO get gauges from the workers and merge them in the main process too.
step_response = StepResponse(all_step_result, get_timer_root())
env_stats = stats_channel.get_and_reset_stats()
step_response = StepResponse(
all_step_result, get_timer_root(), env_stats
)
step_queue.put(EnvironmentResponse("step", worker_id, step_response))
reset_timers()
elif cmd.name == "external_brains":

ew.send("reset", config)
# Next (synchronously) collect the reset observations from each worker in sequence
for ew in self.env_workers:
ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {})
ew.previous_step = EnvironmentStep(ew.recv().payload, ew.worker_id, {}, {})
return list(map(lambda ew: ew.previous_step, self.env_workers))
@property

payload.all_step_result,
step.worker_id,
env_worker.previous_all_action_info,
payload.environment_stats,
)
step_infos.append(new_step)
env_worker.previous_step = new_step

34
ml-agents/mlagents/trainers/tests/test_agent_processor.py


)
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.stats import StatsReporter, StatsSummary
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
def create_mock_brain():

queue_traj = queue.get_nowait()
assert isinstance(queue_traj, Trajectory)
assert queue.empty()
def test_agent_manager_stats():
policy = mock.Mock()
stats_reporter = StatsReporter("FakeCategory")
writer = mock.Mock()
stats_reporter.add_writer(writer)
manager = AgentManager(policy, "MyBehavior", stats_reporter)
all_env_stats = [
{
"averaged": (1.0, StatsAggregationMethod.AVERAGE),
"most_recent": (2.0, StatsAggregationMethod.MOST_RECENT),
},
{
"averaged": (3.0, StatsAggregationMethod.AVERAGE),
"most_recent": (4.0, StatsAggregationMethod.MOST_RECENT),
},
]
for env_stats in all_env_stats:
manager.record_environment_stats(env_stats, worker_id=0)
expected_stats = {
"averaged": StatsSummary(mean=2.0, std=mock.ANY, num=2),
"most_recent": StatsSummary(mean=4.0, std=0.0, num=1),
}
stats_reporter.write_stats(123)
writer.write_stats.assert_any_call("FakeCategory", expected_stats, 123)
# clean up our Mock from the global list
StatsReporter.writers.remove(writer)

55
ml-agents/mlagents/trainers/tests/test_simple_rl.py


override_vals = {
"max_steps": 2500,
"self_play": {
"play_against_current_self_ratio": 1.0,
"play_against_latest_model_ratio": 1.0,
"save_steps": 2000,
"swap_steps": 2000,
},

override_vals = {
"max_steps": 2500,
"self_play": {
"play_against_current_self_ratio": 1.0,
"play_against_latest_model_ratio": 1.0,
_check_environment_trains(env, config, success_threshold=None)
processed_rewards = [
default_reward_processor(rewards) for rewards in env.final_rewards.values()
]
success_threshold = 0.99
assert any(reward > success_threshold for reward in processed_rewards) and any(
reward < success_threshold for reward in processed_rewards
)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_asymm_ghost(use_discrete):
# Make opponent for asymmetric case
brain_name_opp = BRAIN_NAME + "Opp"
env = SimpleEnvironment(
[BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"], use_discrete=use_discrete
)
override_vals = {
"max_steps": 2000,
"self_play": {
"play_against_latest_model_ratio": 1.0,
"save_steps": 5000,
"swap_steps": 5000,
"team_change": 2000,
},
}
config = generate_config(PPO_CONFIG, override_vals)
config[brain_name_opp] = config[BRAIN_NAME]
_check_environment_trains(env, config)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_asymm_ghost_fails(use_discrete):
# Make opponent for asymmetric case
brain_name_opp = BRAIN_NAME + "Opp"
env = SimpleEnvironment(
[BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"], use_discrete=use_discrete
)
# This config should fail because the team that us not learning when both have reached
# max step should be executing the initial, untrained poliy.
override_vals = {
"max_steps": 2000,
"self_play": {
"play_against_latest_model_ratio": 0.0,
"save_steps": 5000,
"swap_steps": 5000,
"team_change": 2000,
},
}
config = generate_config(PPO_CONFIG, override_vals)
config[brain_name_opp] = config[BRAIN_NAME]
_check_environment_trains(env, config, success_threshold=None)
processed_rewards = [
default_reward_processor(rewards) for rewards in env.final_rewards.values()

11
ml-agents/mlagents/trainers/tests/test_subprocess_env_manager.py


from mlagents.trainers.env_manager import EnvironmentStep
from mlagents_envs.base_env import BaseEnv
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfig
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents.trainers.tests.simple_test_envs import SimpleEnvironment
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.tests.test_simple_rl import (

)
manager.step_queue = Mock()
manager.step_queue.get_nowait.side_effect = [
EnvironmentResponse("step", 0, StepResponse(0, None)),
EnvironmentResponse("step", 1, StepResponse(1, None)),
EnvironmentResponse("step", 0, StepResponse(0, None, {})),
EnvironmentResponse("step", 1, StepResponse(1, None, {})),
EmptyQueue(),
]
step_mock = Mock()

env_manager.set_agent_manager(brain_name, agent_manager_mock)
step_info_dict = {brain_name: Mock()}
step_info = EnvironmentStep(step_info_dict, 0, action_info_dict)
env_stats = {
"averaged": (1.0, StatsAggregationMethod.AVERAGE),
"most_recent": (2.0, StatsAggregationMethod.MOST_RECENT),
}
step_info = EnvironmentStep(step_info_dict, 0, action_info_dict, env_stats)
step_mock.return_value = [step_info]
env_manager.advance()

72
com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs


using System;
namespace MLAgents.SideChannels
{
/// <summary>
/// Determines the behavior of how multiple stats within the same summary period are combined.
/// </summary>
public enum StatAggregationMethod
{
/// <summary>
/// Values within the summary period are averaged before reporting.
/// Note that values from the same C# environment in the same step may replace each other.
/// </summary>
Average = 0,
/// <summary>
/// Only the most recent value is reported.
/// To avoid conflicts between multiple environments, the ML Agents environment will only
/// keep stats from worker index 0.
/// </summary>
MostRecent = 1
}
/// <summary>
/// Add stats (key-value pairs) for reporting. The ML Agents environment will send these to a StatsReporter
/// instance, which means the values will appear in the Tensorboard summary, as well as trainer gauges.
/// Note that stats are only written every summary_frequency steps; See <see cref="StatAggregationMethod"/>
/// for options on how multiple values are handled.
/// </summary>
public class StatsSideChannel : SideChannel
{
const string k_StatsSideChannelDefaultId = "a1d8f7b7-cec8-50f9-b78b-d3e165a78520";
/// <summary>
/// Initializes the side channel with the provided channel ID.
/// The constructor is internal because only one instance is
/// supported at a time, and is created by the Academy.
/// </summary>
internal StatsSideChannel()
{
ChannelId = new Guid(k_StatsSideChannelDefaultId);
}
/// <summary>
/// Add a stat value for reporting. This will appear in the Tensorboard summary and trainer gauges.
/// You can nest stats in Tensorboard with "/".
/// Note that stats are only written to Tensorboard each summary_frequency steps; if a stat is
/// received multiple times, only the most recent version is used.
/// To avoid conflicts between multiple environments, only stats from worker index 0 are used.
/// </summary>
/// <param name="key">The stat name.</param>
/// <param name="value">The stat value. You can nest stats in Tensorboard by using "/". </param>
/// <param name="aggregationMethod">How multiple values should be treated.</param>
public void AddStat(
string key, float value, StatAggregationMethod aggregationMethod = StatAggregationMethod.Average
)
{
using (var msg = new OutgoingMessage())
{
msg.WriteString(key);
msg.WriteFloat32(value);
msg.WriteInt32((int)aggregationMethod);
QueueMessageToSend(msg);
}
}
/// <inheritdoc/>
public override void OnMessageReceived(IncomingMessage msg)
{
throw new UnityAgentsException("StatsSideChannel should never receive messages.");
}
}
}

11
com.unity.ml-agents/Runtime/SideChannels/StatsSideChannel.cs.meta


fileFormatVersion: 2
guid: 83a07fdb9e8f04536908a51447dfe548
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

48
ml-agents-envs/mlagents_envs/side_channel/stats_side_channel.py


from mlagents_envs.side_channel import SideChannel, IncomingMessage
import uuid
from typing import Dict, Tuple
from enum import Enum
# Determines the behavior of how multiple stats within the same summary period are combined.
class StatsAggregationMethod(Enum):
# Values within the summary period are averaged before reporting.
AVERAGE = 0
# Only the most recent value is reported.
MOST_RECENT = 1
class StatsSideChannel(SideChannel):
"""
Side channel that receives (string, float) pairs from the environment, so that they can eventually
be passed to a StatsReporter.
"""
def __init__(self) -> None:
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/StatsSideChannel")
# UUID('a1d8f7b7-cec8-50f9-b78b-d3e165a78520')
super().__init__(uuid.UUID("a1d8f7b7-cec8-50f9-b78b-d3e165a78520"))
self.stats: Dict[str, Tuple[float, StatsAggregationMethod]] = {}
def on_message_received(self, msg: IncomingMessage) -> None:
"""
Receive the message from the environment, and save it for later retrieval.
:param msg:
:return:
"""
key = msg.read_string()
val = msg.read_float32()
agg_type = StatsAggregationMethod(msg.read_int32())
self.stats[key] = (val, agg_type)
def get_and_reset_stats(self) -> Dict[str, Tuple[float, StatsAggregationMethod]]:
"""
Returns the current stats, and resets the internal storage of the stats.
:return:
"""
s = self.stats
self.stats = {}
return s
正在加载...
取消
保存