|
|
|
|
|
|
def create_mock_policy(): |
|
|
|
mock_policy = mock.Mock() |
|
|
|
mock_policy.reward_signals = {} |
|
|
|
mock_policy.retrieve_memories.return_value = np.zeros((1, 1), dtype=np.float32) |
|
|
|
mock_policy.retrieve_previous_memories.return_value = np.zeros( |
|
|
|
(1, 1), dtype=np.float32 |
|
|
|
) |
|
|
|
mock_policy.retrieve_previous_action.return_value = np.zeros((1, 1), dtype=np.int32) |
|
|
|
return mock_policy |
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
fake_action_outputs = { |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)), |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
"log_probs": LogProbsTuple( |
|
|
|
continuous=np.array([[0.1], [0.1]], dtype=np.float32) |
|
|
|
), |
|
|
|
} |
|
|
|
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps( |
|
|
|
num_agents=2, |
|
|
|
|
|
|
action_spec=ActionSpec.create_continuous(2), |
|
|
|
) |
|
|
|
fake_action_info = ActionInfo( |
|
|
|
action=ActionTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
env_action=ActionTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)), |
|
|
|
env_action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)), |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_decision_steps.agent_id, |
|
|
|
) |
|
|
|
|
|
|
stats_reporter=StatsReporter("testcat"), |
|
|
|
) |
|
|
|
fake_action_outputs = { |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])), |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
} |
|
|
|
|
|
|
|
mock_decision_step, mock_terminal_step = mb.create_mock_steps( |
|
|
|
|
|
|
done=True, |
|
|
|
) |
|
|
|
fake_action_info = ActionInfo( |
|
|
|
action=ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
env_action=ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
env_action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_decision_step.agent_id, |
|
|
|
) |
|
|
|
|
|
|
stats_reporter=StatsReporter("testcat"), |
|
|
|
) |
|
|
|
fake_action_outputs = { |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])), |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
} |
|
|
|
|
|
|
|
mock_decision_step, mock_terminal_step = mb.create_mock_steps( |
|
|
|
|
|
|
) |
|
|
|
fake_action_info = ActionInfo( |
|
|
|
action=ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
env_action=ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
env_action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_decision_step.agent_id, |
|
|
|
) |
|
|
|