浏览代码

Fix TF policy test

/develop/action-spec-gym
Ervin Teng 4 年前
当前提交
ceeea719
共有 1 个文件被更改,包括 12 次插入14 次删除
  1. 26
      ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py

26
ml-agents/mlagents/trainers/tests/tensorflow/test_tf_policy.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents_envs.base_env import DecisionSteps, BehaviorSpec
from mlagents_envs.base_env import ActionSpec, DecisionSteps, BehaviorSpec
from mlagents.trainers.action_info import ActionInfo
from unittest.mock import MagicMock
from mlagents.trainers.settings import TrainerSettings

def basic_mock_brain():
mock_brain = MagicMock()
mock_brain.vector_action_space_type = "continuous"
mock_brain.vector_observation_space_size = 1
mock_brain.vector_action_space_size = [1]
mock_brain.brain_name = "MockBrain"
return mock_brain
def basic_behavior_spec():
dummy_actionspec = ActionSpec(1, ())
dummy_groupspec = BehaviorSpec([(1,)], dummy_actionspec)
return dummy_groupspec
class FakePolicy(TFPolicy):

def test_take_action_returns_empty_with_no_agents():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
# Doesn't really matter what this is
dummy_groupspec = BehaviorSpec([(1,)], "continuous", 1)
no_agent_step = DecisionSteps.empty(dummy_groupspec)
behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
no_agent_step = DecisionSteps.empty(behavior_spec)
result = policy.get_action(no_agent_step)
assert result == ActionInfo.empty()

policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
policy.evaluate = MagicMock(return_value={})
policy.save_memories = MagicMock()
step_with_agents = DecisionSteps(

def test_take_action_returns_action_info_when_available():
test_seed = 3
policy = FakePolicy(test_seed, basic_mock_brain(), TrainerSettings(), "output")
behavior_spec = basic_behavior_spec()
policy = FakePolicy(test_seed, behavior_spec, TrainerSettings(), "output")
policy_eval_out = {
"action": np.array([1.0], dtype=np.float32),
"memory_out": np.array([[2.5]], dtype=np.float32),

正在加载...
取消
保存