Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

37 行
1.4 KiB

from typing import Dict, Type
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.components.reward_signals import RewardSignal
from mlagents.trainers.components.reward_signals.extrinsic.signal import (
ExtrinsicRewardSignal,
)
from mlagents.trainers.components.reward_signals.gail.signal import GAILRewardSignal
from mlagents.trainers.components.reward_signals.curiosity.signal import (
CuriosityRewardSignal,
)
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.settings import RewardSignalSettings, RewardSignalType
NAME_TO_CLASS: Dict[RewardSignalType, Type[RewardSignal]] = {
RewardSignalType.EXTRINSIC: ExtrinsicRewardSignal,
RewardSignalType.CURIOSITY: CuriosityRewardSignal,
RewardSignalType.GAIL: GAILRewardSignal,
}
def create_reward_signal(
policy: TFPolicy, name: RewardSignalType, settings: RewardSignalSettings
) -> RewardSignal:
"""
Creates a reward signal class based on the name and config entry provided as a dict.
:param policy: The policy class which the reward will be applied to.
:param name: The name of the reward signal
:param config_entry: The config entries for that reward signal
:return: The reward signal class instantiated
"""
rcls = NAME_TO_CLASS.get(name)
if not rcls:
raise UnityTrainerException(f"Unknown reward signal type {name}")
class_inst = rcls(policy, settings)
return class_inst