|
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
from mlagents.torch_utils import torch |
|
|
|
from mlagents.trainers.torch.networks import ( |
|
|
|
NetworkBody, |
|
|
|
ValueNetwork, |
|
|
|
SharedActorCritic, |
|
|
|
SeparateActorCritic, |
|
|
|
) |
|
|
|
from mlagents.trainers.torch.networks import NetworkBody, ValueNetwork, SimpleActor |
|
|
|
from mlagents.trainers.settings import NetworkSettings |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
from mlagents.trainers.tests.dummy_config import create_observation_specs_with_shapes |
|
|
|
|
|
|
assert _out[0] == pytest.approx(1.0, abs=0.1) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic]) |
|
|
|
def test_actor_critic(ac_type, lstm): |
|
|
|
def test_actor_critic(lstm): |
|
|
|
obs_size = 4 |
|
|
|
network_settings = NetworkSettings( |
|
|
|
memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True |
|
|
|
|
|
|
stream_names = [f"stream_name{n}" for n in range(4)] |
|
|
|
# action_spec = ActionSpec.create_continuous(act_size[0]) |
|
|
|
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) |
|
|
|
actor = ac_type(obs_spec, network_settings, action_spec, stream_names) |
|
|
|
actor = SimpleActor(obs_spec, network_settings, action_spec) |
|
|
|
critic = ValueNetwork(stream_names, obs_spec, network_settings) |
|
|
|
if lstm: |
|
|
|
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) |
|
|
|
memories = torch.ones( |
|
|
|
|
|
|
# memories isn't always set to None, the network should be able to |
|
|
|
# deal with that. |
|
|
|
# Test critic pass |
|
|
|
value_out, memories_out = actor.critic_pass([sample_obs], memories=memories) |
|
|
|
value_out, memories_out = critic.critic_pass([sample_obs], memories=memories) |
|
|
|
for stream in stream_names: |
|
|
|
if lstm: |
|
|
|
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|
|
|
|
|
|
|
|
|
|
# Test get action stats and_value |
|
|
|
action, log_probs, entropies, value_out, mem_out = actor.get_action_stats_and_value( |
|
|
|
action, log_probs, entropies, mem_out = actor.get_action_and_stats( |
|
|
|
[sample_obs], memories=memories, masks=mask |
|
|
|
) |
|
|
|
if lstm: |
|
|
|
|
|
|
|
|
|
|
if mem_out is not None: |
|
|
|
assert mem_out.shape == memories.shape |
|
|
|
for stream in stream_names: |
|
|
|
if lstm: |
|
|
|
assert value_out[stream].shape == (network_settings.memory.sequence_length,) |
|
|
|
else: |
|
|
|
assert value_out[stream].shape == (1,) |