|
|
|
|
|
|
assert (agent_1_act.discrete_tensor[3:] == 0).all() |
|
|
|
|
|
|
|
|
|
|
|
def test_slice(): |
|
|
|
# Both continuous and discrete |
|
|
|
aa = AgentAction( |
|
|
|
torch.tensor([[1.0], [1.0], [1.0]]), |
|
|
|
[torch.tensor([2, 1, 0]), torch.tensor([1, 2, 0])], |
|
|
|
) |
|
|
|
saa = aa.slice(0, 2) |
|
|
|
assert saa.continuous_tensor.shape == (2, 1) |
|
|
|
assert saa.discrete_tensor.shape == (2, 2) |
|
|
|
|
|
|
|
|
|
|
|
def test_to_flat(): |
|
|
|
# Both continuous and discrete |
|
|
|
aa = AgentAction( |
|
|
|