|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
from mlagents.trainers.torch.networks import ValueNetwork |
|
|
|
from mlagents.trainers.torch.components.bc.module import BCModule |
|
|
|
from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TorchOptimizer(Optimizer): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
policy: TorchPolicy, |
|
|
|
critic: ValueNetwork, |
|
|
|
trainer_settings: TrainerSettings, |
|
|
|
): |
|
|
|
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|
|
|
self.critic = critic |
|
|
|
self.trainer_settings = trainer_settings |
|
|
|
self.update_dict: Dict[str, torch.Tensor] = {} |
|
|
|
self.value_heads: Dict[str, torch.Tensor] = {} |
|
|
|
|
|
|
default_batch_size=trainer_settings.hyperparameters.batch_size, |
|
|
|
default_num_epoch=3, |
|
|
|
) |
|
|
|
|
|
|
|
@property |
|
|
|
def critic(self): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: |
|
|
|
pass |
|
|
|