|
|
|
|
|
|
from mlagents.trainers.action_info import ActionInfo |
|
|
|
from mlagents.trainers.trajectory import Trajectory |
|
|
|
from mlagents.trainers.stats import StatsReporter |
|
|
|
from mlagents.trainers.brain_conversion_utils import get_global_agent_id |
|
|
|
|
|
|
|
|
|
|
|
def create_mock_brain(): |
|
|
|
|
|
|
processor.add_experiences(mock_step, 0, ActionInfo([], [], {}, [])) |
|
|
|
# Assert that the AgentProcessor is still empty |
|
|
|
assert len(processor.experience_buffers[0]) == 0 |
|
|
|
|
|
|
|
|
|
|
|
def test_agent_deletion(): |
|
|
|
policy = create_mock_policy() |
|
|
|
tqueue = mock.Mock() |
|
|
|
name_behavior_id = "test_brain_name" |
|
|
|
processor = AgentProcessor( |
|
|
|
policy, |
|
|
|
name_behavior_id, |
|
|
|
max_trajectory_length=5, |
|
|
|
stats_reporter=StatsReporter("testcat"), |
|
|
|
) |
|
|
|
|
|
|
|
fake_action_outputs = { |
|
|
|
"action": [0.1], |
|
|
|
"entropy": np.array([1.0], dtype=np.float32), |
|
|
|
"learning_rate": 1.0, |
|
|
|
"pre_action": [0.1], |
|
|
|
"log_probs": [0.1], |
|
|
|
} |
|
|
|
mock_step = mb.create_mock_batchedstep( |
|
|
|
num_agents=1, |
|
|
|
num_vector_observations=8, |
|
|
|
action_shape=[2], |
|
|
|
num_vis_observations=0, |
|
|
|
) |
|
|
|
mock_done_step = mb.create_mock_batchedstep( |
|
|
|
num_agents=1, |
|
|
|
num_vector_observations=8, |
|
|
|
action_shape=[2], |
|
|
|
num_vis_observations=0, |
|
|
|
done=True, |
|
|
|
) |
|
|
|
fake_action_info = ActionInfo( |
|
|
|
action=[0.1], |
|
|
|
value=[0.1], |
|
|
|
outputs=fake_action_outputs, |
|
|
|
agent_ids=mock_step.agent_id, |
|
|
|
) |
|
|
|
|
|
|
|
processor.publish_trajectory_queue(tqueue) |
|
|
|
# This is like the initial state after the env reset |
|
|
|
processor.add_experiences(mock_step, 0, ActionInfo.empty()) |
|
|
|
|
|
|
|
# Run 3 trajectories, with different workers (to simulate different agents) |
|
|
|
add_calls = [] |
|
|
|
add_calls.append(mock.call([get_global_agent_id(0, 0)], [0.1])) |
|
|
|
remove_calls = [] |
|
|
|
for _ep in range(3): |
|
|
|
for _ in range(5): |
|
|
|
processor.add_experiences(mock_step, _ep, fake_action_info) |
|
|
|
add_calls.append(mock.call([get_global_agent_id(_ep, 0)], [0.1])) |
|
|
|
processor.add_experiences(mock_done_step, _ep, fake_action_info) |
|
|
|
# Make sure we don't add experiences from the prior agents after the done |
|
|
|
remove_calls.append(mock.call([get_global_agent_id(_ep, 0)])) |
|
|
|
|
|
|
|
policy.save_previous_action.assert_has_calls(add_calls) |
|
|
|
policy.remove_previous_action.assert_has_calls(remove_calls) |
|
|
|
# Check that there are no experiences left |
|
|
|
assert len(processor.experience_buffers.keys()) == 0 |
|
|
|
assert len(processor.last_take_action_outputs.keys()) == 0 |
|
|
|
assert len(processor.episode_steps.keys()) == 0 |
|
|
|
assert len(processor.episode_rewards.keys()) == 0 |
|
|
|
|
|
|
|
|
|
|
|
def test_agent_manager(): |
|
|
|