Unity 机器学习代理工具包 (ML-Agents) 是一个开源项目,它使游戏和模拟能够作为训练智能代理的环境。
您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符
 
 
 
 
 

71 行
2.4 KiB

from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
from mlagents.trainers.action_info import ActionInfo
from unittest.mock import MagicMock
from mlagents.trainers.settings import TrainerSettings
import numpy as np
def basic_mock_brain():
mock_brain = MagicMock()
mock_brain.vector_action_space_type = "continuous"
mock_brain.vector_observation_space_size = 1
mock_brain.vector_action_space_size = [1]
return mock_brain
class FakePolicy(TFPolicy):
def create_tf_graph(self):
pass
def get_trainable_variables(self):
return []
def test_take_action_returns_empty_with_no_agents():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings())
# Doesn't really matter what this is
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1)
no_agent_step = DecisionSteps.empty(dummy_groupspec)
result = policy.get_action(no_agent_step)
assert result == ActionInfo.empty()
def test_take_action_returns_nones_on_missing_values():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings())
policy.evaluate = MagicMock(return_value={})
policy.save_memories = MagicMock()
step_with_agents = DecisionSteps(
[], np.array([], dtype=np.float32), np.array([0]), None
)
result = policy.get_action(step_with_agents, worker_id=0)
assert result == ActionInfo(None, None, {}, [0])
def test_take_action_returns_action_info_when_available():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings())
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),
"value": np.array([1.1], dtype=np.float32),
}
policy.evaluate = MagicMock(return_value=policy_eval_out)
step_with_agents = DecisionSteps(
[], np.array([], dtype=np.float32), np.array([0]), None
)
result = policy.get_action(step_with_agents)
expected = ActionInfo(
policy_eval_out["action"], policy_eval_out["value"], policy_eval_out, [0]
)
assert result == expected
def test_convert_version_string():
result = TFPolicy._convert_version_string("200.300.100")
assert result == (200, 300, 100)
# Test dev versions
result = TFPolicy._convert_version_string("200.300.100.dev0")
assert result == (200, 300, 100)