|
|
|
|
|
|
from mlagents_envs.base_env import DecisionSteps |
|
|
|
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.components.bc.module_torch import TorchBCModule |
|
|
|
from mlagents.trainers.torch.components.bc.module import BCModule |
|
|
|
from mlagents.trainers.components.reward_signals.extrinsic.signal import ( |
|
|
|
ExtrinsicRewardSignal, |
|
|
|
) |
|
|
|
|
|
|
self.memory_out: torch.Tensor = None |
|
|
|
self.m_size: int = 0 |
|
|
|
self.global_step = torch.tensor(0) |
|
|
|
self.bc_module: Optional[TorchBCModule] = None |
|
|
|
self.bc_module: Optional[BCModule] = None |
|
|
|
self.bc_module = TorchBCModule( |
|
|
|
self.bc_module = BCModule( |
|
|
|
default_batch_size=trainer_settings.hyperparameters.batch_size, |
|
|
|
default_num_epoch=3, |
|
|
|
) |
|
|
|
|
|
|
|
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|
|
|