浏览代码

Fix clear update buffer when trainer stops training, add test (#3422)

* Fix clear update buffer when trainer stops training, add test

* Fix buffer changing types when truncated
/release-0.14.0
GitHub 4 年前
当前提交
1f9d04f2
共有 4 个文件被更改,包括 34 次插入2 次删除
  1. 2
      ml-agents/mlagents/trainers/buffer.py
  2. 2
      ml-agents/mlagents/trainers/rl_trainer.py
  3. 2
      ml-agents/mlagents/trainers/tests/test_buffer.py
  4. 30
      ml-agents/mlagents/trainers/tests/test_rl_trainer.py

2
ml-agents/mlagents/trainers/buffer.py


max_length -= max_length % sequence_length
if current_length > max_length:
for _key in self.keys():
self[_key] = self[_key][current_length - max_length :]
self[_key][:] = self[_key][current_length - max_length :]
def resequence_and_append(
self,

2
ml-agents/mlagents/trainers/rl_trainer.py


Steps the trainer, taking in trajectories and updates if ready
"""
super().advance()
if not self.is_training:
if not self.should_still_train:
self.clear_update_buffer()

2
ml-agents/mlagents/trainers/tests/test_buffer.py


# Test LSTM, truncate should be some multiple of sequence_length
update_buffer.truncate(4, sequence_length=3)
assert update_buffer.num_experiences == 3
for buffer_field in update_buffer.values():
assert isinstance(buffer_field, AgentBuffer.AgentBufferField)

30
ml-agents/mlagents/trainers/tests/test_rl_trainer.py


import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue
def dummy_config():

summary_freq: 1000
max_steps: 100
reward_signals:
extrinsic:
strength: 1.0

trainer.clear_update_buffer()
for _, arr in trainer.update_buffer.items():
assert len(arr) == 0
@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.clear_update_buffer")
def test_advance(mocked_clear_update_buffer):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
time_horizon = 15
trajectory = mb.make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
vec_obs_size=1,
num_vis_obs=0,
action_space=[2],
)
trajectory_queue.put(trajectory)
trainer.advance()
# Check that get_step is correct
assert trainer.get_step == time_horizon
# Check that we can turn off the trainer and that the buffer is cleared
for _ in range(0, 10):
trajectory_queue.put(trajectory)
trainer.advance()
# Check that the buffer has been cleared
assert not trainer.should_still_train
assert mocked_clear_update_buffer.call_count > 0
正在加载...
取消
保存