Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

39 行
1.8 KiB

import unittest.mock as mock
from mlagents.trainers import TrainerMetrics
class TestTrainerMetrics:
def test_field_names(self):
field_names = ['Brain name', 'Time to update policy',
'Time since start of training',
'Time for last experience collection',
'Number of experiences used for training', 'Mean return']
from mlagents.trainers.trainer_metrics import FIELD_NAMES
assert FIELD_NAMES == field_names
@mock.patch('mlagents.trainers.trainer_metrics.time', mock.MagicMock(return_value=42))
def test_experience_collection_timer(self):
mock_path = 'fake'
mock_brain_name = 'fake'
trainer_metrics = TrainerMetrics(path=mock_path,
brain_name=mock_brain_name)
trainer_metrics.start_experience_collection_timer()
trainer_metrics.end_experience_collection_timer()
assert trainer_metrics.delta_last_experience_collection == 0
@mock.patch('mlagents.trainers.trainer_metrics.time', mock.MagicMock(return_value=42))
def test_policy_update_timer(self):
mock_path = 'fake'
mock_brain_name = 'fake'
fake_buffer_length = 350
fake_mean_return = 0.3
trainer_metrics = TrainerMetrics(path=mock_path,
brain_name=mock_brain_name)
trainer_metrics.start_experience_collection_timer()
trainer_metrics.end_experience_collection_timer()
trainer_metrics.start_policy_update_timer(number_experiences=fake_buffer_length,
mean_return=fake_mean_return)
trainer_metrics.end_policy_update()
fake_row = [mock_brain_name, 0,0, 0, 350, '0.300']
assert trainer_metrics.rows[0] == fake_row