浏览代码

Fix BCTrainer increment_steps (#2384)

/hotfix-v0.9.2a
GitHub 5 年前
当前提交
b498c19d
共有 4 个文件被更改,包括 59 次插入27 次删除
  1. 7
      ml-agents/mlagents/trainers/bc/trainer.py
  2. 20
      ml-agents/mlagents/trainers/tests/mock_brain.py
  3. 30
      ml-agents/mlagents/trainers/tests/test_bc.py
  4. 29
      ml-agents/mlagents/trainers/tests/test_bcmodule.py

7
ml-agents/mlagents/trainers/bc/trainer.py


"""
return self.policy.get_current_step()
def increment_step(self):
def increment_step(self, n_steps: int) -> None:
:param n_steps: number of steps to increment the step count by
self.policy.increment_step()
return
self.step = self.policy.increment_step(n_steps)
def add_experiences(
self,

20
ml-agents/mlagents/trainers/tests/mock_brain.py


buffer.append_update_buffer(0, batch_size=None, training_length=sequence_length)
return buffer
def create_mock_3dball_brain():
mock_brain = create_mock_brainparams(
vector_action_space_type="continuous",
vector_action_space_size=[2],
vector_observation_space_size=8,
)
mock_brain.brain_name = "Ball3DBrain"
return mock_brain
def create_mock_banana_brain():
mock_brain = create_mock_brainparams(
number_visual_observations=1,
vector_action_space_type="discrete",
vector_action_space_size=[3, 3, 3, 2],
vector_observation_space_size=0,
)
return mock_brain

30
ml-agents/mlagents/trainers/tests/test_bc.py


import unittest.mock as mock
import pytest
import os
import numpy as np
import tensorflow as tf

import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.bc.offline_trainer import BCTrainer
from mlagents.envs import UnityEnvironment
from mlagents.envs.mock_communicator import MockCommunicator

use_recurrent: false
sequence_length: 32
memory_size: 32
batches_per_epoch: 1
batch_size: 32
summary_freq: 2000
max_steps: 4000
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bc_trainer(mock_env, dummy_config):
mock_brain = mb.create_mock_3dball_brain()
mock_braininfo = mb.create_mock_braininfo(num_agents=12, num_vector_observations=8)
mb.setup_mock_unityenvironment(mock_env, mock_brain, mock_braininfo)
env = mock_env()
trainer_parameters = dummy_config
trainer_parameters["summary_path"] = "tmp"
trainer_parameters["model_path"] = "tmp"
trainer_parameters["demo_path"] = (
os.path.dirname(os.path.abspath(__file__)) + "/test.demo"
)
trainer = BCTrainer(
mock_brain, trainer_parameters, training=True, load=False, seed=0, run_id=0
)
trainer.demonstration_buffer = mb.simulate_rollout(env, trainer.policy, 100)
trainer.update_policy()
assert len(trainer.stats["Losses/Cloning Loss"]) > 0
trainer.increment_step(1)
assert trainer.step == 1
@mock.patch("mlagents.envs.UnityEnvironment.executable_launcher")

29
ml-agents/mlagents/trainers/tests/test_bcmodule.py


)
def create_mock_3dball_brain():
mock_brain = mb.create_mock_brainparams(
vector_action_space_type="continuous",
vector_action_space_size=[2],
vector_observation_space_size=8,
)
return mock_brain
def create_mock_banana_brain():
mock_brain = mb.create_mock_brainparams(
number_visual_observations=1,
vector_action_space_type="discrete",
vector_action_space_size=[3, 3, 3, 2],
vector_observation_space_size=0,
)
return mock_brain
def create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, use_rnn, demo_file
):

@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_defaults(mock_env, dummy_config):
# See if default values match
mock_brain = create_mock_3dball_brain()
mock_brain = mb.create_mock_3dball_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
)

# Test with continuous control env and vector actions
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_update(mock_env, dummy_config):
mock_brain = create_mock_3dball_brain()
mock_brain = mb.create_mock_3dball_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "test.demo"
)

# Test with RNN
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_rnn_update(mock_env, dummy_config):
mock_brain = create_mock_3dball_brain()
mock_brain = mb.create_mock_3dball_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, True, "test.demo"
)

# Test with discrete control and visual observations
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_dc_visual_update(mock_env, dummy_config):
mock_brain = create_mock_banana_brain()
mock_brain = mb.create_mock_banana_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, False, "testdcvis.demo"
)

# Test with discrete control, visual observations and RNN
@mock.patch("mlagents.envs.UnityEnvironment")
def test_bcmodule_rnn_dc_update(mock_env, dummy_config):
mock_brain = create_mock_banana_brain()
mock_brain = mb.create_mock_banana_brain()
env, policy = create_ppo_policy_with_bc_mock(
mock_env, mock_brain, dummy_config, True, "testdcvis.demo"
)

正在加载...
取消
保存