|
|
|
|
|
|
AgentManagerQueue, |
|
|
|
) |
|
|
|
from mlagents.trainers.action_info import ActionInfo |
|
|
|
from mlagents.trainers.torch.action_log_probs import LogProbsTuple |
|
|
|
from mlagents_envs.base_env import ActionSpec |
|
|
|
from mlagents_envs.base_env import ActionSpec, ActionTuple |
|
|
|
|
|
|
|
|
|
|
|
def create_mock_policy(): |
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
fake_action_outputs = { |
|
|
|
"action": {"continuous_action": [0.1, 0.1]}, |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
"log_probs": {"continuous_log_probs": [0.1, 0.1]}, |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
} |
|
|
|
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps( |
|
|
|
num_agents=2, |
|
|
|
|
|
|
fake_action_info = ActionInfo( |
|
|
|
action={"continuous_action": [0.1, 0.1]}, |
|
|
|
action=ActionTuple(continuous=np.array([[0.1], [0.1]])), |
|
|
|
value=[0.1, 0.1], |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_decision_steps.agent_id, |
|
|
|
|
|
|
max_trajectory_length=5, |
|
|
|
stats_reporter=StatsReporter("testcat"), |
|
|
|
) |
|
|
|
|
|
|
|
"action": {"continuous_action": [0.1]}, |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
"log_probs": {"continuous_log_probs": [0.1]}, |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])), |
|
|
|
|
|
|
|
mock_decision_step, mock_terminal_step = mb.create_mock_steps( |
|
|
|
num_agents=1, |
|
|
|
observation_shapes=[(8,)], |
|
|
|
|
|
|
done=True, |
|
|
|
) |
|
|
|
fake_action_info = ActionInfo( |
|
|
|
action={"continuous_action": [0.1]}, |
|
|
|
action=ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
value=[0.1], |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_decision_step.agent_id, |
|
|
|
|
|
|
mock_decision_step, mock_terminal_step, _ep, fake_action_info |
|
|
|
) |
|
|
|
add_calls.append( |
|
|
|
mock.call([get_global_agent_id(_ep, 0)], {"continuous_action": [0.1]}) |
|
|
|
mock.call([get_global_agent_id(_ep, 0)], fake_action_outputs["action"]) |
|
|
|
) |
|
|
|
processor.add_experiences( |
|
|
|
mock_done_decision_step, mock_done_terminal_step, _ep, fake_action_info |
|
|
|
|
|
|
max_trajectory_length=5, |
|
|
|
stats_reporter=StatsReporter("testcat"), |
|
|
|
) |
|
|
|
|
|
|
|
"action": {"continuous_action": [0.1]}, |
|
|
|
"action": ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
"log_probs": {"continuous_log_probs": [0.1]}, |
|
|
|
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])), |
|
|
|
|
|
|
|
mock_decision_step, mock_terminal_step = mb.create_mock_steps( |
|
|
|
num_agents=1, |
|
|
|
observation_shapes=[(8,)], |
|
|
|
|
|
|
action={"continuous_action": [0.1]}, |
|
|
|
action=ActionTuple(continuous=np.array([[0.1]])), |
|
|
|
value=[0.1], |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_decision_step.agent_id, |
|
|
|