|
|
|
|
|
|
assert act.shape == (1, 1) |
|
|
|
|
|
|
|
# Test forward |
|
|
|
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward( |
|
|
|
actions, ver_num, mem_size, is_cont, act_size_vec = actor.forward( |
|
|
|
# This is different from above for ONNX export |
|
|
|
assert act.shape == ( |
|
|
|
act_size[0], |
|
|
|
1, |
|
|
|
) # This is different from above for ONNX export |
|
|
|
assert act.shape == (act_size[0], 1) |
|
|
|
assert act.shape == (1, 1) |
|
|
|
assert act.shape == tuple(act_size) |
|
|
|
# TODO: Once export works properly. fix the shapes here. |
|
|
|
assert mem_size == 0 |
|
|
|
assert is_cont == int(action_type == ActionType.CONTINUOUS) |
|
|
|
assert act_size_vec == torch.tensor(act_size) |
|
|
|