|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
from mlagents.trainers.trajectory import ObsUtil |
|
|
|
from mlagents.trainers.torch.networks import ValueNetwork |
|
|
|
from mlagents.trainers.torch.components.bc.module import BCModule |
|
|
|
from mlagents.trainers.torch.components.reward_providers import create_reward_provider |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TorchOptimizer(Optimizer): |
|
|
|
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
policy: TorchPolicy, |
|
|
|
critic: ValueNetwork, |
|
|
|
trainer_settings: TrainerSettings, |
|
|
|
): |
|
|
|
self.critic = critic |
|
|
|
self.trainer_settings = trainer_settings |
|
|
|
self.update_dict: Dict[str, torch.Tensor] = {} |
|
|
|
self.value_heads: Dict[str, torch.Tensor] = {} |
|
|
|