浏览代码

fix precommit errors

/develop/action-slice
Andrew Cohen 4 年前
当前提交
3aec18a1
共有 2 个文件被更改,包括 10 次插入4 次删除
  1. 9
      ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
  2. 5
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

9
ml-agents/mlagents/trainers/optimizer/torch_optimizer.py


from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.networks import ValueNetwork
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider

class TorchOptimizer(Optimizer):
def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
def __init__(
self,
policy: TorchPolicy,
critic: ValueNetwork,
trainer_settings: TrainerSettings,
):
self.critic = critic
self.trainer_settings = trainer_settings
self.update_dict: Dict[str, torch.Tensor] = {}
self.value_heads: Dict[str, torch.Tensor] = {}

5
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
super().__init__(policy, trainer_settings)
self.critic = ValueNetwork(
critic = ValueNetwork(
super().__init__(policy, critic, trainer_settings)
params = list(self.policy.actor.parameters()) + list(self.critic.parameters())
self.hyperparameters: PPOSettings = cast(

正在加载...
取消
保存