浏览代码

[change] Throw a proper error when sequence length is greater than batch size. (#3583)

/bug-failed-api-check
GitHub 4 年前
当前提交
c42a11c3
共有 4 个文件被更改,包括 59 次插入8 次删除
  1. 15
      ml-agents/mlagents/trainers/ppo/trainer.py
  2. 14
      ml-agents/mlagents/trainers/sac/trainer.py
  3. 23
      ml-agents/mlagents/trainers/tests/test_ppo.py
  4. 15
      ml-agents/mlagents/trainers/tests/test_sac.py

15
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.exception import UnityTrainerException
logger = logging.getLogger("mlagents.trainers")

self.load = load
self.seed = seed
self.policy: NNPolicy = None # type: ignore
def _check_param_keys(self):
super()._check_param_keys()
# Check that batch size is greater than sequence length. Else, throw
# an exception.
if (
self.trainer_parameters["sequence_length"]
> self.trainer_parameters["batch_size"]
and self.trainer_parameters["use_recurrent"]
):
raise UnityTrainerException(
"batch_size must be greater than or equal to sequence_length when use_recurrent is True."
)
def _process_trajectory(self, trajectory: Trajectory) -> None:
"""

14
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import Trajectory, SplitObservations
from mlagents.trainers.brain import BrainParameters
from mlagents.trainers.exception import UnityTrainerException
logger = logging.getLogger("mlagents.trainers")

if "save_replay_buffer" in trainer_parameters
else False
)
def _check_param_keys(self):
super()._check_param_keys()
# Check that batch size is greater than sequence length. Else, throw
# an exception.
if (
self.trainer_parameters["sequence_length"]
> self.trainer_parameters["batch_size"]
and self.trainer_parameters["use_recurrent"]
):
raise UnityTrainerException(
"batch_size must be greater than or equal to sequence_length when use_recurrent is True."
)
def save_model(self, name_behavior_id: str) -> None:
"""

23
ml-agents/mlagents/trainers/tests/test_ppo.py


from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.mock_brain import make_brain_parameters
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.tests.test_reward_signals import ( # noqa: F401; pylint: disable=unused-variable
curiosity_dummy_config,
gail_dummy_config,

num_epoch: 5
num_layers: 2
time_horizon: 64
sequence_length: 64
sequence_length: 16
summary_freq: 1000
use_recurrent: false
normalize: true

update_buffer["extrinsic_value_estimates"] = update_buffer["environment_rewards"]
optimizer.update(
update_buffer,
num_sequences=update_buffer.num_experiences // dummy_config["sequence_length"],
num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length,
)

trainer.update_buffer = buffer
trainer._update_policy()
# Make batch length a larger multiple of sequence length
trainer.trainer_parameters["batch_size"] = 128
trainer._update_policy()
# Make batch length a larger non-multiple of sequence length
trainer.trainer_parameters["batch_size"] = 100
trainer._update_policy()
def test_process_trajectory(dummy_config):

policy = mock.Mock()
with pytest.raises(RuntimeError):
trainer.add_policy(brain_params, policy)
def test_bad_config(dummy_config):
brain_params = make_brain_parameters(
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
# Test that we throw an error if we have sequence length greater than batch size
dummy_config["sequence_length"] = 64
dummy_config["batch_size"] = 32
dummy_config["use_recurrent"] = True
with pytest.raises(UnityTrainerException):
_ = PPOTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
if __name__ == "__main__":

15
ml-agents/mlagents/trainers/tests/test_sac.py


from mlagents.trainers.tests import mock_brain as mb
from mlagents.trainers.tests.mock_brain import make_brain_parameters
from mlagents.trainers.tests.test_trajectory import make_fake_trajectory
from mlagents.trainers.exception import UnityTrainerException
@pytest.fixture

assert (
trainer.stats_reporter.get_stats_summaries("Policy/Extrinsic Reward").mean > 0
)
def test_bad_config(dummy_config):
brain_params = make_brain_parameters(
discrete_action=False, visual_inputs=0, vec_obs_size=6
)
# Test that we throw an error if we have sequence length greater than batch size
dummy_config["sequence_length"] = 64
dummy_config["batch_size"] = 32
dummy_config["use_recurrent"] = True
dummy_config["summary_path"] = "./summaries/test_trainer_summary"
dummy_config["model_path"] = "./models/test_trainer_models/TestModel"
with pytest.raises(UnityTrainerException):
_ = SACTrainer(brain_params, 0, dummy_config, True, False, 0, "0")
if __name__ == "__main__":
正在加载...
取消
保存