浏览代码

Fix mypy errors in trainer code. (#3135)

/asymm-envs
GitHub 5 年前
当前提交
c7da0139
共有 4 个文件被更改,包括 10 次插入10 次删除
  1. 6
      .pre-commit-config.yaml
  2. 4
      ml-agents/mlagents/trainers/ppo/trainer.py
  3. 6
      ml-agents/mlagents/trainers/sac/policy.py
  4. 4
      ml-agents/mlagents/trainers/sac/trainer.py

6
.pre-commit-config.yaml


)$
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.750
# Currently mypy may assert after logging one message. To get all the messages at once, change repo and rev to
# repo: https://github.com/chriselion/mypy
# rev: 3d0b6164a9487a6c5cf9d144110b86600fd85e25
# This is a fork with the assert disabled, although precommit has trouble installing it sometimes.
rev: v0.761
hooks:
- id: mypy
name: mypy-ml-agents

4
ml-agents/mlagents/trainers/ppo/trainer.py


self.load = load
self.multi_gpu = multi_gpu
self.seed = seed
self.policy: TFPolicy = None
self.policy: PPOPolicy = None # type: ignore
def process_trajectory(self, trajectory: Trajectory) -> None:
"""

self.__class__.__name__
)
)
if not isinstance(policy, PPOPolicy):
raise RuntimeError("Non-PPOPolicy passed to PPOTrainer.add_policy()")
self.policy = policy
def get_policy(self, name_behavior_id: str) -> TFPolicy:

6
ml-agents/mlagents/trainers/sac/policy.py


import logging
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Mapping
import numpy as np
from mlagents.tf_utils import tf

return update_stats
def update_reward_signals(
self, reward_signal_minibatches: Dict[str, Dict], num_sequences: int
self, reward_signal_minibatches: Mapping[str, Dict], num_sequences: int
) -> Dict[str, float]:
"""
Only update the reward signals.

feed_dict: Dict[tf.Tensor, Any],
update_dict: Dict[str, tf.Tensor],
stats_needed: Dict[str, str],
reward_signal_minibatches: Dict[str, Dict],
reward_signal_minibatches: Mapping[str, Dict],
num_sequences: int,
) -> None:
"""

4
ml-agents/mlagents/trainers/sac/trainer.py


self.check_param_keys()
self.load = load
self.seed = seed
self.policy: TFPolicy = None
self.policy: SACPolicy = None # type: ignore
self.step = 0
self.train_interval = (

self.__class__.__name__
)
)
if not isinstance(policy, SACPolicy):
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
self.policy = policy
def get_policy(self, name_behavior_id: str) -> TFPolicy:

正在加载...
取消
保存