GitHub
5 年前
当前提交
14193ada
共有 20 个文件被更改,包括 1818 次插入 和 444 次删除
-
44UnitySDK/Assets/ML-Agents/Examples/Tennis/Prefabs/TennisArea.prefab
-
118UnitySDK/Assets/ML-Agents/Examples/Tennis/Scenes/Tennis.unity
-
157UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/HitWall.cs
-
9UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisAgent.cs
-
2UnitySDK/Assets/ML-Agents/Examples/Tennis/Scripts/TennisArea.cs
-
1001UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.nn
-
13UnitySDK/ProjectSettings/EditorSettings.asset
-
9config/sac_trainer_config.yaml
-
9config/trainer_config.yaml
-
33ml-agents/mlagents/trainers/tests/test_simple_rl.py
-
27ml-agents/mlagents/trainers/tf_policy.py
-
2ml-agents/mlagents/trainers/trainer.py
-
10ml-agents/mlagents/trainers/trainer_controller.py
-
12ml-agents/mlagents/trainers/trainer_util.py
-
87docs/Training-Self-Play.md
-
188docs/images/team_id.png
-
36ml-agents/mlagents/trainers/behavior_id_utils.py
-
231ml-agents/mlagents/trainers/tests/test_ghost.py
-
274ml-agents/mlagents/trainers/ghost/trainer.py
1001
UnitySDK/Assets/ML-Agents/Examples/Tennis/TFModels/Tennis.nn
文件差异内容过多而无法显示
查看文件
文件差异内容过多而无法显示
查看文件
|
|||
# Training with Self-Play |
|||
|
|||
ML-Agents provides the functionality to train symmetric, adversarial games with [Self-Play](https://openai.com/blog/competitive-self-play/). |
|||
A symmetric game is one in which opposing agents are *equal* in form and function. In reinforcement learning, |
|||
this means both agents have the same observation and action spaces. |
|||
With self-play, an agent learns in adversarial games by competing against fixed, past versions of itself |
|||
to provide a more stable, stationary learning environment. This is compared |
|||
to competing against its current self in every episode, which is a constantly changing opponent. |
|||
|
|||
Self-play can be used with our implementations of both [Proximal Policy Optimization (PPO)](Training-PPO.md) and [Soft Actor-Critc (SAC)](Training-SAC.md). |
|||
For more general information on training with ML-Agents, see [Training ML-Agents](Training-ML-Agents.md). |
|||
For more algorithm specific instruction, please see the documentation for [PPO](Training-PPO.md) or [SAC](Training-SAC.md). |
|||
|
|||
Self-play is triggered by including the self-play hyperparameter hierarchy in the trainer configuration file. Detailed description of the self-play hyperparameters are contained below. Furthermore, to distinguish opposing agents, set the team ID to different integer values in the behavior parameters script on the agent prefab. |
|||
|
|||
![Team ID](images/team_id.png) |
|||
|
|||
See the trainer configuration and agent prefabs for our Tennis environment for an example. |
|||
|
|||
## Best Practices Training with Self-Play |
|||
|
|||
Training with self-play adds additional confounding factors to the usual |
|||
issues faced by reinforcement learning. In general, the tradeoff is between |
|||
the skill level and generality of the final policy and the stability of learning. |
|||
Training against a set of slowly or unchanging adversaries with low diversity |
|||
results in a more stable learning process than training against a set of quickly |
|||
changing adversaries with high diversity. With this context, this guide discusses the exposed self-play hyperparameters and intuitions for tuning them. |
|||
|
|||
|
|||
## Hyperparameters |
|||
|
|||
### Reward Signals |
|||
|
|||
We make the assumption that the final reward in a trajectory corresponds to the outcome of an episode. |
|||
A final reward of +1 indicates winning, -1 indicates losing and 0 indicates a draw. |
|||
The ELO calculation (discussed below) depends on this final reward being either +1, 0, -1. |
|||
|
|||
The reward signal should still be used as described in the documentation for the other trainers and [reward signals.](Reward-Signals.md) However, we encourage users to be a bit more conservative when shaping reward functions due to the instability and non-stationarity of learning in adversarial games. Specifically, we encourage users to begin with the simplest possible reward function (+1 winning, -1 losing) and to allow for more iterations of training to compensate for the sparsity of reward. |
|||
|
|||
### Save Steps |
|||
|
|||
The `save_steps` parameter corresponds to the number of *trainer steps* between snapshots. For example, if `save_steps`=10000 then a snapshot of the current policy will be saved every 10000 trainer steps. Note, trainer steps are counted per agent. For more information, please see the [migration doc](Migrating.md) after v0.13. |
|||
|
|||
A larger value of `save_steps` will yield a set of opponents that cover a wider range of skill levels and possibly play styles since the policy receives more training. As a result, the agent trains against a wider variety of opponents. Learning a policy to defeat more diverse opponents is a harder problem and so may require more overall training steps but also may lead to more general and robust policy at the end of training. This value is also dependent on how intrinsically difficult the environment is for the agent. |
|||
|
|||
Recommended Range : 10000-100000 |
|||
|
|||
### Swap Steps |
|||
|
|||
The `swap_steps` parameter corresponds to the number of *trainer steps* between swapping the opponents policy with a different snapshot. As in the `save_steps` discussion, note that trainer steps are counted per agent. For more information, please see the [migration doc](Migrating.md) after v0.13. |
|||
|
|||
|
|||
A larger value of `swap_steps` means that an agent will play against the same fixed opponent for a longer number of training iterations. This results in a more stable training scenario, but leaves the agent open to the risk of overfitting it's behavior for this particular opponent. Thus, when a new opponent is swapped, the agent may lose more often than expected. |
|||
|
|||
Recommended Range : 10000-100000 |
|||
|
|||
### Play against current self ratio |
|||
|
|||
The `play_against_current_self_ratio` parameter corresponds to the probability |
|||
an agent will play against its ***current*** self. With probability |
|||
1 - `play_against_current_self_ratio`, the agent will play against a snapshot of itself |
|||
from a past iteration. |
|||
|
|||
A larger value of `play_against_current_self_ratio` indicates that an agent will be playing against itself more often. Since the agent is updating it's policy, the opponent will be different from iteration to iteration. This can lead to an unstable learning environment, but poses the agent with an [auto-curricula](https://openai.com/blog/emergent-tool-use/) of more increasingly challenging situations which may lead to a stronger final policy. |
|||
|
|||
Recommended Range : 0.0 - 1.0 |
|||
|
|||
### Window |
|||
|
|||
The `window` parameter corresponds to the size of the sliding window of past snapshots from which the agent's opponents are sampled. For example, a `window` size of 5 will save the last 5 snapshots taken. Each time a new snapshot is taken, the oldest is discarded. |
|||
|
|||
A larger value of `window` means that an agent's pool of opponents will contain a larger diversity of behaviors since it will contain policies from earlier in the training run. Like in the `save_steps` hyperparameter, the agent trains against a wider variety of opponents. Learning a policy to defeat more diverse opponents is a harder problem and so may require more overall training steps but also may lead to more general and robust policy at the end of training. |
|||
|
|||
Recommended Range : 5 - 30 |
|||
|
|||
## Training Statistics |
|||
|
|||
To view training statistics, use TensorBoard. For information on launching and |
|||
using TensorBoard, see |
|||
[here](./Getting-Started-with-Balance-Ball.md#observing-training-progress). |
|||
|
|||
### ELO |
|||
In adversarial games, the cumulative environment reward may not be a meaningful metric by which to track learning progress. This is because cumulative reward is entirely dependent on the skill of the opponent. An agent at a particular skill level will get more or less reward against a worse or better agent, respectively. |
|||
|
|||
We provide an implementation of the ELO rating system, a method for calculating the relative skill level between two players from a given population in a zero-sum game. For more informtion on ELO, please see [the ELO wiki](https://en.wikipedia.org/wiki/Elo_rating_system). |
|||
|
|||
In a proper training run, the ELO of the agent should steadily increase. The absolute value of the ELO is less important than the change in ELO over training iterations. |
|
|||
from typing import Dict, NamedTuple |
|||
|
|||
|
|||
class BehaviorIdentifiers(NamedTuple): |
|||
name_behavior_id: str |
|||
brain_name: str |
|||
behavior_ids: Dict[str, int] |
|||
|
|||
@staticmethod |
|||
def from_name_behavior_id(name_behavior_id: str) -> "BehaviorIdentifiers": |
|||
""" |
|||
Parses a name_behavior_id of the form name?team=0¶m1=i&... |
|||
into a BehaviorIdentifiers NamedTuple. |
|||
This allows you to access the brain name and distinguishing identifiers |
|||
without parsing more than once. |
|||
:param name_behavior_id: String of behavior params in HTTP format. |
|||
:returns: A BehaviorIdentifiers object. |
|||
""" |
|||
|
|||
ids: Dict[str, int] = {} |
|||
if "?" in name_behavior_id: |
|||
name, identifiers = name_behavior_id.rsplit("?", 1) |
|||
if "&" in identifiers: |
|||
list_of_identifiers = identifiers.split("&") |
|||
else: |
|||
list_of_identifiers = [identifiers] |
|||
|
|||
for identifier in list_of_identifiers: |
|||
key, value = identifier.split("=") |
|||
ids[key] = int(value) |
|||
else: |
|||
name = name_behavior_id |
|||
|
|||
return BehaviorIdentifiers( |
|||
name_behavior_id=name_behavior_id, brain_name=name, behavior_ids=ids |
|||
) |
|
|||
import pytest |
|||
|
|||
import numpy as np |
|||
|
|||
import yaml |
|||
|
|||
from mlagents.trainers.ghost.trainer import GhostTrainer |
|||
from mlagents.trainers.ppo.trainer import PPOTrainer |
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.agent_processor import AgentManagerQueue |
|||
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|||
from mlagents.trainers.tests import mock_brain as mb |
|||
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory |
|||
|
|||
|
|||
@pytest.fixture |
|||
def dummy_config(): |
|||
return yaml.safe_load( |
|||
""" |
|||
trainer: ppo |
|||
batch_size: 32 |
|||
beta: 5.0e-3 |
|||
buffer_size: 512 |
|||
epsilon: 0.2 |
|||
hidden_units: 128 |
|||
lambd: 0.95 |
|||
learning_rate: 3.0e-4 |
|||
max_steps: 5.0e4 |
|||
normalize: true |
|||
num_epoch: 5 |
|||
num_layers: 2 |
|||
time_horizon: 64 |
|||
sequence_length: 64 |
|||
summary_freq: 1000 |
|||
use_recurrent: false |
|||
normalize: true |
|||
memory_size: 8 |
|||
curiosity_strength: 0.0 |
|||
curiosity_enc_size: 1 |
|||
summary_path: test |
|||
model_path: test |
|||
reward_signals: |
|||
extrinsic: |
|||
strength: 1.0 |
|||
gamma: 0.99 |
|||
self_play: |
|||
window: 5 |
|||
play_against_current_self_ratio: 0.5 |
|||
save_steps: 1000 |
|||
swap_steps: 1000 |
|||
""" |
|||
) |
|||
|
|||
|
|||
VECTOR_ACTION_SPACE = [1] |
|||
VECTOR_OBS_SPACE = 8 |
|||
DISCRETE_ACTION_SPACE = [3, 3, 3, 2] |
|||
BUFFER_INIT_SAMPLES = 513 |
|||
NUM_AGENTS = 12 |
|||
|
|||
|
|||
@pytest.mark.parametrize("use_discrete", [True, False]) |
|||
def test_load_and_set(dummy_config, use_discrete): |
|||
mock_brain = mb.setup_mock_brain( |
|||
use_discrete, |
|||
False, |
|||
vector_action_space=VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
discrete_action_space=DISCRETE_ACTION_SPACE, |
|||
) |
|||
|
|||
trainer_params = dummy_config |
|||
trainer = PPOTrainer( |
|||
mock_brain.brain_name, 0, trainer_params, True, False, 0, "0", False |
|||
) |
|||
trainer.seed = 1 |
|||
policy = trainer.create_policy(mock_brain) |
|||
trainer.seed = 20 # otherwise graphs are the same |
|||
to_load_policy = trainer.create_policy(mock_brain) |
|||
to_load_policy.init_load_weights() |
|||
|
|||
weights = policy.get_weights() |
|||
load_weights = to_load_policy.get_weights() |
|||
try: |
|||
for w, lw in zip(weights, load_weights): |
|||
np.testing.assert_array_equal(w, lw) |
|||
except AssertionError: |
|||
pass |
|||
|
|||
to_load_policy.load_weights(weights) |
|||
load_weights = to_load_policy.get_weights() |
|||
|
|||
for w, lw in zip(weights, load_weights): |
|||
np.testing.assert_array_equal(w, lw) |
|||
|
|||
|
|||
def test_process_trajectory(dummy_config): |
|||
brain_params_team0 = BrainParameters( |
|||
brain_name="test_brain?team=0", |
|||
vector_observation_space_size=1, |
|||
camera_resolutions=[], |
|||
vector_action_space_size=[2], |
|||
vector_action_descriptions=[], |
|||
vector_action_space_type=0, |
|||
) |
|||
|
|||
brain_name = BehaviorIdentifiers.from_name_behavior_id( |
|||
brain_params_team0.brain_name |
|||
).brain_name |
|||
|
|||
brain_params_team1 = BrainParameters( |
|||
brain_name="test_brain?team=1", |
|||
vector_observation_space_size=1, |
|||
camera_resolutions=[], |
|||
vector_action_space_size=[2], |
|||
vector_action_descriptions=[], |
|||
vector_action_space_type=0, |
|||
) |
|||
dummy_config["summary_path"] = "./summaries/test_trainer_summary" |
|||
dummy_config["model_path"] = "./models/test_trainer_models/TestModel" |
|||
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0", False) |
|||
trainer = GhostTrainer(ppo_trainer, brain_name, 0, dummy_config, True, "0") |
|||
|
|||
# first policy encountered becomes policy trained by wrapped PPO |
|||
policy = trainer.create_policy(brain_params_team0) |
|||
trainer.add_policy(brain_params_team0.brain_name, policy) |
|||
trajectory_queue0 = AgentManagerQueue(brain_params_team0.brain_name) |
|||
trainer.subscribe_trajectory_queue(trajectory_queue0) |
|||
|
|||
# Ghost trainer should ignore this queue because off policy |
|||
policy = trainer.create_policy(brain_params_team1) |
|||
trainer.add_policy(brain_params_team1.brain_name, policy) |
|||
trajectory_queue1 = AgentManagerQueue(brain_params_team1.brain_name) |
|||
trainer.subscribe_trajectory_queue(trajectory_queue1) |
|||
|
|||
time_horizon = 15 |
|||
trajectory = make_fake_trajectory( |
|||
length=time_horizon, |
|||
max_step_complete=True, |
|||
vec_obs_size=1, |
|||
num_vis_obs=0, |
|||
action_space=[2], |
|||
) |
|||
trajectory_queue0.put(trajectory) |
|||
trainer.advance() |
|||
|
|||
# Check that trainer put trajectory in update buffer |
|||
assert trainer.trainer.update_buffer.num_experiences == 15 |
|||
|
|||
trajectory_queue1.put(trajectory) |
|||
trainer.advance() |
|||
|
|||
# Check that ghost trainer ignored off policy queue |
|||
assert trainer.trainer.update_buffer.num_experiences == 15 |
|||
|
|||
|
|||
def test_publish_queue(dummy_config): |
|||
brain_params_team0 = BrainParameters( |
|||
brain_name="test_brain?team=0", |
|||
vector_observation_space_size=8, |
|||
camera_resolutions=[], |
|||
vector_action_space_size=[1], |
|||
vector_action_descriptions=[], |
|||
vector_action_space_type=0, |
|||
) |
|||
|
|||
brain_name = BehaviorIdentifiers.from_name_behavior_id( |
|||
brain_params_team0.brain_name |
|||
).brain_name |
|||
|
|||
brain_params_team1 = BrainParameters( |
|||
brain_name="test_brain?team=1", |
|||
vector_observation_space_size=8, |
|||
camera_resolutions=[], |
|||
vector_action_space_size=[1], |
|||
vector_action_descriptions=[], |
|||
vector_action_space_type=0, |
|||
) |
|||
dummy_config["summary_path"] = "./summaries/test_trainer_summary" |
|||
dummy_config["model_path"] = "./models/test_trainer_models/TestModel" |
|||
ppo_trainer = PPOTrainer(brain_name, 0, dummy_config, True, False, 0, "0", False) |
|||
trainer = GhostTrainer(ppo_trainer, brain_name, 0, dummy_config, True, "0") |
|||
|
|||
# First policy encountered becomes policy trained by wrapped PPO |
|||
# This queue should remain empty after swap snapshot |
|||
policy = trainer.create_policy(brain_params_team0) |
|||
trainer.add_policy(brain_params_team0.brain_name, policy) |
|||
policy_queue0 = AgentManagerQueue(brain_params_team0.brain_name) |
|||
trainer.publish_policy_queue(policy_queue0) |
|||
|
|||
# Ghost trainer should use this queue for ghost policy swap |
|||
policy = trainer.create_policy(brain_params_team1) |
|||
trainer.add_policy(brain_params_team1.brain_name, policy) |
|||
policy_queue1 = AgentManagerQueue(brain_params_team1.brain_name) |
|||
trainer.publish_policy_queue(policy_queue1) |
|||
|
|||
# check ghost trainer swap pushes to ghost queue and not trainer |
|||
assert policy_queue0.empty() and policy_queue1.empty() |
|||
trainer._swap_snapshots() |
|||
assert policy_queue0.empty() and not policy_queue1.empty() |
|||
# clear |
|||
policy_queue1.get_nowait() |
|||
|
|||
mock_brain = mb.setup_mock_brain( |
|||
False, |
|||
False, |
|||
vector_action_space=VECTOR_ACTION_SPACE, |
|||
vector_obs_space=VECTOR_OBS_SPACE, |
|||
discrete_action_space=DISCRETE_ACTION_SPACE, |
|||
) |
|||
|
|||
buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_brain) |
|||
# Mock out reward signal eval |
|||
buffer["extrinsic_rewards"] = buffer["environment_rewards"] |
|||
buffer["extrinsic_returns"] = buffer["environment_rewards"] |
|||
buffer["extrinsic_value_estimates"] = buffer["environment_rewards"] |
|||
buffer["curiosity_rewards"] = buffer["environment_rewards"] |
|||
buffer["curiosity_returns"] = buffer["environment_rewards"] |
|||
buffer["curiosity_value_estimates"] = buffer["environment_rewards"] |
|||
buffer["advantages"] = buffer["environment_rewards"] |
|||
trainer.trainer.update_buffer = buffer |
|||
|
|||
# when ghost trainer advance and wrapped trainer buffers full |
|||
# the wrapped trainer pushes updated policy to correct queue |
|||
assert policy_queue0.empty() and policy_queue1.empty() |
|||
trainer.advance() |
|||
assert not policy_queue0.empty() and policy_queue1.empty() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
pytest.main() |
|
|||
# # Unity ML-Agents Toolkit |
|||
# ## ML-Agent Learning (Ghost Trainer) |
|||
|
|||
# import logging |
|||
from typing import Deque, Dict, List, Any, cast |
|||
|
|||
import numpy as np |
|||
import logging |
|||
|
|||
from mlagents.trainers.brain import BrainParameters |
|||
from mlagents.trainers.policy import Policy |
|||
from mlagents.trainers.tf_policy import TFPolicy |
|||
|
|||
from mlagents.trainers.trainer import Trainer |
|||
from mlagents.trainers.trajectory import Trajectory |
|||
from mlagents.trainers.agent_processor import AgentManagerQueue |
|||
|
|||
LOGGER = logging.getLogger("mlagents.trainers") |
|||
|
|||
|
|||
class GhostTrainer(Trainer): |
|||
def __init__( |
|||
self, trainer, brain_name, reward_buff_cap, trainer_parameters, training, run_id |
|||
): |
|||
""" |
|||
Responsible for collecting experiences and training trainer model via self_play. |
|||
:param trainer: The trainer of the policy/policies being trained with self_play |
|||
:param brain_name: The name of the brain associated with trainer config |
|||
:param reward_buff_cap: Max reward history to track in the reward buffer |
|||
:param trainer_parameters: The parameters for the trainer (dictionary). |
|||
:param training: Whether the trainer is set for training. |
|||
:param run_id: The identifier of the current run |
|||
""" |
|||
|
|||
super(GhostTrainer, self).__init__( |
|||
brain_name, trainer_parameters, training, run_id, reward_buff_cap |
|||
) |
|||
|
|||
self.trainer = trainer |
|||
|
|||
self.internal_policy_queues: List[AgentManagerQueue[Policy]] = [] |
|||
self.internal_trajectory_queues: List[AgentManagerQueue[Trajectory]] = [] |
|||
self.learning_policy_queues: Dict[str, AgentManagerQueue[Policy]] = {} |
|||
|
|||
# assign ghost's stats collection to wrapped trainer's |
|||
self.stats_reporter = self.trainer.stats_reporter |
|||
|
|||
self_play_parameters = trainer_parameters["self_play"] |
|||
self.window = self_play_parameters.get("window", 10) |
|||
self.play_against_current_self_ratio = self_play_parameters.get( |
|||
"play_against_current_self_ratio", 0.5 |
|||
) |
|||
self.steps_between_save = self_play_parameters.get("save_steps", 20000) |
|||
self.steps_between_swap = self_play_parameters.get("swap_steps", 20000) |
|||
|
|||
self.policies: Dict[str, TFPolicy] = {} |
|||
self.policy_snapshots: List[Any] = [] |
|||
self.snapshot_counter: int = 0 |
|||
self.learning_behavior_name: str = None |
|||
self.current_policy_snapshot = None |
|||
self.last_save = 0 |
|||
self.last_swap = 0 |
|||
|
|||
# Chosen because it is the initial ELO in Chess |
|||
self.initial_elo: float = self_play_parameters.get("initial_elo", 1200.0) |
|||
self.current_elo: float = self.initial_elo |
|||
self.policy_elos: List[float] = [self.initial_elo] * ( |
|||
self.window + 1 |
|||
) # for learning policy |
|||
self.current_opponent: int = 0 |
|||
|
|||
@property |
|||
def get_step(self) -> int: |
|||
""" |
|||
Returns the number of steps the trainer has performed |
|||
:return: the step count of the trainer |
|||
""" |
|||
return self.trainer.get_step |
|||
|
|||
@property |
|||
def reward_buffer(self) -> Deque[float]: |
|||
""" |
|||
Returns the reward buffer. The reward buffer contains the cumulative |
|||
rewards of the most recent episodes completed by agents using this |
|||
trainer. |
|||
:return: the reward buffer. |
|||
""" |
|||
return self.trainer.reward_buffer |
|||
|
|||
def _write_summary(self, step: int) -> None: |
|||
""" |
|||
Saves training statistics to Tensorboard. |
|||
""" |
|||
opponents = np.array(self.policy_elos, dtype=np.float32) |
|||
LOGGER.info( |
|||
" Learning brain {} ELO: {:0.3f}\n" |
|||
"Mean Opponent ELO: {:0.3f}" |
|||
" Std Opponent ELO: {:0.3f}".format( |
|||
self.learning_behavior_name, |
|||
self.current_elo, |
|||
opponents.mean(), |
|||
opponents.std(), |
|||
) |
|||
) |
|||
self.stats_reporter.add_stat("ELO", self.current_elo) |
|||
|
|||
def _process_trajectory(self, trajectory: Trajectory) -> None: |
|||
if trajectory.done_reached and not trajectory.max_step_reached: |
|||
# Assumption is that final reward is 1/.5/0 for win/draw/loss |
|||
final_reward = trajectory.steps[-1].reward |
|||
result = 0.5 |
|||
if final_reward > 0: |
|||
result = 1.0 |
|||
elif final_reward < 0: |
|||
result = 0.0 |
|||
|
|||
change = compute_elo_rating_changes( |
|||
self.current_elo, self.policy_elos[self.current_opponent], result |
|||
) |
|||
self.current_elo += change |
|||
self.policy_elos[self.current_opponent] -= change |
|||
|
|||
def _is_ready_update(self) -> bool: |
|||
return False |
|||
|
|||
def _update_policy(self) -> None: |
|||
pass |
|||
|
|||
def advance(self) -> None: |
|||
""" |
|||
Steps the trainer, passing trajectories to wrapped trainer and calling trainer advance |
|||
""" |
|||
for traj_queue, internal_traj_queue in zip( |
|||
self.trajectory_queues, self.internal_trajectory_queues |
|||
): |
|||
try: |
|||
t = traj_queue.get_nowait() |
|||
# adds to wrapped trainers queue |
|||
internal_traj_queue.put(t) |
|||
self._process_trajectory(t) |
|||
except AgentManagerQueue.Empty: |
|||
pass |
|||
|
|||
self.next_summary_step = self.trainer.next_summary_step |
|||
self.trainer.advance() |
|||
self._maybe_write_summary(self.get_step) |
|||
|
|||
for internal_q in self.internal_policy_queues: |
|||
# Get policies that correspond to the policy queue in question |
|||
try: |
|||
policy = cast(TFPolicy, internal_q.get_nowait()) |
|||
self.current_policy_snapshot = policy.get_weights() |
|||
self.learning_policy_queues[internal_q.behavior_id].put(policy) |
|||
except AgentManagerQueue.Empty: |
|||
pass |
|||
|
|||
if self.get_step - self.last_save > self.steps_between_save: |
|||
self._save_snapshot(self.trainer.policy) |
|||
self.last_save = self.get_step |
|||
|
|||
if self.get_step - self.last_swap > self.steps_between_swap: |
|||
self._swap_snapshots() |
|||
self.last_swap = self.get_step |
|||
|
|||
def end_episode(self): |
|||
self.trainer.end_episode() |
|||
|
|||
def save_model(self, name_behavior_id: str) -> None: |
|||
self.trainer.save_model(name_behavior_id) |
|||
|
|||
def export_model(self, name_behavior_id: str) -> None: |
|||
self.trainer.export_model(name_behavior_id) |
|||
|
|||
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy: |
|||
return self.trainer.create_policy(brain_parameters) |
|||
|
|||
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None: |
|||
# for saving/swapping snapshots |
|||
policy.init_load_weights() |
|||
self.policies[name_behavior_id] = policy |
|||
|
|||
# First policy encountered |
|||
if not self.learning_behavior_name: |
|||
weights = policy.get_weights() |
|||
self.current_policy_snapshot = weights |
|||
self._save_snapshot(policy) |
|||
self.trainer.add_policy(name_behavior_id, policy) |
|||
self.learning_behavior_name = name_behavior_id |
|||
|
|||
def get_policy(self, name_behavior_id: str) -> TFPolicy: |
|||
return self.policies[name_behavior_id] |
|||
|
|||
def _save_snapshot(self, policy: TFPolicy) -> None: |
|||
weights = policy.get_weights() |
|||
try: |
|||
self.policy_snapshots[self.snapshot_counter] = weights |
|||
except IndexError: |
|||
self.policy_snapshots.append(weights) |
|||
self.policy_elos[self.snapshot_counter] = self.current_elo |
|||
self.snapshot_counter = (self.snapshot_counter + 1) % self.window |
|||
|
|||
def _swap_snapshots(self) -> None: |
|||
for q in self.policy_queues: |
|||
name_behavior_id = q.behavior_id |
|||
# here is the place for a sampling protocol |
|||
if name_behavior_id == self.learning_behavior_name: |
|||
continue |
|||
elif np.random.uniform() < (1 - self.play_against_current_self_ratio): |
|||
x = np.random.randint(len(self.policy_snapshots)) |
|||
snapshot = self.policy_snapshots[x] |
|||
else: |
|||
snapshot = self.current_policy_snapshot |
|||
x = "current" |
|||
self.policy_elos[-1] = self.current_elo |
|||
self.current_opponent = -1 if x == "current" else x |
|||
LOGGER.debug( |
|||
"Step {}: Swapping snapshot {} to id {} with {} learning".format( |
|||
self.get_step, x, name_behavior_id, self.learning_behavior_name |
|||
) |
|||
) |
|||
policy = self.get_policy(name_behavior_id) |
|||
policy.load_weights(snapshot) |
|||
q.put(policy) |
|||
|
|||
def publish_policy_queue(self, policy_queue: AgentManagerQueue[Policy]) -> None: |
|||
""" |
|||
Adds a policy queue to the list of queues to publish to when this Trainer |
|||
makes a policy update |
|||
:param queue: Policy queue to publish to. |
|||
""" |
|||
super().publish_policy_queue(policy_queue) |
|||
if policy_queue.behavior_id == self.learning_behavior_name: |
|||
|
|||
internal_policy_queue: AgentManagerQueue[Policy] = AgentManagerQueue( |
|||
policy_queue.behavior_id |
|||
) |
|||
|
|||
self.internal_policy_queues.append(internal_policy_queue) |
|||
self.learning_policy_queues[policy_queue.behavior_id] = policy_queue |
|||
self.trainer.publish_policy_queue(internal_policy_queue) |
|||
|
|||
def subscribe_trajectory_queue( |
|||
self, trajectory_queue: AgentManagerQueue[Trajectory] |
|||
) -> None: |
|||
""" |
|||
Adds a trajectory queue to the list of queues for the trainer to ingest Trajectories from. |
|||
:param queue: Trajectory queue to publish to. |
|||
""" |
|||
|
|||
if trajectory_queue.behavior_id == self.learning_behavior_name: |
|||
super().subscribe_trajectory_queue(trajectory_queue) |
|||
|
|||
internal_trajectory_queue: AgentManagerQueue[ |
|||
Trajectory |
|||
] = AgentManagerQueue(trajectory_queue.behavior_id) |
|||
|
|||
self.internal_trajectory_queues.append(internal_trajectory_queue) |
|||
self.trainer.subscribe_trajectory_queue(internal_trajectory_queue) |
|||
|
|||
|
|||
# Taken from https://github.com/Unity-Technologies/ml-agents/pull/1975 and |
|||
# https://metinmediamath.wordpress.com/2013/11/27/how-to-calculate-the-elo-rating-including-example/ |
|||
# ELO calculation |
|||
|
|||
|
|||
def compute_elo_rating_changes(rating1: float, rating2: float, result: float) -> float: |
|||
r1 = pow(10, rating1 / 400) |
|||
r2 = pow(10, rating2 / 400) |
|||
|
|||
summed = r1 + r2 |
|||
e1 = r1 / summed |
|||
|
|||
change = result - e1 |
|||
return change |
撰写
预览
正在加载...
取消
保存
Reference in new issue