浏览代码

fix test agent processor

/develop/actionmodel-csharp
Andrew Cohen 4 年前
当前提交
0c5934ec
共有 1 个文件被更改,包括 14 次插入13 次删除
  1. 27
      ml-agents/mlagents/trainers/tests/test_agent_processor.py

27
ml-agents/mlagents/trainers/tests/test_agent_processor.py


AgentManagerQueue,
)
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.torch.action_log_probs import LogProbsTuple
from mlagents_envs.base_env import ActionSpec
from mlagents_envs.base_env import ActionSpec, ActionTuple
def create_mock_policy():

)
fake_action_outputs = {
"action": {"continuous_action": [0.1, 0.1]},
"action": ActionTuple(continuous=np.array([[0.1], [0.1]])),
"log_probs": {"continuous_log_probs": [0.1, 0.1]},
"log_probs": LogProbsTuple(continuous=np.array([[0.1], [0.1]])),
}
mock_decision_steps, mock_terminal_steps = mb.create_mock_steps(
num_agents=2,

fake_action_info = ActionInfo(
action={"continuous_action": [0.1, 0.1]},
action=ActionTuple(continuous=np.array([[0.1], [0.1]])),
value=[0.1, 0.1],
outputs=fake_action_outputs,
agent_ids=mock_decision_steps.agent_id,

max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)
"action": {"continuous_action": [0.1]},
"action": ActionTuple(continuous=np.array([[0.1]])),
"log_probs": {"continuous_log_probs": [0.1]},
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])),
mock_decision_step, mock_terminal_step = mb.create_mock_steps(
num_agents=1,
observation_shapes=[(8,)],

done=True,
)
fake_action_info = ActionInfo(
action={"continuous_action": [0.1]},
action=ActionTuple(continuous=np.array([[0.1]])),
value=[0.1],
outputs=fake_action_outputs,
agent_ids=mock_decision_step.agent_id,

mock_decision_step, mock_terminal_step, _ep, fake_action_info
)
add_calls.append(
mock.call([get_global_agent_id(_ep, 0)], {"continuous_action": [0.1]})
mock.call([get_global_agent_id(_ep, 0)], fake_action_outputs["action"])
)
processor.add_experiences(
mock_done_decision_step, mock_done_terminal_step, _ep, fake_action_info

max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)
"action": {"continuous_action": [0.1]},
"action": ActionTuple(continuous=np.array([[0.1]])),
"log_probs": {"continuous_log_probs": [0.1]},
"log_probs": LogProbsTuple(continuous=np.array([[0.1]])),
mock_decision_step, mock_terminal_step = mb.create_mock_steps(
num_agents=1,
observation_shapes=[(8,)],

action={"continuous_action": [0.1]},
action=ActionTuple(continuous=np.array([[0.1]])),
value=[0.1],
outputs=fake_action_outputs,
agent_ids=mock_decision_step.agent_id,

正在加载...
取消
保存