Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

83 行
2.1 KiB

from typing import List, Tuple
from mlagents_envs.base_env import SensorSpec, DimensionProperty
import pytest
import copy
import os
from mlagents.trainers.settings import (
TrainerSettings,
PPOSettings,
SACSettings,
GAILSettings,
CuriositySettings,
RewardSignalSettings,
NetworkSettings,
TrainerType,
RewardSignalType,
ScheduleType,
)
CONTINUOUS_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
DISCRETE_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo"
_PPO_CONFIG = TrainerSettings(
trainer_type=TrainerType.PPO,
hyperparameters=PPOSettings(
learning_rate=5.0e-3,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=16,
buffer_size=64,
),
network_settings=NetworkSettings(num_layers=1, hidden_units=32),
summary_freq=500,
max_steps=3000,
threaded=False,
)
_SAC_CONFIG = TrainerSettings(
trainer_type=TrainerType.SAC,
hyperparameters=SACSettings(
learning_rate=5.0e-3,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=8,
buffer_init_steps=100,
buffer_size=5000,
tau=0.01,
init_entcoef=0.01,
),
network_settings=NetworkSettings(num_layers=1, hidden_units=16),
summary_freq=100,
max_steps=1000,
threaded=False,
)
def ppo_dummy_config():
return copy.deepcopy(_PPO_CONFIG)
def sac_dummy_config():
return copy.deepcopy(_SAC_CONFIG)
@pytest.fixture
def gail_dummy_config():
return {RewardSignalType.GAIL: GAILSettings(demo_path=CONTINUOUS_DEMO_PATH)}
@pytest.fixture
def curiosity_dummy_config():
return {RewardSignalType.CURIOSITY: CuriositySettings()}
@pytest.fixture
def extrinsic_dummy_config():
return {RewardSignalType.EXTRINSIC: RewardSignalSettings()}
def create_sensor_specs_with_shapes(shapes: List[Tuple[int, ...]]) -> List[SensorSpec]:
sen_spec: List[SensorSpec] = []
for shape in shapes:
dim_prop = (DimensionProperty.UNSPECIFIED,) * len(shape)
spec = SensorSpec(shape, dim_prop)
sen_spec.append(spec)
return sen_spec