Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

111 行
2.9 KiB

from typing import List, Tuple
from mlagents_envs.base_env import ObservationSpec, DimensionProperty, ObservationType
import pytest
import copy
import os
from mlagents.trainers.settings import (
POCASettings,
TrainerSettings,
PPOSettings,
SACSettings,
GAILSettings,
CuriositySettings,
RewardSignalSettings,
NetworkSettings,
TrainerType,
RewardSignalType,
ScheduleType,
)
CONTINUOUS_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
DISCRETE_DEMO_PATH = os.path.dirname(os.path.abspath(__file__)) + "/testdcvis.demo"
_PPO_CONFIG = TrainerSettings(
trainer_type=TrainerType.PPO,
hyperparameters=PPOSettings(
learning_rate=5.0e-3,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=16,
buffer_size=64,
),
network_settings=NetworkSettings(num_layers=1, hidden_units=32),
summary_freq=500,
max_steps=3000,
threaded=False,
)
_SAC_CONFIG = TrainerSettings(
trainer_type=TrainerType.SAC,
hyperparameters=SACSettings(
learning_rate=5.0e-3,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=8,
buffer_init_steps=100,
buffer_size=5000,
tau=0.01,
init_entcoef=0.01,
),
network_settings=NetworkSettings(num_layers=1, hidden_units=16),
summary_freq=100,
max_steps=1000,
threaded=False,
)
_POCA_CONFIG = TrainerSettings(
trainer_type=TrainerType.POCA,
hyperparameters=POCASettings(
learning_rate=5.0e-3,
learning_rate_schedule=ScheduleType.CONSTANT,
batch_size=16,
buffer_size=64,
),
network_settings=NetworkSettings(num_layers=1, hidden_units=32),
summary_freq=500,
max_steps=3000,
threaded=False,
)
def ppo_dummy_config():
return copy.deepcopy(_PPO_CONFIG)
def sac_dummy_config():
return copy.deepcopy(_SAC_CONFIG)
def poca_dummy_config():
return copy.deepcopy(_POCA_CONFIG)
@pytest.fixture
def gail_dummy_config():
return {RewardSignalType.GAIL: GAILSettings(demo_path=CONTINUOUS_DEMO_PATH)}
@pytest.fixture
def curiosity_dummy_config():
return {RewardSignalType.CURIOSITY: CuriositySettings()}
@pytest.fixture
def extrinsic_dummy_config():
return {RewardSignalType.EXTRINSIC: RewardSignalSettings()}
def create_observation_specs_with_shapes(
shapes: List[Tuple[int, ...]]
) -> List[ObservationSpec]:
obs_specs: List[ObservationSpec] = []
for i, shape in enumerate(shapes):
dim_prop = (DimensionProperty.UNSPECIFIED,) * len(shape)
if len(shape) == 2:
dim_prop = (DimensionProperty.VARIABLE_SIZE, DimensionProperty.NONE)
spec = ObservationSpec(
name=f"observation {i} with shape {shape}",
shape=shape,
dimension_property=dim_prop,
observation_type=ObservationType.DEFAULT,
)
obs_specs.append(spec)
return obs_specs