您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
99 行
4.1 KiB
99 行
4.1 KiB
import sys
|
|
from typing import Optional
|
|
import uuid
|
|
import mlagents_envs
|
|
import mlagents.trainers
|
|
from mlagents import torch_utils
|
|
from mlagents.trainers.settings import RewardSignalType
|
|
from mlagents_envs.exception import UnityCommunicationException
|
|
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
|
|
from mlagents_envs.communicator_objects.training_analytics_pb2 import (
|
|
TrainingEnvironmentInitialized,
|
|
TrainingBehaviorInitialized,
|
|
)
|
|
from google.protobuf.any_pb2 import Any
|
|
|
|
from mlagents.trainers.settings import TrainerSettings, RunOptions
|
|
|
|
|
|
class TrainingAnalyticsSideChannel(SideChannel):
|
|
"""
|
|
Side channel that sends information about the training to the Unity environment so it can be logged.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel")
|
|
# UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356')
|
|
super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356"))
|
|
self.run_options: Optional[RunOptions] = None
|
|
|
|
def on_message_received(self, msg: IncomingMessage) -> None:
|
|
raise UnityCommunicationException(
|
|
"The TrainingAnalyticsSideChannel received a message from Unity, "
|
|
+ "this should not have happened."
|
|
)
|
|
|
|
def environment_initialized(self, run_options: RunOptions) -> None:
|
|
self.run_options = run_options
|
|
# Tuple of (major, minor, patch)
|
|
vi = sys.version_info
|
|
env_params = run_options.environment_parameters
|
|
|
|
msg = TrainingEnvironmentInitialized(
|
|
python_version=f"{vi[0]}.{vi[1]}.{vi[2]}",
|
|
mlagents_version=mlagents.trainers.__version__,
|
|
mlagents_envs_version=mlagents_envs.__version__,
|
|
torch_version=torch_utils.torch.__version__,
|
|
torch_device_type=torch_utils.default_device().type,
|
|
num_envs=run_options.env_settings.num_envs,
|
|
num_environment_parameters=len(env_params) if env_params else 0,
|
|
)
|
|
|
|
any_message = Any()
|
|
any_message.Pack(msg)
|
|
|
|
env_init_msg = OutgoingMessage()
|
|
env_init_msg.set_raw_bytes(any_message.SerializeToString())
|
|
super().queue_message_to_send(env_init_msg)
|
|
|
|
def training_started(self, behavior_name: str, config: TrainerSettings) -> None:
|
|
msg = TrainingBehaviorInitialized(
|
|
behavior_name=behavior_name,
|
|
trainer_type=config.trainer_type.value,
|
|
extrinsic_reward_enabled=(
|
|
RewardSignalType.EXTRINSIC in config.reward_signals
|
|
),
|
|
gail_reward_enabled=(RewardSignalType.GAIL in config.reward_signals),
|
|
curiosity_reward_enabled=(
|
|
RewardSignalType.CURIOSITY in config.reward_signals
|
|
),
|
|
rnd_reward_enabled=(RewardSignalType.RND in config.reward_signals),
|
|
behavioral_cloning_enabled=config.behavioral_cloning is not None,
|
|
recurrent_enabled=config.network_settings.memory is not None,
|
|
visual_encoder=config.network_settings.vis_encode_type.value,
|
|
num_network_layers=config.network_settings.num_layers,
|
|
num_network_hidden_units=config.network_settings.hidden_units,
|
|
trainer_threaded=config.threaded,
|
|
self_play_enabled=config.self_play is not None,
|
|
curriculum_enabled=self._behavior_uses_curriculum(behavior_name),
|
|
)
|
|
|
|
any_message = Any()
|
|
any_message.Pack(msg)
|
|
|
|
training_start_msg = OutgoingMessage()
|
|
training_start_msg.set_raw_bytes(any_message.SerializeToString())
|
|
|
|
super().queue_message_to_send(training_start_msg)
|
|
|
|
def _behavior_uses_curriculum(self, behavior_name: str) -> bool:
|
|
if not self.run_options or not self.run_options.environment_parameters:
|
|
return False
|
|
|
|
for param_settings in self.run_options.environment_parameters.values():
|
|
for lesson in param_settings.curriculum:
|
|
cc = lesson.completion_criteria
|
|
if cc and cc.behavior == behavior_name:
|
|
return True
|
|
|
|
return False
|