浏览代码

revert bc default batch/epoch

/develop/add-fire/bc
Andrew Cohen 4 年前
当前提交
0a7444f9
共有 4 个文件被更改,包括 14 次插入5 次删除
  1. 2
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 6
      ml-agents/mlagents/trainers/settings.py
  3. 2
      ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py
  4. 9
      ml-agents/mlagents/trainers/torch/components/bc/module.py

2
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


self.policy,
trainer_settings.behavioral_cloning,
policy_learning_rate=trainer_settings.hyperparameters.learning_rate,
default_batch_size=trainer_settings.hyperparameters.batch_size,
default_num_epoch=3,
)
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:

6
ml-agents/mlagents/trainers/settings.py


steps: int = 0
strength: float = 1.0
samples_per_update: int = 0
num_epoch: int = 3
batch_size: int = 1024
# Setting either of these to None will allow the Optimizer
# to decide these parameters, based on Trainer hyperparams
num_epoch: Optional[int] = None
batch_size: Optional[int] = None
@attr.s(auto_attribs=True)

2
ml-agents/mlagents/trainers/tests/torch/test_bcmodule.py


policy,
settings=bc_settings,
policy_learning_rate=trainer_config.hyperparameters.learning_rate,
default_batch_size=trainer_config.hyperparameters.batch_size,
default_num_epoch=3,
)
return bc_module

9
ml-agents/mlagents/trainers/torch/components/bc/module.py


policy: TorchPolicy,
settings: BehavioralCloningSettings,
policy_learning_rate: float,
default_batch_size: int,
default_num_epoch: int,
):
"""
A BC trainer that can be used inline with RL.

)
params = self.policy.actor_critic.parameters()
self.optimizer = torch.optim.Adam(params, lr=self.current_lr)
self.batch_size = settings.batch_size
self.num_epoch = settings.num_epoch
self.batch_size = (
settings.batch_size if settings.batch_size else default_batch_size
)
self.num_epoch = settings.num_epoch if settings.num_epoch else default_num_epoch
self.n_sequences = max(
min(self.batch_size, self.demonstration_buffer.num_experiences)
// policy.sequence_length,

正在加载...
取消
保存