您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
31 行
1.1 KiB
31 行
1.1 KiB
import pytest
|
|
from mlagents.torch_utils import torch
|
|
|
|
from mlagents.trainers.torch.decoders import ValueHeads
|
|
|
|
|
|
def test_valueheads():
|
|
stream_names = [f"reward_signal_{num}" for num in range(5)]
|
|
input_size = 5
|
|
batch_size = 4
|
|
|
|
# Test default 1 value per head
|
|
value_heads = ValueHeads(stream_names, input_size)
|
|
input_data = torch.ones((batch_size, input_size))
|
|
value_out = value_heads(input_data) # Note: mean value will be removed shortly
|
|
|
|
for stream_name in stream_names:
|
|
assert value_out[stream_name].shape == (batch_size,)
|
|
|
|
# Test that inputting the wrong size input will throw an error
|
|
with pytest.raises(Exception):
|
|
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
|
|
|
|
# Test multiple values per head (e.g. discrete Q function)
|
|
output_size = 4
|
|
value_heads = ValueHeads(stream_names, input_size, output_size)
|
|
input_data = torch.ones((batch_size, input_size))
|
|
value_out = value_heads(input_data)
|
|
|
|
for stream_name in stream_names:
|
|
assert value_out[stream_name].shape == (batch_size, output_size)
|