您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
212 行
7.2 KiB
212 行
7.2 KiB
import pytest
|
|
|
|
import torch
|
|
from mlagents.trainers.torch.networks import (
|
|
NetworkBody,
|
|
ValueNetwork,
|
|
SimpleActor,
|
|
SharedActorCritic,
|
|
SeparateActorCritic,
|
|
)
|
|
from mlagents.trainers.settings import NetworkSettings
|
|
from mlagents_envs.base_env import ActionType
|
|
from mlagents.trainers.torch.distributions import (
|
|
GaussianDistInstance,
|
|
CategoricalDistInstance,
|
|
)
|
|
|
|
|
|
def test_networkbody_vector():
|
|
torch.manual_seed(0)
|
|
obs_size = 4
|
|
network_settings = NetworkSettings()
|
|
obs_shapes = [(obs_size,)]
|
|
|
|
networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2)
|
|
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
|
|
sample_obs = 0.1 * torch.ones((1, obs_size))
|
|
sample_act = 0.1 * torch.ones((1, 2))
|
|
|
|
for _ in range(300):
|
|
encoded, _ = networkbody([sample_obs], [], sample_act)
|
|
assert encoded.shape == (1, network_settings.hidden_units)
|
|
# Try to force output to 1
|
|
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
# In the last step, values should be close to 1
|
|
for _enc in encoded.flatten():
|
|
assert _enc == pytest.approx(1.0, abs=0.1)
|
|
|
|
|
|
def test_networkbody_lstm():
|
|
torch.manual_seed(0)
|
|
obs_size = 4
|
|
seq_len = 16
|
|
network_settings = NetworkSettings(
|
|
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=12)
|
|
)
|
|
obs_shapes = [(obs_size,)]
|
|
|
|
networkbody = NetworkBody(obs_shapes, network_settings)
|
|
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-4)
|
|
sample_obs = torch.ones((1, seq_len, obs_size))
|
|
|
|
for _ in range(200):
|
|
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 12))
|
|
# Try to force output to 1
|
|
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
# In the last step, values should be close to 1
|
|
for _enc in encoded.flatten():
|
|
assert _enc == pytest.approx(1.0, abs=0.1)
|
|
|
|
|
|
def test_networkbody_visual():
|
|
torch.manual_seed(0)
|
|
vec_obs_size = 4
|
|
obs_size = (84, 84, 3)
|
|
network_settings = NetworkSettings()
|
|
obs_shapes = [(vec_obs_size,), obs_size]
|
|
torch.random.manual_seed(0)
|
|
|
|
networkbody = NetworkBody(obs_shapes, network_settings)
|
|
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
|
|
sample_obs = torch.ones((1, 84, 84, 3))
|
|
sample_vec_obs = torch.ones((1, vec_obs_size))
|
|
|
|
for _ in range(150):
|
|
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
|
|
assert encoded.shape == (1, network_settings.hidden_units)
|
|
# Try to force output to 1
|
|
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
# In the last step, values should be close to 1
|
|
for _enc in encoded.flatten():
|
|
assert _enc == pytest.approx(1.0, abs=0.1)
|
|
|
|
|
|
def test_valuenetwork():
|
|
torch.manual_seed(0)
|
|
obs_size = 4
|
|
num_outputs = 2
|
|
network_settings = NetworkSettings()
|
|
obs_shapes = [(obs_size,)]
|
|
|
|
stream_names = [f"stream_name{n}" for n in range(4)]
|
|
value_net = ValueNetwork(
|
|
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs
|
|
)
|
|
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3)
|
|
|
|
for _ in range(50):
|
|
sample_obs = torch.ones((1, obs_size))
|
|
values, _ = value_net([sample_obs], [])
|
|
loss = 0
|
|
for s_name in stream_names:
|
|
assert values[s_name].shape == (1, num_outputs)
|
|
# Try to force output to 1
|
|
loss += torch.nn.functional.mse_loss(
|
|
values[s_name], torch.ones((1, num_outputs))
|
|
)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
# In the last step, values should be close to 1
|
|
for value in values.values():
|
|
for _out in value:
|
|
assert _out[0] == pytest.approx(1.0, abs=0.1)
|
|
|
|
|
|
@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
|
|
def test_simple_actor(action_type):
|
|
obs_size = 4
|
|
network_settings = NetworkSettings()
|
|
obs_shapes = [(obs_size,)]
|
|
act_size = [2]
|
|
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
|
|
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
|
|
# Test get_dist
|
|
sample_obs = torch.ones((1, obs_size))
|
|
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
|
|
for dist in dists:
|
|
if action_type == ActionType.CONTINUOUS:
|
|
assert isinstance(dist, GaussianDistInstance)
|
|
else:
|
|
assert isinstance(dist, CategoricalDistInstance)
|
|
|
|
# Test sample_actions
|
|
actions = actor.sample_action(dists)
|
|
for act in actions:
|
|
if action_type == ActionType.CONTINUOUS:
|
|
assert act.shape == (1, act_size[0])
|
|
else:
|
|
assert act.shape == (1, 1)
|
|
|
|
# Test forward
|
|
actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
|
|
[sample_obs], [], masks=masks
|
|
)
|
|
for act in actions:
|
|
# This is different from above for ONNX export
|
|
if action_type == ActionType.CONTINUOUS:
|
|
assert act.shape == (act_size[0], 1)
|
|
else:
|
|
assert act.shape == tuple(act_size)
|
|
|
|
assert mem_size == 0
|
|
assert is_cont == int(action_type == ActionType.CONTINUOUS)
|
|
assert act_size_vec == torch.tensor(act_size)
|
|
|
|
|
|
@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
|
|
@pytest.mark.parametrize("lstm", [True, False])
|
|
def test_actor_critic(ac_type, lstm):
|
|
obs_size = 4
|
|
network_settings = NetworkSettings(
|
|
memory=NetworkSettings.MemorySettings() if lstm else None
|
|
)
|
|
obs_shapes = [(obs_size,)]
|
|
act_size = [2]
|
|
stream_names = [f"stream_name{n}" for n in range(4)]
|
|
actor = ac_type(
|
|
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
|
|
)
|
|
if lstm:
|
|
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
|
|
memories = torch.ones(
|
|
(1, network_settings.memory.sequence_length, actor.memory_size)
|
|
)
|
|
else:
|
|
sample_obs = torch.ones((1, obs_size))
|
|
memories = torch.tensor([])
|
|
# memories isn't always set to None, the network should be able to
|
|
# deal with that.
|
|
# Test critic pass
|
|
value_out, memories_out = actor.critic_pass([sample_obs], [], memories=memories)
|
|
for stream in stream_names:
|
|
if lstm:
|
|
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
|
|
assert memories_out.shape == memories.shape
|
|
else:
|
|
assert value_out[stream].shape == (1,)
|
|
|
|
# Test get_dist_and_value
|
|
dists, value_out, mem_out = actor.get_dist_and_value(
|
|
[sample_obs], [], memories=memories
|
|
)
|
|
if mem_out is not None:
|
|
assert mem_out.shape == memories.shape
|
|
for dist in dists:
|
|
assert isinstance(dist, GaussianDistInstance)
|
|
for stream in stream_names:
|
|
if lstm:
|
|
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
|
|
else:
|
|
assert value_out[stream].shape == (1,)
|