import pytest from mlagents.torch_utils import torch from mlagents.trainers.torch.distributions import ( GaussianDistribution, MultiCategoricalDistribution, GaussianDistInstance, TanhGaussianDistInstance, CategoricalDistInstance, ) @pytest.mark.parametrize("tanh_squash", [True, False]) @pytest.mark.parametrize("conditional_sigma", [True, False]) def test_gaussian_distribution(conditional_sigma, tanh_squash): torch.manual_seed(0) hidden_size = 16 act_size = 4 sample_embedding = torch.ones((1, 16)) gauss_dist = GaussianDistribution( hidden_size, act_size, conditional_sigma=conditional_sigma, tanh_squash=tanh_squash, ) # Make sure backprop works force_action = torch.zeros((1, act_size)) optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) for _ in range(50): dist_inst = gauss_dist(sample_embedding) if tanh_squash: assert isinstance(dist_inst, TanhGaussianDistInstance) else: assert isinstance(dist_inst, GaussianDistInstance) log_prob = dist_inst.log_prob(force_action) loss = torch.nn.functional.mse_loss(log_prob, -2 * torch.ones(log_prob.shape)) optimizer.zero_grad() loss.backward() optimizer.step() for prob in log_prob.flatten().tolist(): assert prob == pytest.approx(-2, abs=0.1) def test_multi_categorical_distribution(): torch.manual_seed(0) hidden_size = 16 act_size = [3, 3, 4] sample_embedding = torch.ones((1, 16)) gauss_dist = MultiCategoricalDistribution(hidden_size, act_size) # Make sure backprop works optimizer = torch.optim.Adam(gauss_dist.parameters(), lr=3e-3) def create_test_prob(size: int) -> torch.Tensor: test_prob = torch.tensor( [[1.0 - 0.01 * (size - 1)] + [0.01] * (size - 1)] ) # High prob for first action return test_prob.log() for _ in range(100): dist_insts = gauss_dist(sample_embedding, masks=torch.ones((1, sum(act_size)))) loss = 0 for i, dist_inst in enumerate(dist_insts): assert isinstance(dist_inst, CategoricalDistInstance) log_prob = dist_inst.all_log_prob() test_log_prob = create_test_prob(act_size[i]) # Force log_probs to match the high probability for the first action generated by # create_test_prob loss += torch.nn.functional.mse_loss(log_prob, test_log_prob) optimizer.zero_grad() loss.backward() optimizer.step() for dist_inst, size in zip(dist_insts, act_size): # Check that the log probs are close to the fake ones that we generated. test_log_probs = create_test_prob(size) for _prob, _test_prob in zip( dist_inst.all_log_prob().flatten().tolist(), test_log_probs.flatten().tolist(), ): assert _prob == pytest.approx(_test_prob, abs=0.1) # Test masks masks = [] for branch in act_size: masks += [0] * (branch - 1) + [1] masks = torch.tensor([masks]) dist_insts = gauss_dist(sample_embedding, masks=masks) for dist_inst in dist_insts: log_prob = dist_inst.all_log_prob() assert log_prob.flatten()[-1].tolist() == pytest.approx(0, abs=0.001) def test_gaussian_dist_instance(): torch.manual_seed(0) act_size = 4 dist_instance = GaussianDistInstance( torch.zeros(1, act_size), torch.ones(1, act_size) ) action = dist_instance.sample() assert action.shape == (1, act_size) for log_prob in ( dist_instance.log_prob(torch.zeros((1, act_size))).flatten().tolist() ): # Log prob of standard normal at 0 assert log_prob == pytest.approx(-0.919, abs=0.01) for ent in dist_instance.entropy().flatten().tolist(): # entropy of standard normal at 0, based on 1/2 + ln(sqrt(2pi)sigma) assert ent == pytest.approx(1.42, abs=0.01) def test_tanh_gaussian_dist_instance(): torch.manual_seed(0) act_size = 4 dist_instance = TanhGaussianDistInstance( torch.zeros(1, act_size), torch.ones(1, act_size) ) for _ in range(10): action = dist_instance.sample() assert action.shape == (1, act_size) assert torch.max(action) < 1.0 and torch.min(action) > -1.0 def test_categorical_dist_instance(): torch.manual_seed(0) act_size = 4 test_prob = torch.tensor( [[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] ) # High prob for first action dist_instance = CategoricalDistInstance(test_prob) for _ in range(10): action = dist_instance.sample() assert action.shape == (1, 1) assert action < act_size # Make sure the first action as higher probability than the others. prob_first_action = dist_instance.log_prob(torch.tensor([0])) for i in range(1, act_size): assert dist_instance.log_prob(torch.tensor([i])) < prob_first_action