浏览代码

Fix buffer tests and truncate

/develop-newnormalization
Ervin Teng 5 年前
当前提交
9053610f
共有 2 个文件被更改,包括 8 次插入8 次删除
  1. 2
      ml-agents/mlagents/trainers/buffer.py
  2. 14
      ml-agents/mlagents/trainers/tests/test_buffer.py

2
ml-agents/mlagents/trainers/buffer.py


we're not truncating at each update. Note that we must truncate an integer number of sequence_lengths
param: max_length: The length at which to truncate the buffer.
"""
current_length = len(next(iter(self)))
current_length = len(next(iter(self.values())))
# make max_length an integer number of sequence_lengths
max_length -= max_length % sequence_length
if current_length > max_length:

14
ml-agents/mlagents/trainers/tests/test_buffer.py


assert la[i] == lb[i]
def construct_fake_buffer():
def construct_fake_processing_buffer():
b["vector_observation"].append(
b[fake_agent_id]["vector_observation"].append(
[
100 * fake_agent_id + 10 * step + 1,
100 * fake_agent_id + 10 * step + 2,

b["action"].append(
b[fake_agent_id]["action"].append(
[
100 * fake_agent_id + 10 * step + 4,
100 * fake_agent_id + 10 * step + 5,

def test_buffer():
b = construct_fake_buffer()
b = construct_fake_processing_buffer()
a = b[1]["vector_observation"].get_batch(
batch_size=2, training_length=1, sequential=True
)

def test_buffer_sample():
b = construct_fake_buffer()
b = construct_fake_processing_buffer()
update_buffer = AgentBuffer()
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2)

def test_buffer_truncate():
b = construct_fake_buffer()
b = construct_fake_processing_buffer()
assert len(b.update_buffer["action"]) == 2
assert len(update_buffer["action"]) == 2
b.append_update_buffer(update_buffer, 3, batch_size=None, training_length=2)
b.append_update_buffer(update_buffer, 2, batch_size=None, training_length=2)

正在加载...
取消
保存