浏览代码

Move private methods out of trainer, simplify interface

/develop/trainerinterface
Ervin Teng 5 年前
当前提交
db743971
共有 6 个文件被更改,包括 184 次插入177 次删除
  1. 2
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 168
      ml-agents/mlagents/trainers/rl_trainer.py
  3. 2
      ml-agents/mlagents/trainers/sac/trainer.py
  4. 4
      ml-agents/mlagents/trainers/tests/test_trainer_controller.py
  5. 175
      ml-agents/mlagents/trainers/trainer.py
  6. 10
      ml-agents/mlagents/trainers/trainer_controller.py

2
ml-agents/mlagents/trainers/ppo/trainer.py


def _update_policy(self):
"""
Uses demonstration_buffer to update the policy.
Uses update buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
buffer_length = self.update_buffer.num_experiences

168
ml-agents/mlagents/trainers/rl_trainer.py


# # Unity ML-Agents Toolkit
import logging
from typing import Dict
from typing import Dict, List, Any
import abc
import time
from mlagents.tf_utils import tf
from mlagents_envs.timers import set_gauge
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents_envs.timers import hierarchical_timer
LOGGER = logging.getLogger("mlagents.trainers")

class RLTrainer(Trainer): # pylint: disable=abstract-method
class RLTrainer(Trainer, abc.ABC): # pylint: disable=abstract-method
"""
This class is the base class for trainers that use Reward Signals.
"""

self.param_keys: List[str] = []
self.cumulative_returns_since_policy_update: List[float] = []
self.step: int = 0
self.training_start_time = time.time()
self.summary_freq = self.trainer_parameters["summary_freq"]
self.next_update_step = self.summary_freq
# Make sure we have at least one reward_signal
if not self.trainer_parameters["reward_signals"]:
raise UnityTrainerException(

}
self.update_buffer: AgentBuffer = AgentBuffer()
self.episode_steps: Dict[str, int] = defaultdict(lambda: 0)
# Write hyperparameters to Tensorboard
if self.is_training:
self.write_tensorboard_text("Hyperparameters", self.trainer_parameters)
def _check_param_keys(self):
for k in self.param_keys:
if k not in self.trainer_parameters:
raise UnityTrainerException(
"The hyper-parameter {0} could not be found for the {1} trainer of "
"brain {2}.".format(k, self.__class__, self.brain_name)
)
def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
"""
Increment the step count of the trainer
:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
self.next_update_step = self.step + (
self.summary_freq - self.step % self.summary_freq
)
p = self.get_policy(name_behavior_id)
if p:
p.increment_step(n_steps)
def _write_summary(self, step: int) -> None:
"""
Saves training statistics to Tensorboard.
"""
is_training = "Training." if self.training_progress < 1.0 else "Not Training."
stats_summary = self.stats_reporter.get_stats_summaries(
"Environment/Cumulative Reward"
)
if stats_summary.num > 0:
LOGGER.info(
" {}: {}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
self.run_id,
self.brain_name,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
set_gauge(f"{self.brain_name}.mean_reward", stats_summary.mean)
else:
LOGGER.info(
" {}: {}: Step: {}. No episode was completed since last summary. {}".format(
self.run_id, self.brain_name, step, is_training
)
)
self.stats_reporter.write_stats(int(step))
def _maybe_write_summary(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next summary write,
write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if step_after_process >= self.next_update_step and self.step != 0:
self._write_summary(self.next_update_step)
def end_episode(self) -> None:
"""

"""
Steps the trainer, taking in trajectories and updates if ready
"""
super().advance()
with hierarchical_timer("process_trajectory"):
for traj_queue in self.trajectory_queues:
try:
t = traj_queue.get_nowait()
self._process_trajectory(t)
except AgentManagerQueue.Empty:
pass
if self.training_progress < 1.0:
if self._is_ready_update():
with hierarchical_timer("update_policy"):
self._update_policy()
for q in self.policy_queues:
# Get policies that correspond to the policy queue in question
q.put(self.get_policy(q.behavior_id))
@property
def training_progress(self) -> float:
"""
Returns a float between 0 and 1 indicating how far along in the training progress the Trainer is.
If 1, the Trainer wasn't training to begin with, or max_steps
is reached.
"""
if self.is_training:
return min(self.step / float(self.trainer_parameters["max_steps"]), 1.0)
else:
return 1.0
def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
"""
Saves text to Tensorboard.
Note: Only works on tensorflow r1.2 or above.
:param key: The name of the text.
:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
with tf.Session() as sess:
s_op = tf.summary.text(
key,
tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])
),
)
s = sess.run(s_op)
self.stats_reporter.write_text(s, self.step)
except Exception:
LOGGER.info("Could not write text summary for Tensorboard.")
pass
def save_model(self, name_behavior_id: str) -> None:
"""
Saves the model
"""
self.get_policy(name_behavior_id).save_model(self.step)
def export_model(self, name_behavior_id: str) -> None:
"""
Exports the model
"""
self.get_policy(name_behavior_id).export_model()
@abc.abstractmethod
def _update_policy(self) -> None:
"""
Uses update buffer to update the policy.
The reward signal generators must be updated in this method at their own pace.
"""
pass
@abc.abstractmethod
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
@abc.abstractmethod
def _is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
return False

2
ml-agents/mlagents/trainers/sac/trainer.py


Saves the model. Overrides the default save_model since we want to save
the replay buffer as well.
"""
self.policy.save_model(self.get_step)
super().save_model(name_behavior_id)
if self.checkpoint_replay_buffer:
self.save_replay_buffer()

4
ml-agents/mlagents/trainers/tests/test_trainer_controller.py


trainer_mock = MagicMock()
trainer_mock.get_step = 0
trainer_mock.get_max_steps = 5
trainer_mock.should_still_train = True
trainer_mock.training_progress = 0.0
trainer_mock.parameters = {"some": "parameter"}
trainer_mock.write_tensorboard_text = MagicMock()

not tc.trainers["testbrain"].get_step
<= tc.trainers["testbrain"].get_max_steps
):
tc.trainers["testbrain"].should_still_train = False
tc.trainers["testbrain"].training_progress = 1.0
if tc.trainers["testbrain"].get_step > 10:
raise KeyboardInterrupt
return 1

175
ml-agents/mlagents/trainers/trainer.py


# # Unity ML-Agents Toolkit
import logging
from typing import Dict, List, Deque, Any
import time
from mlagents.tf_utils import tf
from mlagents_envs.timers import set_gauge
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.trajectory import Trajectory

from mlagents_envs.timers import hierarchical_timer
LOGGER = logging.getLogger("mlagents.trainers")

:str run_id: The identifier of the current run
:int reward_buff_cap:
"""
self.param_keys: List[str] = []
self.is_training = training
self.cumulative_returns_since_policy_update: List[float] = []
self.is_training = training
self.step: int = 0
self.training_start_time = time.time()
self.summary_freq = self.trainer_parameters["summary_freq"]
self.next_update_step = self.summary_freq
def _check_param_keys(self):
for k in self.param_keys:
if k not in self.trainer_parameters:
raise UnityTrainerException(
"The hyper-parameter {0} could not be found for the {1} trainer of "
"brain {2}.".format(k, self.__class__, self.brain_name)
)
def write_tensorboard_text(self, key: str, input_dict: Dict[str, Any]) -> None:
"""
Saves text to Tensorboard.
Note: Only works on tensorflow r1.2 or above.
:param key: The name of the text.
:param input_dict: A dictionary that will be displayed in a table on Tensorboard.
"""
try:
with tf.Session() as sess:
s_op = tf.summary.text(
key,
tf.convert_to_tensor(
([[str(x), str(input_dict[x])] for x in input_dict])
),
)
s = sess.run(s_op)
self.stats_reporter.write_text(s, self.get_step)
except Exception:
LOGGER.info("Could not write text summary for Tensorboard.")
pass
def _dict_to_str(self, param_dict: Dict[str, Any], num_tabs: int) -> str:
"""

)
@property
def parameters(self) -> Dict[str, Any]:
"""
Returns the trainer parameters of the trainer.
"""
return self.trainer_parameters
@property
def get_max_steps(self) -> int:
"""
Returns the maximum number of steps. Is used to know when the trainer should be stopped.
:return: The maximum number of steps of the trainer
"""
return int(float(self.trainer_parameters["max_steps"]))
@property
def get_step(self) -> int:
"""
Returns the number of steps the trainer has performed
:return: the step count of the trainer
"""
return self.step
@property
def should_still_train(self) -> bool:
@abc.abstractmethod
def training_progress(self) -> float:
Returns whether or not the trainer should train. A Trainer could
stop training if it wasn't training to begin with, or if max_steps
Returns a float between 0 and 1 indicating how far along in the training progress the Trainer is.
If 1, the Trainer wasn't training to begin with, or max_steps
return self.is_training and self.get_step <= self.get_max_steps
pass
@property
def reward_buffer(self) -> Deque[float]:

"""
return self._reward_buffer
def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
"""
Increment the step count of the trainer
:param n_steps: number of steps to increment the step count by
"""
self.step += n_steps
self.next_update_step = self.step + (
self.summary_freq - self.step % self.summary_freq
)
p = self.get_policy(name_behavior_id)
if p:
p.increment_step(n_steps)
@abc.abstractmethod
self.get_policy(name_behavior_id).save_model(self.get_step)
pass
@abc.abstractmethod
self.get_policy(name_behavior_id).export_model()
def _write_summary(self, step: int) -> None:
"""
Saves training statistics to Tensorboard.
"""
is_training = "Training." if self.should_still_train else "Not Training."
stats_summary = self.stats_reporter.get_stats_summaries(
"Environment/Cumulative Reward"
)
if stats_summary.num > 0:
LOGGER.info(
" {}: {}: Step: {}. "
"Time Elapsed: {:0.3f} s "
"Mean "
"Reward: {:0.3f}"
". Std of Reward: {:0.3f}. {}".format(
self.run_id,
self.brain_name,
step,
time.time() - self.training_start_time,
stats_summary.mean,
stats_summary.std,
is_training,
)
)
set_gauge(f"{self.brain_name}.mean_reward", stats_summary.mean)
else:
LOGGER.info(
" {}: {}: Step: {}. No episode was completed since last summary. {}".format(
self.run_id, self.brain_name, step, is_training
)
)
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next summary write,
write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if step_after_process >= self.next_update_step and self.get_step != 0:
self._write_summary(self.next_update_step)
pass
@abc.abstractmethod
def end_episode(self):

pass
@abc.abstractmethod
def _is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
def advance(self) -> None:
return False
@abc.abstractmethod
def _update_policy(self):
"""
Uses demonstration_buffer to update model.
Steps the trainer, taking in trajectories and updates if ready
def advance(self) -> None:
"""
Steps the trainer, taking in trajectories and updates if ready.
"""
with hierarchical_timer("process_trajectory"):
for traj_queue in self.trajectory_queues:
try:
t = traj_queue.get_nowait()
self._process_trajectory(t)
except AgentManagerQueue.Empty:
pass
if self.should_still_train:
if self._is_ready_update():
with hierarchical_timer("_update_policy"):
self._update_policy()
for q in self.policy_queues:
# Get policies that correspond to the policy queue in question
q.put(self.get_policy(q.behavior_id))
def publish_policy_queue(self, policy_queue: AgentManagerQueue[Policy]) -> None:
"""

10
ml-agents/mlagents/trainers/trainer_controller.py


if brain_name not in self.trainers:
continue
if curriculum.measure == "progress":
measure_val = self.trainers[brain_name].get_step / float(
self.trainers[brain_name].get_max_steps
)
measure_val = self.trainers[brain_name].training_progress
brain_names_to_measure_vals[brain_name] = measure_val
elif curriculum.measure == "reward":
measure_val = np.mean(self.trainers[brain_name].reward_buffer)

def _not_done_training(self) -> bool:
return (
any(t.should_still_train for t in self.trainers.values())
any(t.training_progress < 1.0 for t in self.trainers.values())
or not self.train_model
) or len(self.trainers) == 0

trainer = self.trainer_factory.generate(brain_name)
self.trainers[brain_name] = trainer
self.logger.info(trainer)
if self.train_model:
trainer.write_tensorboard_text("Hyperparameters", trainer.parameters)
policy = trainer.create_policy(env_manager.external_brains[name_behavior_id])
trainer.add_policy(name_behavior_id, policy)

policy,
name_behavior_id,
trainer.stats_reporter,
trainer.parameters.get("time_horizon", sys.maxsize),
trainer.trainer_parameters.get("time_horizon", sys.maxsize),
)
trainer.publish_policy_queue(agent_manager.policy_queue)
trainer.subscribe_trajectory_queue(agent_manager.trajectory_queue)

正在加载...
取消
保存