Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

318 行
13 KiB

# # Unity ML-Agents Toolkit
from typing import Dict, List, Optional
from collections import defaultdict
import abc
import time
import attr
import numpy as np
from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod
from mlagents.trainers.policy.checkpoint_manager import (
ModelCheckpoint,
ModelCheckpointManager,
)
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.buffer import AgentBuffer, BufferKey
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.torch.components.reward_providers.base_reward_provider import (
BaseRewardProvider,
)
from mlagents_envs.timers import hierarchical_timer
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.policy import Policy
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.agent_processor import AgentManagerQueue
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.stats import StatsPropertyType
from mlagents.trainers.model_saver.model_saver import BaseModelSaver
logger = get_logger(__name__)
class RLTrainer(Trainer):
"""
This class is the base class for trainers that use Reward Signals.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# collected_rewards is a dictionary from name of reward signal to a dictionary of agent_id to cumulative reward
# used for reporting only. We always want to report the environment reward to Tensorboard, regardless
# of what reward signals are actually present.
self.cumulative_returns_since_policy_update: List[float] = []
self.collected_rewards: Dict[str, Dict[str, int]] = {
"environment": defaultdict(lambda: 0)
}
self.update_buffer: AgentBuffer = AgentBuffer()
self._stats_reporter.add_property(
StatsPropertyType.HYPERPARAMETERS, self.trainer_settings.as_dict()
)
self._next_save_step = 0
self._next_summary_step = 0
self.model_saver = self.create_model_saver(
self.trainer_settings, self.artifact_path, self.load
)
self._has_warned_group_rewards = False
def end_episode(self) -> None:
"""
A signal that the Episode has ended. The buffer must be reset.
Get only called when the academy resets.
"""
for rewards in self.collected_rewards.values():
for agent_id in rewards:
rewards[agent_id] = 0
def _update_end_episode_stats(self, agent_id: str, optimizer: Optimizer) -> None:
for name, rewards in self.collected_rewards.items():
if name == "environment":
self.stats_reporter.add_stat(
"Environment/Cumulative Reward",
rewards.get(agent_id, 0),
aggregation=StatsAggregationMethod.HISTOGRAM,
)
self.cumulative_returns_since_policy_update.append(
rewards.get(agent_id, 0)
)
self.reward_buffer.appendleft(rewards.get(agent_id, 0))
rewards[agent_id] = 0
else:
if isinstance(optimizer.reward_signals[name], BaseRewardProvider):
self.stats_reporter.add_stat(
f"Policy/{optimizer.reward_signals[name].name.capitalize()} Reward",
rewards.get(agent_id, 0),
)
else:
self.stats_reporter.add_stat(
optimizer.reward_signals[name].stat_name,
rewards.get(agent_id, 0),
)
rewards[agent_id] = 0
def _clear_update_buffer(self) -> None:
"""
Clear the buffers that have been built up during inference.
"""
self.update_buffer.reset_agent()
@abc.abstractmethod
def _is_ready_update(self):
"""
Returns whether or not the trainer has enough elements to run update model
:return: A boolean corresponding to wether or not update_model() can be run
"""
return False
def create_policy(
self,
parsed_behavior_id: BehaviorIdentifiers,
behavior_spec: BehaviorSpec,
create_graph: bool = False,
) -> Policy:
return self.create_torch_policy(parsed_behavior_id, behavior_spec)
@abc.abstractmethod
def create_torch_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TorchPolicy:
"""
Create a Policy object that uses the PyTorch backend.
"""
pass
@staticmethod
def create_model_saver(
trainer_settings: TrainerSettings, model_path: str, load: bool
) -> BaseModelSaver:
model_saver = TorchModelSaver( # type: ignore
trainer_settings, model_path, load
)
return model_saver
def _policy_mean_reward(self) -> Optional[float]:
""" Returns the mean episode reward for the current policy. """
rewards = self.cumulative_returns_since_policy_update
if len(rewards) == 0:
return None
else:
return sum(rewards) / len(rewards)
@timed
def _checkpoint(self) -> ModelCheckpoint:
"""
Checkpoints the policy associated with this trainer.
"""
n_policies = len(self.policies.keys())
if n_policies > 1:
logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
export_path, auxillary_paths = self.model_saver.save_checkpoint(
self.brain_name, self._step
)
new_checkpoint = ModelCheckpoint(
int(self._step),
export_path,
self._policy_mean_reward(),
time.time(),
auxillary_file_paths=auxillary_paths,
)
ModelCheckpointManager.add_checkpoint(
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
)
return new_checkpoint
def save_model(self) -> None:
"""
Saves the policy associated with this trainer.
"""
n_policies = len(self.policies.keys())
if n_policies > 1:
logger.warning(
"Trainer has multiple policies, but default behavior only saves the first."
)
elif n_policies == 0:
logger.warning("Trainer has no policies, not saving anything.")
return
model_checkpoint = self._checkpoint()
self.model_saver.copy_final_model(model_checkpoint.file_path)
export_ext = "onnx"
final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{self.model_saver.model_path}.{export_ext}"
)
ModelCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
@abc.abstractmethod
def _update_policy(self) -> bool:
"""
Uses demonstration_buffer to update model.
:return: Whether or not the policy was updated.
"""
pass
def _increment_step(self, n_steps: int, name_behavior_id: str) -> None:
"""
Increment the step count of the trainer
:param n_steps: number of steps to increment the step count by
"""
self._step += n_steps
self._next_summary_step = self._get_next_interval_step(self.summary_freq)
self._next_save_step = self._get_next_interval_step(
self.trainer_settings.checkpoint_interval
)
p = self.get_policy(name_behavior_id)
if p:
p.increment_step(n_steps)
self.stats_reporter.set_stat("Step", float(self.get_step))
def _get_next_interval_step(self, interval: int) -> int:
"""
Get the next step count that should result in an action.
:param interval: The interval between actions.
"""
return self._step + (interval - self._step % interval)
def _write_summary(self, step: int) -> None:
"""
Saves training statistics to Tensorboard.
"""
self.stats_reporter.add_stat("Is Training", float(self.should_still_train))
self.stats_reporter.write_stats(int(step))
@abc.abstractmethod
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Takes a trajectory and processes it, putting it into the update buffer.
:param trajectory: The Trajectory tuple containing the steps to be processed.
"""
self._maybe_write_summary(self.get_step + len(trajectory.steps))
self._maybe_save_model(self.get_step + len(trajectory.steps))
self._increment_step(len(trajectory.steps), trajectory.behavior_id)
def _maybe_write_summary(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next summary write,
write the summary. This logic ensures summaries are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if self._next_summary_step == 0: # Don't write out the first one
self._next_summary_step = self._get_next_interval_step(self.summary_freq)
if step_after_process >= self._next_summary_step and self.get_step != 0:
self._write_summary(self._next_summary_step)
def _append_to_update_buffer(self, agentbuffer_trajectory: AgentBuffer) -> None:
"""
Append an AgentBuffer to the update buffer. If the trainer isn't training,
don't update to avoid a memory leak.
"""
if self.should_still_train:
seq_len = (
self.trainer_settings.network_settings.memory.sequence_length
if self.trainer_settings.network_settings.memory is not None
else 1
)
agentbuffer_trajectory.resequence_and_append(
self.update_buffer, training_length=seq_len
)
def _maybe_save_model(self, step_after_process: int) -> None:
"""
If processing the trajectory will make the step exceed the next model write,
save the model. This logic ensures models are written on the update step and not in between.
:param step_after_process: the step count after processing the next trajectory.
"""
if self._next_save_step == 0: # Don't save the first one
self._next_save_step = self._get_next_interval_step(
self.trainer_settings.checkpoint_interval
)
if step_after_process >= self._next_save_step and self.get_step != 0:
self._checkpoint()
def _warn_if_group_reward(self, buffer: AgentBuffer) -> None:
"""
Warn if the trainer receives a Group Reward but isn't a multiagent trainer (e.g. POCA).
"""
if not self._has_warned_group_rewards:
if np.any(buffer[BufferKey.GROUP_REWARD]):
logger.warning(
"An agent recieved a Group Reward, but you are not using a multi-agent trainer. "
"Please use the POCA trainer for best results."
)
self._has_warned_group_rewards = True
def advance(self) -> None:
"""
Steps the trainer, taking in trajectories and updates if ready.
Will block and wait briefly if there are no trajectories.
"""
with hierarchical_timer("process_trajectory"):
for traj_queue in self.trajectory_queues:
# We grab at most the maximum length of the queue.
# This ensures that even if the queue is being filled faster than it is
# being emptied, the trajectories in the queue are on-policy.
_queried = False
for _ in range(traj_queue.qsize()):
_queried = True
try:
t = traj_queue.get_nowait()
self._process_trajectory(t)
except AgentManagerQueue.Empty:
break
if self.threaded and not _queried:
# Yield thread to avoid busy-waiting
time.sleep(0.0001)
if self.should_still_train:
if self._is_ready_update():
with hierarchical_timer("_update_policy"):
if self._update_policy():
for q in self.policy_queues:
# Get policies that correspond to the policy queue in question
q.put(self.get_policy(q.behavior_id))