浏览代码

Improve tests

/develop/sac-apex
Ervin Teng 5 年前
当前提交
99ce4b59
共有 2 个文件被更改,包括 47 次插入5 次删除
  1. 23
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py
  2. 29
      ml-agents/mlagents/trainers/tests/test_sac.py

23
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


import yaml
from unittest import mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer

# Add concrete implementations of abstract methods
class FakeTrainer(RLTrainer):
def set_is_policy_updating(self, is_updating):
self.update_policy = is_updating
def get_policy(self, name_behavior_id):
return mock.Mock()

def _update_policy(self):
return True
return self.update_policy
def add_policy(self):
pass

def create_rl_trainer():
mock_brainparams = create_mock_brain()
trainer = FakeTrainer(mock_brainparams, dummy_config(), True, 0)
trainer.set_is_policy_updating(True)
return trainer

def test_advance(mocked_clear_update_buffer):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")
policy_queue = AgentManagerQueue("testbrain")
time_horizon = 15
trainer.publish_policy_queue(policy_queue)
time_horizon = 10
trajectory = mb.make_fake_trajectory(
length=time_horizon,
max_step_complete=True,

trajectory_queue.put(trajectory)
trainer.advance()
policy_queue.get(block=False)
for _ in range(0, 5):
trajectory_queue.put(trajectory)
trainer.advance()
# Check that there is stuff in the policy queue
policy_queue.get(block=False)
# Check that if the policy doesn't update, we don't push it to the queue
trainer.set_is_policy_updating(False)
# Check that there nothing in the policy queue
with pytest.raises(AgentManagerQueue.Empty):
policy_queue.get(block=False)
# Check that the buffer has been cleared
assert not trainer.should_still_train

29
ml-agents/mlagents/trainers/tests/test_sac.py


return yaml.safe_load(
"""
trainer: sac
batch_size: 32
batch_size: 8
buffer_size: 10240
buffer_init_steps: 0
hidden_units: 32

trainer.add_policy(brain_params, policy)
def test_process_trajectory(dummy_config):
def test_advance(dummy_config):
dummy_config["steps_per_update"] = 20
policy_queue = AgentManagerQueue("testbrain")
trainer.publish_policy_queue(policy_queue)
trajectory = make_fake_trajectory(
length=15,

action_space=[2],
is_discrete=False,
)
trajectory_queue.put(trajectory)
trainer.advance()

# Add a terminal trajectory
trajectory = make_fake_trajectory(
length=15,
length=6,
is_discrete=False,
)
trajectory_queue.put(trajectory)
trainer.advance()

assert (
trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").mean > 0
)
# Make sure there is a policy on the queue
policy_queue.get(block=False)
# Add another trajectory. Since this is less than 20 steps total (enough for)
# two updates, there should NOT be a policy on the queue.
trajectory = make_fake_trajectory(
length=5,
max_step_complete=False,
vec_obs_size=6,
num_vis_obs=0,
action_space=[2],
is_discrete=False,
)
trajectory_queue.put(trajectory)
trainer.advance()
with pytest.raises(AgentManagerQueue.Empty):
policy_queue.get(block=False)
def test_bad_config(dummy_config):

正在加载...
取消
保存