您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

102 行
3.2 KiB

from pose_estimation.single_cube_dataset import (
SingleCubeDataset,
RawDataIterator,
)
import os
import pytest
from yacs.config import CfgNode as CN
data_root = os.path.join(os.getcwd(), "tests")
zip_file_name = "test_single_cube_dataset"
root = os.path.join(data_root, zip_file_name)
@pytest.fixture
def config():
"""prepare config."""
with open("tests/config/test_config.yaml") as f:
cfg = CN.load_cfg(f)
return cfg
class TestSingleCubeDataset:
def test_RawDataIterator(self, config):
raw_data_iterator = RawDataIterator(path=root)
data_dicts = []
image_paths = []
for data_dict, image_path in raw_data_iterator:
data_dicts.append(data_dict)
image_paths.append(image_path)
assert len(data_dicts) == 10
assert len(image_paths) == 10
assert (len(set(image_paths))) == 10
for i in range(len(data_dicts)):
for j in range(len(config.dataset.symmetric)):
data_dict = data_dicts[i][j]
print(data_dict)
assert "translation" in data_dict.keys()
assert "rotation" in data_dict.keys()
assert len(data_dict['translation']) == 3
assert len(data_dict['rotation']) == 4
def test_SingleCubeIterator(self, config):
dataset_iterator = SingleCubeDataset(
config=config,
data_root=data_root,
zip_file_name=zip_file_name,
sample_size=0,
download=False,
)
images, targets_trans, targets_orient = [], [], []
assert len(dataset_iterator) == 10
for image, target_trans, target_orient in dataset_iterator:
images.append(image)
targets_trans.append(target_trans)
targets_orient.append(target_orient)
print(target_orient)
assert len(images) == 10
assert len(targets_trans) == 10
assert len(targets_orient) == 10
print(targets_trans[0])
for i in range(len(images)):
for j in range(len(config.dataset.symmetric)):
target_trans = targets_trans[i][j]
target_orient = targets_orient[i][j]
assert len(target_trans) == 3
assert len(target_orient) == 4
dataset_sample_iterator = SingleCubeDataset(
config=config,
data_root=data_root,
zip_file_name=zip_file_name,
sample_size=2,
)
sample_images, sample_targets_trans, sample_targets_orient = [], [], []
for image, target_trans, target_orient in dataset_sample_iterator:
sample_images.append(image)
sample_targets_trans.append(target_trans)
sample_targets_orient.append(target_orient)
assert len(sample_images) == 2
assert len(sample_targets_trans) == 2
assert len(sample_targets_orient) == 2
for i in range(len(sample_images)):
for j in range(len(config.dataset.symmetric)):
target_trans = sample_targets_trans[i][j]
target_orient = sample_targets_orient[i][j]
assert len(target_trans) == 3
assert len(target_orient) == 4