浏览代码

fix test_networks

/develop/action-slice
Andrew Cohen 4 年前
当前提交
543f22bc
共有 1 个文件被更改,包括 6 次插入16 次删除
  1. 22
      ml-agents/mlagents/trainers/tests/torch/test_networks.py

22
ml-agents/mlagents/trainers/tests/torch/test_networks.py


import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SharedActorCritic,
SeparateActorCritic,
)
from mlagents.trainers.torch.networks import NetworkBody, ValueNetwork, SimpleActor
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes

assert _out[0] == pytest.approx(1.0, abs=0.1)
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
def test_actor_critic(ac_type, lstm):
def test_actor_critic(lstm):
obs_size = 4
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True

stream_names = [f"stream_name{n}" for n in range(4)]
# action_spec = ActionSpec.create_continuous(act_size[0])
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
actor = ac_type(obs_spec, network_settings, action_spec, stream_names)
actor = SimpleActor(obs_spec, network_settings, action_spec)
critic = ValueNetwork(stream_names, obs_spec, network_settings)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(

# memories isn't always set to None, the network should be able to
# deal with that.
# Test critic pass
value_out, memories_out = actor.critic_pass([sample_obs], memories=memories)
value_out, memories_out = critic.critic_pass([sample_obs], memories=memories)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)

# Test get action stats and_value
action, log_probs, entropies, value_out, mem_out = actor.get_action_stats_and_value(
action, log_probs, entropies, mem_out = actor.get_action_and_stats(
[sample_obs], memories=memories, masks=mask
)
if lstm:

if mem_out is not None:
assert mem_out.shape == memories.shape
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)
正在加载...
取消
保存