浏览代码

Fix AgentProcessor tests

/develop/critic-op-lstm-currentmem
Ervin Teng 4 年前
当前提交
7471a2fd
共有 1 个文件被更改,包括 17 次插入13 次删除
  1. 30
      ml-agents/mlagents/trainers/tests/test_agent_processor.py

30
ml-agents/mlagents/trainers/tests/test_agent_processor.py


def create_mock_policy():
mock_policy = mock.Mock()
mock_policy.reward_signals = {}
mock_policy.retrieve_memories.return_value = np.zeros((1, 1), dtype=np.float32)
mock_policy.retrieve_previous_memories.return_value = np.zeros(
(1, 1), dtype=np.float32
)
mock_policy.retrieve_previous_action.return_value = np.zeros((1, 1), dtype=np.int32)
return mock_policy

)
fake_action_outputs = {
"action": ActionTuple(continuous=np.array([[0.1], [0.1]])),
"action": ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)),
"log_probs": LogProbsTuple(continuous=np.array([[0.1], [0.1]])),
"log_probs": LogProbsTuple(
continuous=np.array([[0.1], [0.1]], dtype=np.float32)
),
}
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=2,

action_spec=ActionSpec.create_continuous(2),
)
fake_action_info = ActionInfo(
action=ActionTuple(continuous=np.array([[0.1], [0.1]])),
env_action=ActionTuple(continuous=np.array([[0.1], [0.1]])),
action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)),
env_action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)),
outputs=fake_action_outputs,
agent_ids=mock_decision_steps.agent_id,
)

stats_reporter=StatsReporter("testcat"),
)
fake_action_outputs = {
"action": ActionTuple(continuous=np.array([[0.1]])),
"action": ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)),
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])),
"log_probs": LogProbsTuple(continuous=np.array([[0.1]], dtype=np.float32)),
}
mock_decision_step, mock_terminal_step = mb.create_mock_steps(

done=True,
)
fake_action_info = ActionInfo(
action=ActionTuple(continuous=np.array([[0.1]])),
env_action=ActionTuple(continuous=np.array([[0.1]])),
action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)),
env_action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)),
outputs=fake_action_outputs,
agent_ids=mock_decision_step.agent_id,
)

stats_reporter=StatsReporter("testcat"),
)
fake_action_outputs = {
"action": ActionTuple(continuous=np.array([[0.1]])),
"action": ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)),
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])),
"log_probs": LogProbsTuple(continuous=np.array([[0.1]], dtype=np.float32)),
}
mock_decision_step, mock_terminal_step = mb.create_mock_steps(

)
fake_action_info = ActionInfo(
action=ActionTuple(continuous=np.array([[0.1]])),
env_action=ActionTuple(continuous=np.array([[0.1]])),
action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)),
env_action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)),
outputs=fake_action_outputs,
agent_ids=mock_decision_step.agent_id,
)

正在加载...
取消
保存