浏览代码

[change] Clean up trainer interface, clean up GhostTrainer stats (#3634)

/bug-failed-api-check
GitHub 4 年前
当前提交
6709a9bf
共有 10 个文件被更改,包括 194 次插入146 次删除
  1. 4
      ml-agents/mlagents/trainers/agent_processor.py
  2. 42
      ml-agents/mlagents/trainers/ghost/trainer.py
  3. 11
      ml-agents/mlagents/trainers/ppo/trainer.py
  4. 11
      ml-agents/mlagents/trainers/sac/trainer.py
  5. 27
      ml-agents/mlagents/trainers/stats.py
  6. 3
      ml-agents/mlagents/trainers/tests/test_learn.py
  7. 7
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  8. 27
      ml-agents/mlagents/trainers/tests/test_stats.py
  9. 102
      ml-agents/mlagents/trainers/trainer/rl_trainer.py
  10. 106
      ml-agents/mlagents/trainers/trainer/trainer.py

4
ml-agents/mlagents/trainers/agent_processor.py


self.experience_buffers[global_id] = []
if curr_agent_step.done:
self.stats_reporter.add_stat(
"Environment/Cumulative Reward",
self.episode_rewards.get(global_id, 0),
)
self.stats_reporter.add_stat(
"Environment/Episode Length",
self.episode_steps.get(global_id, 0),
)

42
ml-agents/mlagents/trainers/ghost/trainer.py


from mlagents.trainers.trainer import Trainer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
logger = logging.getLogger("mlagents.trainers")

self.learning_policy_queues: Dict[str, AgentManagerQueue[Policy]] = {}
# assign ghost's stats collection to wrapped trainer's
self.stats_reporter = self.trainer.stats_reporter
self._stats_reporter = self.trainer.stats_reporter
# Set the logging to print ELO in the console
self._stats_reporter.add_property(StatsPropertyType.SELF_PLAY, True)
self_play_parameters = trainer_parameters["self_play"]
self.window = self_play_parameters.get("window", 10)

"""
return self.trainer.reward_buffer
def _write_summary(self, step: int) -> None:
"""
Saves training statistics to Tensorboard.
"""
opponents = np.array(self.policy_elos, dtype=np.float32)
logger.info(
" Learning brain {} ELO: {:0.3f}\n"
"Mean Opponent ELO: {:0.3f}"
" Std Opponent ELO: {:0.3f}".format(
self.learning_behavior_name,
self.current_elo,
opponents.mean(),
opponents.std(),
)
)
self.stats_reporter.add_stat("ELO", self.current_elo)
def _process_trajectory(self, trajectory: Trajectory) -> None:
if trajectory.done_reached and not trajectory.max_step_reached:
# Assumption is that final reward is 1/.5/0 for win/draw/loss

)
self.current_elo += change
self.policy_elos[self.current_opponent] -= change
def _is_ready_update(self) -> bool:
return False
def _update_policy(self) -> None:
pass
opponents = np.array(self.policy_elos, dtype=np.float32)
self._stats_reporter.add_stat("Self-play/ELO", self.current_elo)
self._stats_reporter.add_stat(
"Self-play/Mean Opponent ELO", opponents.mean()
)
self._stats_reporter.add_stat("Self-play/Std Opponent ELO", opponents.std())
def advance(self) -> None:
"""

pass
self.next_summary_step = self.trainer.next_summary_step
self._maybe_write_summary(self.get_step)
for internal_q in self.internal_policy_queues:
# Get policies that correspond to the policy queue in question

self.trainer.add_policy(name_behavior_id, policy)
self._save_snapshot(policy) # Need to save after trainer initializes policy
self.learning_behavior_name = name_behavior_id
behavior_id_parsed = BehaviorIdentifiers.from_name_behavior_id(
self.learning_behavior_name
)
team_id = behavior_id_parsed.behavior_ids["team"]
self._stats_reporter.add_property(StatsPropertyType.SELF_PLAY_TEAM, team_id)
else:
# for saving/swapping snapshots
policy.init_load_weights()

11
ml-agents/mlagents/trainers/ppo/trainer.py


super()._process_trajectory(trajectory)
agent_id = trajectory.agent_id # All the agents should have the same ID
# Add to episode_steps
self.episode_steps[agent_id] += len(trajectory.steps)
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization
if self.is_training:

)
for name, v in value_estimates.items():
agent_buffer_trajectory["{}_value_estimates".format(name)].extend(v)
self.stats_reporter.add_stat(
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)

batch_update_stats[stat_name].append(value)
for stat, stat_list in batch_update_stats.items():
self.stats_reporter.add_stat(stat, np.mean(stat_list))
self._stats_reporter.add_stat(stat, np.mean(stat_list))
self.stats_reporter.add_stat(stat, val)
self.clear_update_buffer()
self._stats_reporter.add_stat(stat, val)
self._clear_update_buffer()
def create_policy(self, brain_parameters: BrainParameters) -> TFPolicy:
"""

11
ml-agents/mlagents/trainers/sac/trainer.py


last_step = trajectory.steps[-1]
agent_id = trajectory.agent_id # All the agents should have the same ID
# Add to episode_steps
self.episode_steps[agent_id] += len(trajectory.steps)
agent_buffer_trajectory = trajectory.to_agentbuffer()
# Update the normalization

agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached
)
for name, v in value_estimates.items():
self.stats_reporter.add_stat(
self._stats_reporter.add_stat(
self.optimizer.reward_signals[name].value_name, np.mean(v)
)

)
for stat, stat_list in batch_update_stats.items():
self.stats_reporter.add_stat(stat, np.mean(stat_list))
self._stats_reporter.add_stat(stat, np.mean(stat_list))
self.stats_reporter.add_stat(stat, val)
self._stats_reporter.add_stat(stat, val)
def update_reward_signals(self) -> None:
"""

for stat_name, value in update_stats.items():
batch_update_stats[stat_name].append(value)
for stat, stat_list in batch_update_stats.items():
self.stats_reporter.add_stat(stat, np.mean(stat_list))
self._stats_reporter.add_stat(stat, np.mean(stat_list))
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
"""

27
ml-agents/mlagents/trainers/stats.py


class StatsPropertyType(Enum):
HYPERPARAMETERS = "hyperparameters"
SELF_PLAY = "selfplay"
SELF_PLAY_TEAM = "selfplayteam"
class StatsWriter(abc.ABC):

class ConsoleWriter(StatsWriter):
def __init__(self):
self.training_start_time = time.time()
# If self-play, we want to print ELO as well as reward
self.self_play = False
self.self_play_team = -1
def write_stats(
self, category: str, values: Dict[str, StatsSummary], step: int

stats_summary = stats_summary = values["Is Training"]
if stats_summary.mean > 0.0:
is_training = "Training."
if "Environment/Cumulative Reward" in values:
stats_summary = values["Environment/Cumulative Reward"]
logger.info(

is_training,
)
)
if self.self_play and "Self-play/ELO" in values:
elo_stats = values["Self-play/ELO"]
mean_opponent_elo = values["Self-play/Mean Opponent ELO"]
std_opponent_elo = values["Self-play/Std Opponent ELO"]
logger.info(
"{} Team {}: ELO: {:0.3f}. "
"Mean Opponent ELO: {:0.3f}. "
"Std Opponent ELO: {:0.3f}. ".format(
category,
self.self_play_team,
elo_stats.mean,
mean_opponent_elo.mean,
std_opponent_elo.mean,
)
)
else:
logger.info(
"{}: Step: {}. No episode was completed since last summary. {}".format(

category, self._dict_to_str(value, 0)
)
)
elif property_type == StatsPropertyType.SELF_PLAY:
assert isinstance(value, bool)
self.self_play = value
elif property_type == StatsPropertyType.SELF_PLAY_TEAM:
assert isinstance(value, int)
self.self_play_team = value
def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""

3
ml-agents/mlagents/trainers/tests/test_learn.py


from mlagents.trainers.trainer_controller import TrainerController
from mlagents.trainers.learn import parse_command_line
from mlagents_envs.exception import UnityEnvironmentException
from mlagents.trainers.stats import StatsReporter
def basic_options(extra_args=None):

sampler_manager_mock.return_value,
None,
)
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py
@patch("mlagents.trainers.learn.SamplerManager")

mock_init.assert_called_once()
assert mock_init.call_args[0][1] == "/dockertarget/models/ppo"
assert mock_init.call_args[0][2] == "/dockertarget/summaries"
StatsReporter.writers.clear() # make sure there aren't any writers as added by learn.py
def test_bad_env_path():

7
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


def test_rl_trainer():
trainer = create_rl_trainer()
agent_id = "0"
trainer.episode_steps[agent_id] = 3
for agent_id in trainer.episode_steps:
assert trainer.episode_steps[agent_id] == 0
for rewards in trainer.collected_rewards.values():
for agent_id in rewards:
assert rewards[agent_id] == 0

trainer = create_rl_trainer()
trainer.update_buffer = construct_fake_buffer(0)
trainer.clear_update_buffer()
trainer._clear_update_buffer()
@mock.patch("mlagents.trainers.trainer.rl_trainer.RLTrainer.clear_update_buffer")
@mock.patch("mlagents.trainers.trainer.rl_trainer.RLTrainer._clear_update_buffer")
def test_advance(mocked_clear_update_buffer):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")

27
ml-agents/mlagents/trainers/tests/test_stats.py


self.assertIn("Hyperparameters for behavior name", cm.output[2])
self.assertIn("example:\t1.0", cm.output[2])
def test_selfplay_console_writer(self):
with self.assertLogs("mlagents.trainers", level="INFO") as cm:
category = "category1"
console_writer = ConsoleWriter()
console_writer.add_property(category, StatsPropertyType.SELF_PLAY, True)
console_writer.add_property(category, StatsPropertyType.SELF_PLAY_TEAM, 1)
statssummary1 = StatsSummary(mean=1.0, std=1.0, num=1)
console_writer.write_stats(
category,
{
"Environment/Cumulative Reward": statssummary1,
"Is Training": statssummary1,
"Self-play/ELO": statssummary1,
"Self-play/Mean Opponent ELO": statssummary1,
"Self-play/Std Opponent ELO": statssummary1,
},
10,
)
self.assertIn(
"Mean Reward: 1.000. Std of Reward: 1.000. Training.", cm.output[0]
)
self.assertIn(
"category1 Team 1: ELO: 1.000. Mean Opponent ELO: 1.000. Std Opponent ELO: 1.000.",
cm.output[1],
)

102
ml-agents/mlagents/trainers/trainer/rl_trainer.py


# # Unity ML-Agents Toolkit
from typing import Dict
from typing import Dict, List
import abc
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer

from mlagents_envs.timers import hierarchical_timer
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.stats import StatsPropertyType
RewardSignalResults = Dict[str, RewardSignalResult]

# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.cumulative_returns_since_policy_update: List[float] = []
self.episode_steps: Dict[str, int] = defaultdict(lambda: 0)
self._stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_parameters
)
def end_episode(self) -> None:
"""

for agent_id in self.episode_steps:
self.episode_steps[agent_id] = 0
self.episode_steps[agent_id] = 0
self.stats_reporter.add_stat(
"Environment/Cumulative Reward", rewards.get(agent_id, 0)
)
self.cumulative_returns_since_policy_update.append(
rewards.get(agent_id, 0)
)

)
rewards[agent_id] = 0
def clear_update_buffer(self) -> None:
def _clear_update_buffer(self) -> None:
@abc.abstractmethod
def _is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
return False
@abc.abstractmethod
def _update_policy(self):
"""
Uses demonstration_buffer to update model.
"""
pass
def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
"""
Increment the step count of the trainer
:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
self.next_summary_step = self._get_next_summary_step()
p = self.get_policy(name_behavior_id)
if p:
p.increment_step(n_steps)
def _get_next_summary_step(self) -> int:
"""
Get the next step count that should result in a summary write.
"""
return self.step + (self.summary_freq - self.step % self.summary_freq)
def _write_summary(self, step: int) -> None:
"""
Saves training statistics to Tensorboard.
"""
self.stats_reporter.add_stat("Is Training", float(self.should_still_train))
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next summary write,
write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if step_after_process >= self.next_summary_step and self.get_step != 0:
self._write_summary(self.next_summary_step)
Steps the trainer, taking in trajectories and updates if ready
Steps the trainer, taking in trajectories and updates if ready.
super().advance()
if not self.should_still_train:
self.clear_update_buffer()
with hierarchical_timer("process_trajectory"):
for traj_queue in self.trajectory_queues:
# We grab at most the maximum length of the queue.
# This ensures that even if the queue is being filled faster than it is
# being emptied, the trajectories in the queue are on-policy.
for _ in range(traj_queue.maxlen):
try:
t = traj_queue.get_nowait()
self._process_trajectory(t)
except AgentManagerQueue.Empty:
break
if self.should_still_train:
if self._is_ready_update():
with hierarchical_timer("_update_policy"):
self._update_policy()
for q in self.policy_queues:
# Get policies that correspond to the policy queue in question
q.put(self.get_policy(q.behavior_id))
else:
self._clear_update_buffer()

106
ml-agents/mlagents/trainers/trainer/trainer.py


# # Unity ML-Agents Toolkit
import logging
from typing import Dict, List, Deque, Any
import time
import abc
from collections import deque

from mlagents.trainers.stats import StatsReporter, StatsPropertyType
from mlagents.trainers.stats import StatsReporter
from mlagents_envs.timers import hierarchical_timer
logger = logging.getLogger("mlagents.trainers")

self.run_id = run_id
self.trainer_parameters = trainer_parameters
self.summary_path = trainer_parameters["summary_path"]
self.stats_reporter = StatsReporter(self.summary_path)
self.cumulative_returns_since_policy_update: List[float] = []
self._stats_reporter = StatsReporter(self.summary_path)
self.training_start_time = time.time()
self.stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_parameters
)
@property
def stats_reporter(self):
"""
Returns the stats reporter associated with this Trainer.
"""
return self._stats_reporter
def _check_param_keys(self):
for k in self.param_keys:

"""
return self._reward_buffer
def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
"""
Increment the step count of the trainer
:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
self.next_summary_step = self._get_next_summary_step()
p = self.get_policy(name_behavior_id)
if p:
p.increment_step(n_steps)
def _get_next_summary_step(self) -> int:
"""
Get the next step count that should result in a summary write.
"""
return self.step + (self.summary_freq - self.step % self.summary_freq)
def save_model(self, name_behavior_id: str) -> None:
"""
Saves the model

settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
export_policy_model(settings, policy.graph, policy.sess)
def _write_summary(self, step: int) -> None:
"""
Saves training statistics to Tensorboard.
"""
self.stats_reporter.add_stat("Is Training", float(self.should_still_train))
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next summary write,
write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if step_after_process >= self.next_summary_step and self.get_step != 0:
self._write_summary(self.next_summary_step)
@abc.abstractmethod
def end_episode(self):
"""

@abc.abstractmethod
def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
"""
Adds policy to trainer
Adds policy to trainer.
"""
pass

Gets policy from trainer
Gets policy from trainer.
def _is_ready_update(self):
def advance(self) -> None:
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
return False
@abc.abstractmethod
def _update_policy(self):
"""
Uses demonstration_buffer to update model.
Advances the trainer. Typically, this means grabbing trajectories
from all subscribed trajectory queues (self.trajectory_queues), and updating
a policy using the steps in them, and if needed pushing a new policy onto the right
policy queues (self.policy_queues).
def advance(self) -> None:
"""
Steps the trainer, taking in trajectories and updates if ready.
"""
with hierarchical_timer("process_trajectory"):
for traj_queue in self.trajectory_queues:
# We grab at most the maximum length of the queue.
# This ensures that even if the queue is being filled faster than it is
# being emptied, the trajectories in the queue are on-policy.
for _ in range(traj_queue.maxlen):
try:
t = traj_queue.get_nowait()
self._process_trajectory(t)
except AgentManagerQueue.Empty:
break
if self.should_still_train:
if self._is_ready_update():
with hierarchical_timer("_update_policy"):
self._update_policy()
for q in self.policy_queues:
# Get policies that correspond to the policy queue in question
q.put(self.get_policy(q.behavior_id))
:param queue: Policy queue to publish to.
:param policy_queue: Policy queue to publish to.
"""
self.policy_queues.append(policy_queue)

"""
Adds a trajectory queue to the list of queues for the trainer to ingest Trajectories from.
:param queue: Trajectory queue to publish to.
:param trajectory_queue: Trajectory queue to read from.
"""
self.trajectory_queues.append(trajectory_queue)
正在加载...
取消
保存