Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

31 行
1.1 KiB

import pytest
from mlagents.torch_utils import torch
from mlagents.trainers.torch.decoders import ValueHeads
def test_valueheads():
stream_names = [f"reward_signal_{num}" for num in range(5)]
input_size = 5
batch_size = 4
# Test default 1 value per head
value_heads = ValueHeads(stream_names, input_size)
input_data = torch.ones((batch_size, input_size))
value_out = value_heads(input_data) # Note: mean value will be removed shortly
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size,)
# Test that inputting the wrong size input will throw an error
with pytest.raises(Exception):
value_out = value_heads(torch.ones((batch_size, input_size + 2)))
# Test multiple values per head (e.g. discrete Q function)
output_size = 4
value_heads = ValueHeads(stream_names, input_size, output_size)
input_data = torch.ones((batch_size, input_size))
value_out = value_heads(input_data)
for stream_name in stream_names:
assert value_out[stream_name].shape == (batch_size, output_size)