浏览代码

Fix batch size issue with BC (#2965)

/develop/tanhsquash
GitHub 5 年前
当前提交
a2194ea7
共有 2 个文件被更改,包括 10 次插入7 次删除
  1. 7
      ml-agents/mlagents/trainers/bc/trainer.py
  2. 10
      ml-agents/mlagents/trainers/tests/test_bc.py

7
ml-agents/mlagents/trainers/bc/trainer.py


"""
self.demonstration_buffer.update_buffer.shuffle(self.policy.sequence_length)
batch_losses = []
batch_size = self.n_sequences * self.policy.sequence_length
# We either divide the entire buffer into num_batches batches, or limit the number
# of batches to batches_per_epoch.
len(self.demonstration_buffer.update_buffer["actions"]) // self.n_sequences,
len(self.demonstration_buffer.update_buffer["actions"]) // batch_size,
batch_size = self.n_sequences * self.policy.sequence_length
for i in range(0, num_batches * batch_size, batch_size):
update_buffer = self.demonstration_buffer.update_buffer

10
ml-agents/mlagents/trainers/tests/test_bc.py


use_recurrent: false
sequence_length: 32
memory_size: 32
batches_per_epoch: 1
batches_per_epoch: 100 # Force code to use all possible batches
batch_size: 32
summary_freq: 2000
max_steps: 4000

def create_bc_trainer(dummy_config, is_discrete=False):
def create_bc_trainer(dummy_config, is_discrete=False, use_recurrent=False):
mock_env = mock.Mock()
if is_discrete:
mock_brain = mb.create_mock_pushblock_brain()

trainer_parameters["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
)
trainer_parameters["use_recurrent"] = use_recurrent
trainer = BCTrainer(
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0
)

def test_bc_trainer_step(dummy_config):
trainer, env = create_bc_trainer(dummy_config)
@pytest.mark.parametrize("use_recurrent", [True, False])
def test_bc_trainer_step(dummy_config, use_recurrent):
trainer, env = create_bc_trainer(dummy_config, use_recurrent=use_recurrent)
# Test get_step
assert trainer.get_step == 0
# Test update policy

正在加载...
取消
保存