您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
141 行
4.8 KiB
141 行
4.8 KiB
import pytest
|
|
import torch
|
|
|
|
from mlagents.trainers.torch.distributions import (
|
|
GaussianDistribution,
|
|
MultiCategoricalDistribution,
|
|
GaussianDistInstance,
|
|
TanhGaussianDistInstance,
|
|
CategoricalDistInstance,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("tanh_squash", [True, False])
|
|
@pytest.mark.parametrize("conditional_sigma", [True, False])
|
|
def test_gaussian_distribution(conditional_sigma, tanh_squash):
|
|
torch.manual_seed(0)
|
|
hidden_size = 16
|
|
act_size = 4
|
|
sample_embedding = torch.ones((1, 16))
|
|
gauss_dist = GaussianDistribution(
|
|
hidden_size,
|
|
act_size,
|
|
conditional_sigma=conditional_sigma,
|
|
tanh_squash=tanh_squash,
|
|
)
|
|
|
|
# Make sure backprop works
|
|
force_action = torch.zeros((1, act_size))
|
|
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
|
|
|
|
for _ in range(50):
|
|
dist_inst = gauss_dist(sample_embedding)[0]
|
|
if tanh_squash:
|
|
assert isinstance(dist_inst, TanhGaussianDistInstance)
|
|
else:
|
|
assert isinstance(dist_inst, GaussianDistInstance)
|
|
log_prob = dist_inst.log_prob(force_action)
|
|
loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape))
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
for prob in log_prob.flatten():
|
|
assert prob == pytest.approx(-2, abs=0.1)
|
|
|
|
|
|
def test_multi_categorical_distribution():
|
|
torch.manual_seed(0)
|
|
hidden_size = 16
|
|
act_size = [3, 3, 4]
|
|
sample_embedding = torch.ones((1, 16))
|
|
gauss_dist = MultiCategoricalDistribution(hidden_size, act_size)
|
|
|
|
# Make sure backprop works
|
|
optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3)
|
|
|
|
def create_test_prob(size: int) -> torch.Tensor:
|
|
test_prob = torch.tensor(
|
|
[[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)]
|
|
) # High prob for first action
|
|
return test_prob.log()
|
|
|
|
for _ in range(100):
|
|
dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size))))
|
|
loss = 0
|
|
for i, dist_inst in enumerate(dist_insts):
|
|
assert isinstance(dist_inst, CategoricalDistInstance)
|
|
log_prob = dist_inst.all_log_prob()
|
|
test_log_prob = create_test_prob(act_size[i])
|
|
# Force log_probs to match the high probability for the first action generated by
|
|
# create_test_prob
|
|
loss += torch.nn.functional.mse_loss(log_prob, test_log_prob)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
for dist_inst, size in zip(dist_insts, act_size):
|
|
# Check that the log probs are close to the fake ones that we generated.
|
|
test_log_probs = create_test_prob(size)
|
|
for _prob, _test_prob in zip(
|
|
dist_inst.all_log_prob().flatten().tolist(),
|
|
test_log_probs.flatten().tolist(),
|
|
):
|
|
assert _prob == pytest.approx(_test_prob, abs=0.1)
|
|
|
|
# Test masks
|
|
masks = []
|
|
for branch in act_size:
|
|
masks += [0] * (branch - 1) + [1]
|
|
masks = torch.tensor([masks])
|
|
dist_insts = gauss_dist(sample_embedding, masks=masks)
|
|
for dist_inst in dist_insts:
|
|
log_prob = dist_inst.all_log_prob()
|
|
assert log_prob.flatten()[-1] == pytest.approx(0, abs=0.001)
|
|
|
|
|
|
def test_gaussian_dist_instance():
|
|
torch.manual_seed(0)
|
|
act_size = 4
|
|
dist_instance = GaussianDistInstance(
|
|
torch.zeros(1, act_size), torch.ones(1, act_size)
|
|
)
|
|
action = dist_instance.sample()
|
|
assert action.shape == (1, act_size)
|
|
for log_prob in dist_instance.log_prob(torch.zeros((1, act_size))).flatten():
|
|
# Log prob of standard normal at 0
|
|
assert log_prob == pytest.approx(-0.919, abs=0.01)
|
|
|
|
for ent in dist_instance.entropy().flatten():
|
|
# entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma)
|
|
assert ent == pytest.approx(1.42, abs=0.01)
|
|
|
|
|
|
def test_tanh_gaussian_dist_instance():
|
|
torch.manual_seed(0)
|
|
act_size = 4
|
|
dist_instance = TanhGaussianDistInstance(
|
|
torch.zeros(1, act_size), torch.ones(1, act_size)
|
|
)
|
|
for _ in range(10):
|
|
action = dist_instance.sample()
|
|
assert action.shape == (1, act_size)
|
|
assert torch.max(action) < 1.0 and torch.min(action) > -1.0
|
|
|
|
|
|
def test_categorical_dist_instance():
|
|
torch.manual_seed(0)
|
|
act_size = 4
|
|
test_prob = torch.tensor(
|
|
[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)
|
|
) # High prob for first action
|
|
dist_instance = CategoricalDistInstance(test_prob)
|
|
|
|
for _ in range(10):
|
|
action = dist_instance.sample()
|
|
assert action.shape == (1,)
|
|
assert action < act_size
|
|
|
|
# Make sure the first action as higher probability than the others.
|
|
prob_first_action = dist_instance.log_prob(torch.tensor([0]))
|
|
|
|
for i in range(1, act_size):
|
|
assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action
|