Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

99 行
4.1 KiB

import sys
from typing import Optional
import uuid
import mlagents_envs
import mlagents.trainers
from mlagents import torch_utils
from mlagents.trainers.settings import RewardSignalType
from mlagents_envs.exception import UnityCommunicationException
from mlagents_envs.side_channel import SideChannel, IncomingMessage, OutgoingMessage
from mlagents_envs.communicator_objects.training_analytics_pb2 import (
TrainingEnvironmentInitialized,
TrainingBehaviorInitialized,
)
from google.protobuf.any_pb2 import Any
from mlagents.trainers.settings import TrainerSettings, RunOptions
class TrainingAnalyticsSideChannel(SideChannel):
"""
Side channel that sends information about the training to the Unity environment so it can be logged.
"""
def __init__(self) -> None:
# >>> uuid.uuid5(uuid.NAMESPACE_URL, "com.unity.ml-agents/TrainingAnalyticsSideChannel")
# UUID('b664a4a9-d86f-5a5f-95cb-e8353a7e8356')
super().__init__(uuid.UUID("b664a4a9-d86f-5a5f-95cb-e8353a7e8356"))
self.run_options: Optional[RunOptions] = None
def on_message_received(self, msg: IncomingMessage) -> None:
raise UnityCommunicationException(
"The TrainingAnalyticsSideChannel received a message from Unity, "
+ "this should not have happened."
)
def environment_initialized(self, run_options: RunOptions) -> None:
self.run_options = run_options
# Tuple of (major, minor, patch)
vi = sys.version_info
env_params = run_options.environment_parameters
msg = TrainingEnvironmentInitialized(
python_version=f"{vi[0]}.{vi[1]}.{vi[2]}",
mlagents_version=mlagents.trainers.__version__,
mlagents_envs_version=mlagents_envs.__version__,
torch_version=torch_utils.torch.__version__,
torch_device_type=torch_utils.default_device().type,
num_envs=run_options.env_settings.num_envs,
num_environment_parameters=len(env_params) if env_params else 0,
)
any_message = Any()
any_message.Pack(msg)
env_init_msg = OutgoingMessage()
env_init_msg.set_raw_bytes(any_message.SerializeToString())
super().queue_message_to_send(env_init_msg)
def training_started(self, behavior_name: str, config: TrainerSettings) -> None:
msg = TrainingBehaviorInitialized(
behavior_name=behavior_name,
trainer_type=config.trainer_type.value,
extrinsic_reward_enabled=(
RewardSignalType.EXTRINSIC in config.reward_signals
),
gail_reward_enabled=(RewardSignalType.GAIL in config.reward_signals),
curiosity_reward_enabled=(
RewardSignalType.CURIOSITY in config.reward_signals
),
rnd_reward_enabled=(RewardSignalType.RND in config.reward_signals),
behavioral_cloning_enabled=config.behavioral_cloning is not None,
recurrent_enabled=config.network_settings.memory is not None,
visual_encoder=config.network_settings.vis_encode_type.value,
num_network_layers=config.network_settings.num_layers,
num_network_hidden_units=config.network_settings.hidden_units,
trainer_threaded=config.threaded,
self_play_enabled=config.self_play is not None,
curriculum_enabled=self._behavior_uses_curriculum(behavior_name),
)
any_message = Any()
any_message.Pack(msg)
training_start_msg = OutgoingMessage()
training_start_msg.set_raw_bytes(any_message.SerializeToString())
super().queue_message_to_send(training_start_msg)
def _behavior_uses_curriculum(self, behavior_name: str) -> bool:
if not self.run_options or not self.run_options.environment_parameters:
return False
for param_settings in self.run_options.environment_parameters.values():
for lesson in param_settings.curriculum:
cc = lesson.completion_criteria
if cc and cc.behavior == behavior_name:
return True
return False