比较提交

...
此合并请求有变更与目标分支冲突。
/docs/Migrating.md
/ml-agents/mlagents/trainers/trainer_controller.py
/ml-agents/mlagents/trainers/stats.py
/ml-agents/mlagents/trainers/ghost/trainer.py
/ml-agents/mlagents/trainers/behavior_id_utils.py
/ml-agents/mlagents/trainers/ppo/trainer.py
/ml-agents/mlagents/trainers/sac/trainer.py
/ml-agents/mlagents/trainers/trainer/trainer.py
/ml-agents/mlagents/trainers/tests/test_stats.py
/ml-agents/mlagents/trainers/ghost/controller.py
/config/trainer_config.yaml
/docs/Training-Self-Play.md
/ml-agents/mlagents/trainers/tests/test_ghost.py
/ml-agents/mlagents/trainers/tests/test_simple_rl.py
/ml-agents/mlagents/trainers/trainer_util.py
/ml-agents/mlagents/trainers/policy/tf_policy.py

5 次代码提交

共有 16 个文件被更改,包括 618 次插入215 次删除
  1. 6
      config/trainer_config.yaml
  2. 98
      docs/Training-Self-Play.md
  3. 1
      docs/Migrating.md
  4. 18
      ml-agents/mlagents/trainers/stats.py
  5. 52
      ml-agents/mlagents/trainers/behavior_id_utils.py
  6. 7
      ml-agents/mlagents/trainers/trainer_controller.py
  7. 7
      ml-agents/mlagents/trainers/ppo/trainer.py
  8. 5
      ml-agents/mlagents/trainers/sac/trainer.py
  9. 36
      ml-agents/mlagents/trainers/tests/test_ghost.py
  10. 55
      ml-agents/mlagents/trainers/tests/test_simple_rl.py
  11. 7
      ml-agents/mlagents/trainers/tests/test_stats.py
  12. 6
      ml-agents/mlagents/trainers/trainer_util.py
  13. 4
      ml-agents/mlagents/trainers/policy/tf_policy.py
  14. 5
      ml-agents/mlagents/trainers/trainer/trainer.py
  15. 434
      ml-agents/mlagents/trainers/ghost/trainer.py
  16. 92
      ml-agents/mlagents/trainers/ghost/controller.py

6
config/trainer_config.yaml


time_horizon: 1000
self_play:
window: 10
play_against_current_self_ratio: 0.5
play_against_latest_model_ratio: 0.5
team_change: 100000
Soccer:
normalize: false

num_layers: 2
self_play:
window: 10
play_against_current_self_ratio: 0.5
play_against_latest_model_ratio: 0.5
team_change: 100000
CrawlerStatic:
normalize: true

98
docs/Training-Self-Play.md


# Training with Self-Play
ML-Agents provides the functionality to train symmetric, adversarial games with [Self-Play](https://openai.com/blog/competitive-self-play/).
A symmetric game is one in which opposing agents are *equal* in form and function. In reinforcement learning,
this means both agents have the same observation and action spaces.
With self-play, an agent learns in adversarial games by competing against fixed, past versions of itself
to provide a more stable, stationary learning environment. This is compared
to competing against its current self in every episode, which is a constantly changing opponent.
ML-Agents provides the functionality to train both symmetric and asymmetric adversarial games with
[Self-Play](https://openai.com/blog/competitive-self-play/).
A symmetric game is one in which opposing agents are equal in form, function and objective. Examples of symmetric games
are our Tennis and Soccer example environments. In reinforcement learning, this means both agents have the same observation and
action spaces and learn from the same reward function and so *they can share the same policy*. In asymmetric games,
this is not the case. An example of an asymmetric games are Hide and Seek. Agents in these
types of games do not always have the same observation or action spaces and so sharing policy networks is not
necessarily ideal.
With self-play, an agent learns in adversarial games by competing against fixed, past versions of its opponent
(which could be itself as in symmetric games) to provide a more stable, stationary learning environment. This is compared
to competing against the current, best opponent in every episode, which is constantly changing (because it's learning).
However, from the perspective of an individual agent, these scenarios appear to have non-stationary dynamics because the opponent is often changing.
This can cause significant issues in the experience replay mechanism used by SAC. Thus, we recommend that users use PPO. For further reading on
this issue in particular, see the paper [Stabilising Experience Replay for Deep Multi-Agent Reinforcement Learning](https://arxiv.org/pdf/1702.08887.pdf).
For more general information on training with ML-Agents, see [Training ML-Agents](Training-ML-Agents.md).
For more algorithm specific instruction, please see the documentation for [PPO](Training-PPO.md) or [SAC](Training-SAC.md).

See the trainer configuration and agent prefabs for our Tennis environment for an example.
***Team ID must be 0 or an integer greater than 0.***
In symmetric games, since all agents (even on opposing teams) will share the same policy, they should have the same 'Behavior Name' in their
Behavior Parameters Script. In asymmetric games, they should have a different Behavior Name in their Behavior Parameters script.
Note, in asymmetric games, the agents must have both different Behavior Names *and* different team IDs! Then, specify the trainer configuration
for each Behavior Name in your scene as you would normally, and remember to include the self-play hyperparameter hierarchy!
For examples of how to use this feature, you can see the trainer configurations and agent prefabs for our Tennis and Soccer environments.
Tennis and Soccer provide examples of symmetric games. To train an asymmetric game, specify trainer configurations for each of your behavior names
and include the self-play hyperparameter hierarchy in both.
## Best Practices Training with Self-Play

Training against a set of slowly or unchanging adversaries with low diversity
results in a more stable learning process than training against a set of quickly
changing adversaries with high diversity. With this context, this guide discusses the exposed self-play hyperparameters and intuitions for tuning them.
changing adversaries with high diversity. With this context, this guide discusses
the exposed self-play hyperparameters and intuitions for tuning them.
## Hyperparameters

### Save Steps
The `save_steps` parameter corresponds to the number of *trainer steps* between snapshots. For example, if `save_steps`=10000 then a snapshot of the current policy will be saved every 10000 trainer steps. Note, trainer steps are counted per agent. For more information, please see the [migration doc](Migrating.md) after v0.13.
The `save_steps` parameter corresponds to the number of *trainer steps* between snapshots. For example, if `save_steps=10000` then a snapshot of the current policy will be saved every `10000` trainer steps. Note, trainer steps are counted per agent. For more information, please see the [migration doc](Migrating.md) after v0.13.
### Team Change
The `team_change` parameter corresponds to the number of *trainer_steps* between switching the learning team.
This is the number of trainer steps the teams associated with a specific ghost trainer will train before a different team
becomes the new learning team. It is possible that, in asymmetric games, opposing teams require fewer trainer steps to make similar
performance gains. This enables users to train a more complicated team of agents for more trainer steps than a simpler team of agents
per team switch.
A larger value of `team-change` will allow the agent to train longer against it's opponents. The longer an agent trains against the same set of opponents
the more able it will be to defeat them. However, training against them for too long may result in overfitting to the particular opponent strategies
and so the agent may fail against the next batch of opponents.
The value of `team-change` will determine how many snapshots of the agent's policy are saved to be used as opponents for the other team. So, we
recommend setting this value as a function of the `save_steps` parameter discussed previously.
Recommended Range : 4x-10x where x=`save_steps`
The `swap_steps` parameter corresponds to the number of *trainer steps* between swapping the opponents policy with a different snapshot. As in the `save_steps` discussion, note that trainer steps are counted per agent. For more information, please see the [migration doc](Migrating.md) after v0.13.
The `swap_steps` parameter corresponds to the number of *ghost steps* (not trainer steps) between swapping the opponents policy with a different snapshot.
A 'ghost step' refers to a step taken by an agent *that is following a fixed policy and not learning*. The reason for this distinction is that in asymmetric games,
we may have teams with an unequal number of agents e.g. a 2v1 scenario. The team with two agents collects
twice as many agent steps per environment step as the team with one agent. Thus, these two values will need to be distinct to ensure that the same number
of trainer steps corresponds to the same number of opponent swaps for each team. The formula for `swap_steps` if
a user desires `x` swaps of a team with `num_agents` agents against an opponent team with `num_opponent_agents`
agents during `team-change` total steps is:
```
swap_steps = (num_agents / num_opponent_agents) * (team_change / x)
```
As an example, in a 2v1 scenario, if we want the swap to occur `x=4` times during `team-change=200000` steps,
the `swap_steps` for the team of one agent is:
```
swap_steps = (1 / 2) * (200000 / 4) = 25000
```
The `swap_steps` for the team of two agents is:
```
swap_steps = (2 / 1) * (200000 / 4) = 100000
```
Note, with equal team sizes, the first term is equal to 1 and `swap_steps` can be calculated by just dividing the total steps by the desired number of swaps.
### Play against current self ratio
### Play against latest model ratio
The `play_against_current_self_ratio` parameter corresponds to the probability
an agent will play against its ***current*** self. With probability
1 - `play_against_current_self_ratio`, the agent will play against a snapshot of itself
from a past iteration.
The `play_against_latest_model_ratio` parameter corresponds to the probability
an agent will play against the latest opponent policy. With probability
1 - `play_against_latest_model_ratio`, the agent will play against a snapshot of its
opponent from a past iteration.
A larger value of `play_against_current_self_ratio` indicates that an agent will be playing against itself more often. Since the agent is updating it's policy, the opponent will be different from iteration to iteration. This can lead to an unstable learning environment, but poses the agent with an [auto-curricula](https://openai.com/blog/emergent-tool-use/) of more increasingly challenging situations which may lead to a stronger final policy.
A larger value of `play_against_latest_model_ratio` indicates that an agent will be playing against the current opponent more often. Since the agent is updating it's policy, the opponent will be different from iteration to iteration. This can lead to an unstable learning environment, but poses the agent with an [auto-curricula](https://openai.com/blog/emergent-tool-use/) of more increasingly challenging situations which may lead to a stronger final policy.
Recommended Range : 0.0 - 1.0
Range : 0.0 - 1.0
### Window

In adversarial games, the cumulative environment reward may not be a meaningful metric by which to track learning progress. This is because cumulative reward is entirely dependent on the skill of the opponent. An agent at a particular skill level will get more or less reward against a worse or better agent, respectively.
We provide an implementation of the ELO rating system, a method for calculating the relative skill level between two players from a given population in a zero-sum game. For more information on ELO, please see [the ELO wiki](https://en.wikipedia.org/wiki/Elo_rating_system).
In a proper training run, the ELO of the agent should steadily increase. The absolute value of the ELO is less important than the change in ELO over training iterations.
In a proper training run, the ELO of the agent should steadily increase. The absolute value of the ELO is less important than the change in ELO over training iterations.
Note, this implementation will support any number of teams but ELO is only applicable to games with two teams. It is ongoing work to implement
a reliable metric for measuring progress in scenarios with three or more teams. These scenarios can still train, though as of now, reward and qualitative observations
are the only metric by which we can judge performance.

1
docs/Migrating.md


### Important changes
* The `--load` and `--train` command-line flags have been deprecated and replaced with `--resume` and `--inference`.
* Running with the same `--run-id` twice will now throw an error.
* The `play_against_current_self_ratio` self-play trainer hyperparameter has been renamed to `play_against_latest_model_ratio`
### Steps to Migrate
* Replace the `--load` flag with `--resume` when calling `mlagents-learn`, and don't use the `--train` flag as training

18
ml-agents/mlagents/trainers/stats.py


class StatsPropertyType(Enum):
HYPERPARAMETERS = "hyperparameters"
SELF_PLAY = "selfplay"
SELF_PLAY_TEAM = "selfplayteam"
class StatsWriter(abc.ABC):

)
if self.self_play and "Self-play/ELO" in values:
elo_stats = values["Self-play/ELO"]
mean_opponent_elo = values["Self-play/Mean Opponent ELO"]
std_opponent_elo = values["Self-play/Std Opponent ELO"]
logger.info(
"{} Team {}: ELO: {:0.3f}. "
"Mean Opponent ELO: {:0.3f}. "
"Std Opponent ELO: {:0.3f}. ".format(
category,
self.self_play_team,
elo_stats.mean,
mean_opponent_elo.mean,
std_opponent_elo.mean,
)
)
logger.info("{} ELO: {:0.3f}. ".format(category, elo_stats.mean))
else:
logger.info(
"{}: Step: {}. No episode was completed since last summary. {}".format(

elif property_type == StatsPropertyType.SELF_PLAY:
assert isinstance(value, bool)
self.self_play = value
elif property_type == StatsPropertyType.SELF_PLAY_TEAM:
assert isinstance(value, int)
self.self_play_team = value
def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""

52
ml-agents/mlagents/trainers/behavior_id_utils.py


from typing import Dict, NamedTuple
from typing import NamedTuple
from urllib.parse import urlparse, parse_qs
name_behavior_id: str
"""
BehaviorIdentifiers is a named tuple of the identifiers that uniquely distinguish
an agent encountered in the trainer_controller. The named tuple consists of the
fully qualified behavior name, the name of the brain name (corresponds to trainer
in the trainer controller) and the team id. In the future, this can be extended
to support further identifiers.
"""
behavior_id: str
behavior_ids: Dict[str, int]
team_id: int
Parses a name_behavior_id of the form name?team=0&param1=i&...
Parses a name_behavior_id of the form name?team=0
This allows you to access the brain name and distinguishing identifiers
without parsing more than once.
This allows you to access the brain name and team id of an agent
ids: Dict[str, int] = {}
if "?" in name_behavior_id:
name, identifiers = name_behavior_id.rsplit("?", 1)
if "&" in identifiers:
list_of_identifiers = identifiers.split("&")
else:
list_of_identifiers = [identifiers]
for identifier in list_of_identifiers:
key, value = identifier.split("=")
ids[key] = int(value)
else:
name = name_behavior_id
parsed = urlparse(name_behavior_id)
name = parsed.path
ids = parse_qs(parsed.query)
team_id: int = 0
if "team" in ids:
team_id = int(ids["team"][0])
name_behavior_id=name_behavior_id, brain_name=name, behavior_ids=ids
behavior_id=name_behavior_id, brain_name=name, team_id=team_id
def create_name_behavior_id(name: str, team_id: int) -> str:
"""
Reconstructs fully qualified behavior name from name and team_id
:param name: brain name
:param team_id: team ID
:return: name_behavior_id
"""
return name + "?team=" + str(team_id)

7
ml-agents/mlagents/trainers/trainer_controller.py


self, env_manager: EnvManager, name_behavior_id: str
) -> None:
brain_name = BehaviorIdentifiers.from_name_behavior_id(
name_behavior_id
).brain_name
parsed_behavior_id = BehaviorIdentifiers.from_name_behavior_id(name_behavior_id)
brain_name = parsed_behavior_id.brain_name
try:
trainer = self.trainers[brain_name]
except KeyError:

policy = trainer.create_policy(env_manager.external_brains[name_behavior_id])
trainer.add_policy(name_behavior_id, policy)
trainer.add_policy(parsed_behavior_id, policy)
agent_manager = AgentManager(
policy,

7
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
logger = get_logger(__name__)

return policy
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
) -> None:
:param name_behavior_id: Behavior ID that the policy should belong to.
:param parsed_behavior_id: Behavior identifiers that the policy should belong to.
:param policy: Policy to associate with name_behavior_id.
"""
if self.policy:

5
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
logger = get_logger(__name__)

for stat, stat_list in batch_update_stats.items():
self._stats_reporter.add_stat(stat, np.mean(stat_list))
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
) -> None:
"""
Adds policy to trainer.
:param brain_parameters: specifications for policy construction

36
ml-agents/mlagents/trainers/tests/test_ghost.py


import yaml
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory

dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0")
trainer = GhostTrainer(ppo_trainer, brain_name, 0, dummy_config, True, "0")
controller = GhostController(100)
trainer = GhostTrainer(
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0"
)
trainer.add_policy(brain_params_team0.brain_name, policy)
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(
brain_params_team0.brain_name
)
trainer.add_policy(parsed_behavior_id0, policy)
trainer.add_policy(brain_params_team1.brain_name, policy)
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(
brain_params_team1.brain_name
)
trainer.add_policy(parsed_behavior_id1, policy)
trajectory_queue1 = AgentManagerQueue(brain_params_team1.brain_name)
trainer.subscribe_trajectory_queue(trajectory_queue1)

vector_action_space_type=0,
)
brain_name = BehaviorIdentifiers.from_name_behavior_id(
parsed_behavior_id0 = BehaviorIdentifiers.from_name_behavior_id(
).brain_name
)
brain_name = parsed_behavior_id0.brain_name
brain_params_team1 = BrainParameters(
brain_name="test_brain?team=1",

dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0")
trainer = GhostTrainer(ppo_trainer, brain_name, 0, dummy_config, True, "0")
controller = GhostController(100)
trainer = GhostTrainer(
ppo_trainer, brain_name, controller, 0, dummy_config, True, "0"
)
trainer.add_policy(brain_params_team0.brain_name, policy)
trainer.add_policy(parsed_behavior_id0, policy)
trainer.add_policy(brain_params_team1.brain_name, policy)
parsed_behavior_id1 = BehaviorIdentifiers.from_name_behavior_id(
brain_params_team1.brain_name
)
trainer.add_policy(parsed_behavior_id1, policy)
policy_queue1 = AgentManagerQueue(brain_params_team1.brain_name)
trainer.publish_policy_queue(policy_queue1)

55
ml-agents/mlagents/trainers/tests/test_simple_rl.py


override_vals = {
"max_steps": 2500,
"self_play": {
"play_against_current_self_ratio": 1.0,
"play_against_latest_model_ratio": 1.0,
"save_steps": 2000,
"swap_steps": 2000,
},

override_vals = {
"max_steps": 2500,
"self_play": {
"play_against_current_self_ratio": 1.0,
"play_against_latest_model_ratio": 1.0,
"save_steps": 2000,
"swap_steps": 4000,
},

default_reward_processor(rewards) for rewards in env.final_rewards.values()
]
success_threshold = 0.9
assert any(reward > success_threshold for reward in processed_rewards) and any(
reward < success_threshold for reward in processed_rewards
)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_asymm_ghost(use_discrete):
# Make opponent for asymmetric case
brain_name_opp = BRAIN_NAME + "Opp"
env = SimpleEnvironment(
[BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"], use_discrete=use_discrete
)
override_vals = {
"max_steps": 2000,
"self_play": {
"play_against_latest_model_ratio": 1.0,
"save_steps": 5000,
"swap_steps": 5000,
"team_change": 2000,
},
}
config = generate_config(PPO_CONFIG, override_vals)
config[brain_name_opp] = config[BRAIN_NAME]
_check_environment_trains(env, config)
@pytest.mark.parametrize("use_discrete", [True, False])
def test_simple_asymm_ghost_fails(use_discrete):
# Make opponent for asymmetric case
brain_name_opp = BRAIN_NAME + "Opp"
env = SimpleEnvironment(
[BRAIN_NAME + "?team=0", brain_name_opp + "?team=1"], use_discrete=use_discrete
)
# This config should fail because the team that us not learning when both have reached
# max step should be executing the initial, untrained poliy.
override_vals = {
"max_steps": 2000,
"self_play": {
"play_against_latest_model_ratio": 0.0,
"save_steps": 5000,
"swap_steps": 5000,
"team_change": 2000,
},
}
config = generate_config(PPO_CONFIG, override_vals)
config[brain_name_opp] = config[BRAIN_NAME]
_check_environment_trains(env, config, success_threshold=None)
processed_rewards = [
default_reward_processor(rewards) for rewards in env.final_rewards.values()
]
success_threshold = 0.99
assert any(reward > success_threshold for reward in processed_rewards) and any(
reward < success_threshold for reward in processed_rewards
)

7
ml-agents/mlagents/trainers/tests/test_stats.py


category = "category1"
console_writer = ConsoleWriter()
console_writer.add_property(category, StatsPropertyType.SELF_PLAY, True)
console_writer.add_property(category, StatsPropertyType.SELF_PLAY_TEAM, 1)
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
console_writer.write_stats(
category,

"Self-play/ELO": statssummary1,
"Self-play/Mean Opponent ELO": statssummary1,
"Self-play/Std Opponent ELO": statssummary1,
},
10,
)

)
self.assertIn(
"category1 Team 1: ELO: 1.000. Mean Opponent ELO: 1.000. Std Opponent ELO: 1.000.",
cm.output[1],
)

6
ml-agents/mlagents/trainers/trainer_util.py


from mlagents.trainers.ppo.trainer import PPOTrainer
from mlagents.trainers.sac.trainer import SACTrainer
from mlagents.trainers.ghost.trainer import GhostTrainer
from mlagents.trainers.ghost.controller import GhostController
logger = get_logger(__name__)

self.seed = seed
self.meta_curriculum = meta_curriculum
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()
def generate(self, brain_name: str) -> Trainer:
return initialize_trainer(

self.keep_checkpoints,
self.train_model,
self.load_model,
self.ghost_controller,
self.seed,
self.meta_curriculum,
self.multi_gpu,

keep_checkpoints: int,
train_model: bool,
load_model: bool,
ghost_controller: GhostController,
seed: int,
meta_curriculum: MetaCurriculum = None,
multi_gpu: bool = False,

:param keep_checkpoints: How many model checkpoints to keep
:param train_model: Whether to train the model (vs. run inference)
:param load_model: Whether to load the model or randomly initialize
:param ghost_controller: The object that coordinates ghost trainers
:param seed: The random seed to use
:param meta_curriculum: Optional meta_curriculum, used to determine a reward buffer length for PPOTrainer
:return:

trainer = GhostTrainer(
trainer,
brain_name,
ghost_controller,
min_lesson_length,
trainer_parameters,
train_model,

4
ml-agents/mlagents/trainers/policy/tf_policy.py


self.assign_ops.append(tf.assign(var, assign_ph))
def load_weights(self, values):
if len(self.assign_ops) == 0:
logger.warning(
"Calling load_weights in tf_policy but assign_ops is empty. Did you forget to call init_load_weights?"
)
with self.graph.as_default():
feed_dict = {}
for assign_ph, value in zip(self.assign_phs, values):

5
ml-agents/mlagents/trainers/trainer/trainer.py


from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.policy import Policy
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
logger = get_logger(__name__)

pass
@abc.abstractmethod
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
) -> None:
"""
Adds policy to trainer.
"""

434
ml-agents/mlagents/trainers/ghost/trainer.py


# # Unity ML-Agents Toolkit
# ## ML-Agent Learning (Ghost Trainer)
from typing import Deque, Dict, List, Any, cast
from typing import Deque, Dict, List, cast
import numpy as np

from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.behavior_id_utils import (
BehaviorIdentifiers,
create_name_behavior_id,
)
logger = get_logger(__name__)

"""
The GhostTrainer trains agents in adversarial games (there are teams in opposition) using a self-play mechanism.
In adversarial settings with self-play, at any time, there is only a single learning team. The other team(s) is
"ghosted" which means that its agents are executing fixed policies and not learning. The GhostTrainer wraps
a standard RL trainer which trains the learning team and ensures that only the trajectories collected
by the learning team are used for training. The GhostTrainer also maintains past policy snapshots to be used
as the fixed policies when the team is not learning. The GhostTrainer is 1:1 with brain_names as the other
trainers, and is responsible for one or more teams. Note, a GhostTrainer can have only one team in
asymmetric games where there is only one team with a particular behavior i.e. Hide and Seek.
The GhostController manages high level coordination between multiple ghost trainers. The learning team id
is cycled throughout a training run.
"""
self, trainer, brain_name, reward_buff_cap, trainer_parameters, training, run_id
self,
trainer,
brain_name,
controller,
reward_buff_cap,
trainer_parameters,
training,
run_id,
Responsible for collecting experiences and training trainer model via self_play.
Creates a GhostTrainer.
:param controller: GhostController that coordinates all ghost trainers and calculates ELO
:param reward_buff_cap: Max reward history to track in the reward buffer
:param trainer_parameters: The parameters for the trainer (dictionary).
:param training: Whether the trainer is set for training.

)
self.trainer = trainer
self.controller = controller
self.internal_policy_queues: List[AgentManagerQueue[Policy]] = []
self.internal_trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
self.ignored_trajectory_queues: List[AgentManagerQueue[Trajectory]] = []
self.learning_policy_queues: Dict[str, AgentManagerQueue[Policy]] = {}
self._internal_trajectory_queues: Dict[str, AgentManagerQueue[Trajectory]] = {}
self._internal_policy_queues: Dict[str, AgentManagerQueue[Policy]] = {}
self._team_to_name_to_policy_queue: Dict[
int, Dict[str, AgentManagerQueue[Policy]]
] = {}
self._name_to_parsed_behavior_id: Dict[str, BehaviorIdentifiers] = {}
# assign ghost's stats collection to wrapped trainer's
self._stats_reporter = self.trainer.stats_reporter

self_play_parameters = trainer_parameters["self_play"]
self.window = self_play_parameters.get("window", 10)
self.play_against_current_self_ratio = self_play_parameters.get(
"play_against_current_self_ratio", 0.5
self.play_against_latest_model_ratio = self_play_parameters.get(
"play_against_latest_model_ratio", 0.5
if (
self.play_against_latest_model_ratio > 1.0
or self.play_against_latest_model_ratio < 0.0
):
logger.warning(
"The play_against_latest_model_ratio is not between 0 and 1."
)
self.steps_to_train_team = self_play_parameters.get("team_change", 100000)
if self.steps_to_train_team > self.get_max_steps:
logger.warning(
"The max steps of the GhostTrainer for behavior name {} is less than team change. This team will not face \
opposition that has been trained if the opposition is managed by a different GhostTrainer as in an \
asymmetric game.".format(
self.brain_name
)
)
self.policies: Dict[str, TFPolicy] = {}
self.policy_snapshots: List[Any] = []
# Counts the The number of steps of the ghost policies. Snapshot swapping
# depends on this counter whereas snapshot saving and team switching depends
# on the wrapped. This ensures that all teams train for the same number of trainer
# steps.
self.ghost_step: int = 0
# A list of dicts from brain name to a single snapshot for this trainer's policies
self.policy_snapshots: List[Dict[str, List[float]]] = []
# A dict from brain name to the current snapshot of this trainer's policies
self.current_policy_snapshot: Dict[str, List[float]] = {}
self.learning_behavior_name: str = None
self.current_policy_snapshot = None
self.last_save = 0
self.last_swap = 0
self.policies: Dict[str, TFPolicy] = {}
# wrapped_training_team and learning team need to be separate
# in the situation where new agents are created destroyed
# after learning team switches. These agents need to be added
# to trainers properly.
self._learning_team: int = None
self.wrapped_trainer_team: int = None
self.last_save: int = 0
self.last_swap: int = 0
self.last_team_change: int = 0
self.current_elo: float = self.initial_elo
self.policy_elos: List[float] = [self.initial_elo] * (
self.window + 1
) # for learning policy

def get_step(self) -> int:
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
Returns the number of steps the wrapped trainer has performed
:return: the step count of the wrapped trainer
"""
return self.trainer.get_step

"""
return self.trainer.reward_buffer
@property
def current_elo(self) -> float:
"""
Gets ELO of current policy which is always last in the list
:return: ELO of current policy
"""
return self.policy_elos[-1]
def change_current_elo(self, change: float) -> None:
"""
Changes elo of current policy which is always last in the list
:param change: Amount to change current elo by
"""
self.policy_elos[-1] += change
def get_opponent_elo(self) -> float:
"""
Get elo of current opponent policy
:return: ELO of current opponent policy
"""
return self.policy_elos[self.current_opponent]
def change_opponent_elo(self, change: float) -> None:
"""
Changes elo of current opponent policy
:param change: Amount to change current opponent elo by
"""
self.policy_elos[self.current_opponent] -= change
if trajectory.done_reached and not trajectory.max_step_reached:
# Assumption is that final reward is 1/.5/0 for win/draw/loss
"""
Determines the final result of an episode and asks the GhostController
to calculate the ELO change. The GhostController changes the ELO
of the opponent policy since this may be in a different GhostTrainer
i.e. in asymmetric games. We assume the last reward determines the winner.
:param trajectory: Trajectory.
"""
if trajectory.done_reached:
# Assumption is that final reward is >0/0/<0 for win/draw/loss
final_reward = trajectory.steps[-1].reward
result = 0.5
if final_reward > 0:

change = compute_elo_rating_changes(
self.current_elo, self.policy_elos[self.current_opponent], result
change = self.controller.compute_elo_rating_changes(
self.current_elo, result
self.current_elo += change
self.policy_elos[self.current_opponent] -= change
opponents = np.array(self.policy_elos, dtype=np.float32)
self.change_current_elo(change)
self._stats_reporter.add_stat(
"Self-play/Mean Opponent ELO", opponents.mean()
)
self._stats_reporter.add_stat("Self-play/Std Opponent ELO", opponents.std())
for traj_queue, internal_traj_queue in zip(
self.trajectory_queues, self.internal_trajectory_queues
):
try:
# We grab at most the maximum length of the queue.
# This ensures that even if the queue is being filled faster than it is
# being emptied, the trajectories in the queue are on-policy.
for _ in range(traj_queue.maxlen):
t = traj_queue.get_nowait()
# adds to wrapped trainers queue
internal_traj_queue.put(t)
self._process_trajectory(t)
except AgentManagerQueue.Empty:
pass
for trajectory_queue in self.trajectory_queues:
parsed_behavior_id = self._name_to_parsed_behavior_id[
trajectory_queue.behavior_id
]
if parsed_behavior_id.team_id == self._learning_team:
# With a future multiagent trainer, this will be indexed by 'role'
internal_trajectory_queue = self._internal_trajectory_queues[
parsed_behavior_id.brain_name
]
try:
# We grab at most the maximum length of the queue.
# This ensures that even if the queue is being filled faster than it is
# being emptied, the trajectories in the queue are on-policy.
for _ in range(trajectory_queue.maxlen):
t = trajectory_queue.get_nowait()
# adds to wrapped trainers queue
internal_trajectory_queue.put(t)
self._process_trajectory(t)
except AgentManagerQueue.Empty:
pass
else:
# Dump trajectories from non-learning policy
try:
for _ in range(trajectory_queue.maxlen):
t = trajectory_queue.get_nowait()
# count ghost steps
self.ghost_step += len(t.steps)
except AgentManagerQueue.Empty:
pass
if self.get_step - self.last_team_change > self.steps_to_train_team:
self.controller.change_training_team(self.get_step)
self.last_team_change = self.get_step
for internal_q in self.internal_policy_queues:
# Get policies that correspond to the policy queue in question
next_learning_team = self.controller.get_learning_team
# CASE 1: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 2: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
# Case 3: No team change. The if statement just continues to push the policy
# into the correct queue (or not if not learning team).
for brain_name in self._internal_policy_queues:
internal_policy_queue = self._internal_policy_queues[brain_name]
policy = cast(TFPolicy, internal_q.get_nowait())
self.current_policy_snapshot = policy.get_weights()
self.learning_policy_queues[internal_q.behavior_id].put(policy)
policy = cast(TFPolicy, internal_policy_queue.get_nowait())
self.current_policy_snapshot[brain_name] = policy.get_weights()
if next_learning_team in self._team_to_name_to_policy_queue:
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
if brain_name in name_to_policy_queue:
behavior_id = create_name_behavior_id(
brain_name, next_learning_team
)
policy = self.get_policy(behavior_id)
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)
# Note save and swap should be on different step counters.
# We don't want to save unless the policy is learning.
self._save_snapshot(self.trainer.policy)
self._save_snapshot()
if self.get_step - self.last_swap > self.steps_between_swap:
if (
self._learning_team != next_learning_team
or self.ghost_step - self.last_swap > self.steps_between_swap
):
self._learning_team = next_learning_team
self.last_swap = self.get_step
# Dump trajectories from non-learning policy
for traj_queue in self.ignored_trajectory_queues:
try:
for _ in range(traj_queue.maxlen):
traj_queue.get_nowait()
except AgentManagerQueue.Empty:
pass
self.last_swap = self.ghost_step
"""
Forwarding call to wrapped trainers end_episode
"""
"""
Forwarding call to wrapped trainers save_model
"""
self.trainer.export_model(name_behavior_id)
"""
Forwarding call to wrapped trainers export_model.
First loads the current snapshot.
"""
parsed_behavior_id = self._name_to_parsed_behavior_id[name_behavior_id]
brain_name = parsed_behavior_id.brain_name
policy = self.trainer.get_policy(brain_name)
policy.load_weights(self.current_policy_snapshot[brain_name])
self.trainer.export_model(brain_name)
"""
Creates policy with the wrapped trainer's create_policy function
"""
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
) -> None:
Adds policy to trainer. For the first policy added, add a trainer
to the policy and set the learning behavior name to name_behavior_id.
Adds policy to trainer. The first policy encountered sets the wrapped
trainer team. This is to ensure that all agents from the same multi-agent
team are grouped. All policies associated with this team are added to the
wrapped trainer to be trained.
name_behavior_id = parsed_behavior_id.behavior_id
team_id = parsed_behavior_id.team_id
self.controller.subscribe_team_id(team_id, self)
# First policy encountered
if not self.learning_behavior_name:
weights = policy.get_weights()
self.current_policy_snapshot = weights
self.trainer.add_policy(name_behavior_id, policy)
self._save_snapshot(policy) # Need to save after trainer initializes policy
self.learning_behavior_name = name_behavior_id
behavior_id_parsed = BehaviorIdentifiers.from_name_behavior_id(
self.learning_behavior_name
)
team_id = behavior_id_parsed.behavior_ids["team"]
self._stats_reporter.add_property(StatsPropertyType.SELF_PLAY_TEAM, team_id)
else:
# for saving/swapping snapshots
policy.init_load_weights()
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
# for saving/swapping snapshots
policy.init_load_weights()
# First policy or a new agent on the same team encountered
if self.wrapped_trainer_team is None or team_id == self.wrapped_trainer_team:
self.current_policy_snapshot[
parsed_behavior_id.brain_name
] = policy.get_weights()
self._save_snapshot() # Need to save after trainer initializes policy
self.trainer.add_policy(parsed_behavior_id, policy)
self._learning_team = self.controller.get_learning_team
self.wrapped_trainer_team = team_id
"""
Gets policy associated with name_behavior_id
:param name_behavior_id: Fully qualified behavior name
:return: Policy associated with name_behavior_id
"""
def _save_snapshot(self, policy: TFPolicy) -> None:
weights = policy.get_weights()
try:
self.policy_snapshots[self.snapshot_counter] = weights
except IndexError:
self.policy_snapshots.append(weights)
def _save_snapshot(self) -> None:
"""
Saves a snapshot of the current weights of the policy and maintains the policy_snapshots
according to the window size
"""
for brain_name in self.current_policy_snapshot:
current_snapshot_for_brain_name = self.current_policy_snapshot[brain_name]
try:
self.policy_snapshots[self.snapshot_counter][
brain_name
] = current_snapshot_for_brain_name
except IndexError:
self.policy_snapshots.append(
{brain_name: current_snapshot_for_brain_name}
)
for q in self.policy_queues:
name_behavior_id = q.behavior_id
# here is the place for a sampling protocol
if name_behavior_id == self.learning_behavior_name:
"""
Swaps the appropriate weight to the policy and pushes it to respective policy queues
"""
for team_id in self._team_to_name_to_policy_queue:
if team_id == self._learning_team:
elif np.random.uniform() < (1 - self.play_against_current_self_ratio):
elif np.random.uniform() < (1 - self.play_against_latest_model_ratio):
self.policy_elos[-1] = self.current_elo
logger.debug(
"Step {}: Swapping snapshot {} to id {} with {} learning".format(
self.get_step, x, name_behavior_id, self.learning_behavior_name
name_to_policy_queue = self._team_to_name_to_policy_queue[team_id]
for brain_name in self._team_to_name_to_policy_queue[team_id]:
behavior_id = create_name_behavior_id(brain_name, team_id)
policy = self.get_policy(behavior_id)
policy.load_weights(snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)
logger.debug(
"Step {}: Swapping snapshot {} to id {} with team {} learning".format(
self.ghost_step, x, behavior_id, self._learning_team
)
)
policy = self.get_policy(name_behavior_id)
policy.load_weights(snapshot)
q.put(policy)
Adds a policy queue to the list of queues to publish to when this Trainer
makes a policy update
Adds a policy queue for every member of the team to the list of queues to publish to when this Trainer
makes a policy update. Creates an internal policy queue for the wrapped
trainer to push to. The GhostTrainer pushes all policies to the env.
if policy_queue.behavior_id == self.learning_behavior_name:
parsed_behavior_id = self._name_to_parsed_behavior_id[policy_queue.behavior_id]
try:
self._team_to_name_to_policy_queue[parsed_behavior_id.team_id][
parsed_behavior_id.brain_name
] = policy_queue
except KeyError:
self._team_to_name_to_policy_queue[parsed_behavior_id.team_id] = {
parsed_behavior_id.brain_name: policy_queue
}
if parsed_behavior_id.team_id == self.wrapped_trainer_team:
# With a future multiagent trainer, this will be indexed by 'role'
policy_queue.behavior_id
parsed_behavior_id.brain_name
self.internal_policy_queues.append(internal_policy_queue)
self.learning_policy_queues[policy_queue.behavior_id] = policy_queue
self._internal_policy_queues[
parsed_behavior_id.brain_name
] = internal_policy_queue
self.trainer.publish_policy_queue(internal_policy_queue)
def subscribe_trajectory_queue(

Adds a trajectory queue to the list of queues for the trainer to ingest Trajectories from.
Adds a trajectory queue for every member of the team to the list of queues for the trainer
to ingest Trajectories from. Creates an internal trajectory queue to push trajectories from
the learning team. The wrapped trainer subscribes to this queue.
if trajectory_queue.behavior_id == self.learning_behavior_name:
super().subscribe_trajectory_queue(trajectory_queue)
super().subscribe_trajectory_queue(trajectory_queue)
parsed_behavior_id = self._name_to_parsed_behavior_id[
trajectory_queue.behavior_id
]
if parsed_behavior_id.team_id == self.wrapped_trainer_team:
# With a future multiagent trainer, this will be indexed by 'role'
] = AgentManagerQueue(trajectory_queue.behavior_id)
] = AgentManagerQueue(parsed_behavior_id.brain_name)
self.internal_trajectory_queues.append(internal_trajectory_queue)
self._internal_trajectory_queues[
parsed_behavior_id.brain_name
] = internal_trajectory_queue
else:
self.ignored_trajectory_queues.append(trajectory_queue)
# Taken from https://github.com/Unity-Technologies/ml-agents/pull/1975 and
# https://metinmediamath.wordpress.com/2013/11/27/how-to-calculate-the-elo-rating-including-example/
# ELO calculation
def compute_elo_rating_changes(rating1: float, rating2: float, result: float) -> float:
r1 = pow(10, rating1 / 400)
r2 = pow(10, rating2 / 400)
summed = r1 + r2
e1 = r1 / summed
change = result - e1
return change

92
ml-agents/mlagents/trainers/ghost/controller.py


from mlagents_envs.logging_util import get_logger
from typing import Deque, Dict
from collections import deque
from mlagents.trainers.ghost.trainer import GhostTrainer
logger = get_logger(__name__)
class GhostController:
"""
GhostController contains a queue of team ids. GhostTrainers subscribe to the GhostController and query
it to get the current learning team. The GhostController cycles through team ids every 'swap_interval'
which corresponds to the number of trainer steps between changing learning teams.
The GhostController is a unique object and there can only be one per training run.
"""
def __init__(self, maxlen: int = 10):
"""
Create a GhostController.
:param maxlen: Maximum number of GhostTrainers allowed in this GhostController
"""
# Tracks last swap step for each learning team because trainer
# steps of all GhostTrainers do not increment together
self._queue: Deque[int] = deque(maxlen=maxlen)
self._learning_team: int = -1
# Dict from team id to GhostTrainer for ELO calculation
self._ghost_trainers: Dict[int, GhostTrainer] = {}
@property
def get_learning_team(self) -> int:
"""
Returns the current learning team.
:return: The learning team id
"""
return self._learning_team
def subscribe_team_id(self, team_id: int, trainer: GhostTrainer) -> None:
"""
Given a team_id and trainer, add to queue and trainers if not already.
The GhostTrainer is used later by the controller to get ELO ratings of agents.
:param team_id: The team_id of an agent managed by this GhostTrainer
:param trainer: A GhostTrainer that manages this team_id.
"""
if team_id not in self._ghost_trainers:
self._ghost_trainers[team_id] = trainer
if self._learning_team < 0:
self._learning_team = team_id
else:
self._queue.append(team_id)
def change_training_team(self, step: int) -> None:
"""
The current learning team is added to the end of the queue and then updated with the
next in line.
:param step: The step of the trainer for debugging
"""
self._queue.append(self._learning_team)
self._learning_team = self._queue.popleft()
logger.debug(
"Learning team {} swapped on step {}".format(self._learning_team, step)
)
# Adapted from https://github.com/Unity-Technologies/ml-agents/pull/1975 and
# https://metinmediamath.wordpress.com/2013/11/27/how-to-calculate-the-elo-rating-including-example/
# ELO calculation
# TODO : Generalize this to more than two teams
def compute_elo_rating_changes(self, rating: float, result: float) -> float:
"""
Calculates ELO. Given the rating of the learning team and result. The GhostController
queries the other GhostTrainers for the ELO of their agent that is currently being deployed.
Note, this could be the current agent or a past snapshot.
:param rating: Rating of the learning team.
:param result: Win, loss, or draw from the perspective of the learning team.
:return: The change in ELO.
"""
opponent_rating: float = 0.0
for team_id, trainer in self._ghost_trainers.items():
if team_id != self._learning_team:
opponent_rating = trainer.get_opponent_elo()
r1 = pow(10, rating / 400)
r2 = pow(10, opponent_rating / 400)
summed = r1 + r2
e1 = r1 / summed
change = result - e1
for team_id, trainer in self._ghost_trainers.items():
if team_id != self._learning_team:
trainer.change_opponent_elo(change)
return change
正在加载...
取消
保存