浏览代码

test for SharedActorCritic

/develop/action-slice
Andrew Cohen 4 年前
当前提交
66742dc8
共有 1 个文件被更改,包括 15 次插入4 次删除
  1. 19
      ml-agents/mlagents/trainers/tests/torch/test_networks.py

19
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.networks import NetworkBody, ValueNetwork, SimpleActor
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SimpleActor,
SharedActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes

assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("shared", [True, False])
def test_actor_critic(lstm):
def test_actor_critic(lstm, shared):
obs_size = 4
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True

stream_names = [f"stream_name{n}" for n in range(4)]
# action_spec = ActionSpec.create_continuous(act_size[0])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
actor = SimpleActor(obs_spec, network_settings, action_spec)
critic = ValueNetwork(stream_names, obs_spec, network_settings)
if shared:
actor = critic = SharedActorCritic(
obs_spec, network_settings, action_spec, stream_names, network_settings
)
else:
actor = SimpleActor(obs_spec, network_settings, action_spec)
critic = ValueNetwork(stream_names, obs_spec, network_settings)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(

正在加载...
取消
保存