您最多选择25个主题
主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
46 行
1.7 KiB
46 行
1.7 KiB
import unittest.mock as mock
|
|
from mlagents.trainers.trainer_metrics import TrainerMetrics
|
|
|
|
|
|
class TestTrainerMetrics:
|
|
def test_field_names(self):
|
|
field_names = [
|
|
"Brain name",
|
|
"Time to update policy",
|
|
"Time since start of training",
|
|
"Time for last experience collection",
|
|
"Number of experiences used for training",
|
|
"Mean return",
|
|
]
|
|
from mlagents.trainers.trainer_metrics import FIELD_NAMES
|
|
|
|
assert FIELD_NAMES == field_names
|
|
|
|
@mock.patch(
|
|
"mlagents.trainers.trainer_metrics.time", mock.MagicMock(return_value=42)
|
|
)
|
|
def test_experience_collection_timer(self):
|
|
mock_path = "fake"
|
|
mock_brain_name = "fake"
|
|
trainer_metrics = TrainerMetrics(path=mock_path, brain_name=mock_brain_name)
|
|
trainer_metrics.start_experience_collection_timer()
|
|
trainer_metrics.end_experience_collection_timer()
|
|
assert trainer_metrics.delta_last_experience_collection == 0
|
|
|
|
@mock.patch(
|
|
"mlagents.trainers.trainer_metrics.time", mock.MagicMock(return_value=42)
|
|
)
|
|
def test_policy_update_timer(self):
|
|
mock_path = "fake"
|
|
mock_brain_name = "fake"
|
|
fake_buffer_length = 350
|
|
fake_mean_return = 0.3
|
|
trainer_metrics = TrainerMetrics(path=mock_path, brain_name=mock_brain_name)
|
|
trainer_metrics.start_experience_collection_timer()
|
|
trainer_metrics.end_experience_collection_timer()
|
|
trainer_metrics.start_policy_update_timer(
|
|
number_experiences=fake_buffer_length, mean_return=fake_mean_return
|
|
)
|
|
trainer_metrics.end_policy_update()
|
|
fake_row = [mock_brain_name, 0, 0, 0, 350, "0.300"]
|
|
assert trainer_metrics.rows[0] == fake_row
|