|
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents.trainers.torch.networks import NetworkBody, ValueNetwork, SimpleActor |
|
|
|
from mlagents.trainers.torch.networks import ( |
|
|
|
NetworkBody, |
|
|
|
ValueNetwork, |
|
|
|
SimpleActor, |
|
|
|
SharedActorCritic, |
|
|
|
) |
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes |
|
|
|
|
|
|
assert _out[0] == pytest.approx(1.0, abs=0.1) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("shared", [True, False]) |
|
|
|
def test_actor_critic(lstm): |
|
|
|
def test_actor_critic(lstm, shared): |
|
|
|
obs_size = 4 |
|
|
|
network_settings = NetworkSettings( |
|
|
|
memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True |
|
|
|
|
|
|
stream_names = [f"stream_name{n}" for n in range(4)] |
|
|
|
# action_spec = ActionSpec.create_continuous(act_size[0]) |
|
|
|
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) |
|
|
|
actor = SimpleActor(obs_spec, network_settings, action_spec) |
|
|
|
critic = ValueNetwork(stream_names, obs_spec, network_settings) |
|
|
|
if shared: |
|
|
|
actor = critic = SharedActorCritic( |
|
|
|
obs_spec, network_settings, action_spec, stream_names, network_settings |
|
|
|
) |
|
|
|
else: |
|
|
|
actor = SimpleActor(obs_spec, network_settings, action_spec) |
|
|
|
critic = ValueNetwork(stream_names, obs_spec, network_settings) |
|
|
|
if lstm: |
|
|
|
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
|
|
|
memories = torch.ones( |
|
|
|