浏览代码

small mypy cleanup (#2637)

* small mypy cleanup

* sac cleanup

* types for ppo policy init
/0.10.1
Ervin Teng 5 年前
当前提交
024e3677
共有 5 个文件被更改,包括 23 次插入26 次删除
  1. 20
      ml-agents/mlagents/trainers/ppo/policy.py
  2. 5
      ml-agents/mlagents/trainers/sac/policy.py
  3. 3
      ml-agents/mlagents/trainers/sac/trainer.py
  4. 19
      ml-agents/mlagents/trainers/trainer_metrics.py
  5. 2
      ml-agents/mlagents/trainers/trainer_util.py

20
ml-agents/mlagents/trainers/ppo/policy.py


import logging
import numpy as np
from typing import Any, Dict
from typing import Any, Dict, Optional
from mlagents.envs.brain import BrainInfo
from mlagents.envs.brain import BrainInfo, BrainParameters
from mlagents.trainers.models import EncoderType, LearningRateSchedule
from mlagents.trainers.ppo.models import PPOModel
from mlagents.trainers.tf_policy import TFPolicy

class PPOPolicy(TFPolicy):
def __init__(self, seed, brain, trainer_params, is_training, load):
def __init__(
self,
seed: int,
brain: BrainParameters,
trainer_params: Dict[str, Any],
is_training: bool,
load: bool,
):
"""
Policy for Proximal Policy Optimization Networks.
:param seed: Random seed.

super().__init__(seed, brain, trainer_params)
reward_signal_configs = trainer_params["reward_signals"]
self.inference_dict = {}
self.update_dict = {}
self.inference_dict: Dict[str, tf.Tensor] = {}
self.update_dict: Dict[str, tf.Tensor] = {}
self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",

self.create_reward_signals(reward_signal_configs)
with self.graph.as_default():
self.bc_module: Optional[BCModule] = None
# Create pretrainer if needed
if "pretraining" in trainer_params:
BCModule.check_config(trainer_params["pretraining"])

default_num_epoch=trainer_params["num_epoch"],
**trainer_params["pretraining"],
)
else:
self.bc_module = None
if load:
self._load_graph()

5
ml-agents/mlagents/trainers/sac/policy.py


import logging
from typing import Dict, Any
from typing import Dict, Any, Optional
import numpy as np
import tensorflow as tf

with self.graph.as_default():
# Create pretrainer if needed
self.bc_module: Optional[BCModule] = None
if "pretraining" in trainer_params:
BCModule.check_config(trainer_params["pretraining"])
self.bc_module = BCModule(

"Pretraining: Samples Per Update is not a valid setting for SAC."
)
self.bc_module.samples_per_update = 1
else:
self.bc_module = None
if load:
self._load_graph()

3
ml-agents/mlagents/trainers/sac/trainer.py


from mlagents.envs.brain import AllBrainInfo
from mlagents.envs.action_info import ActionInfoOutputs
from mlagents.envs.timers import timed
from mlagents.trainers.buffer import Buffer
from mlagents.trainers.sac.policy import SACPolicy
from mlagents.trainers.rl_trainer import RLTrainer, AllRewardsOutput

with open(filename, "wb") as file_object:
self.training_buffer.update_buffer.save_to_file(file_object)
def load_replay_buffer(self) -> Buffer:
def load_replay_buffer(self) -> None:
"""
Loads the last saved replay buffer from a file.
"""

19
ml-agents/mlagents/trainers/trainer_metrics.py


self.delta_policy_update = 0
delta_train_start = time() - self.time_training_start
LOGGER.debug(
" Policy Update Training Metrics for {}: "
"\n\t\tTime to update Policy: {:0.3f} s \n"
"\t\tTime elapsed since training: {:0.3f} s \n"
"\t\tTime for experience collection: {:0.3f} s \n"
"\t\tBuffer Length: {} \n"
"\t\tReturns : {:0.3f}\n".format(
self.brain_name,
self.delta_policy_update,
delta_train_start,
self.delta_last_experience_collection,
self.last_buffer_length,
self.last_mean_return,
)
f" Policy Update Training Metrics for {self.brain_name}: "
f"\n\t\tTime to update Policy: {self.delta_policy_update:0.3f} s \n"
f"\t\tTime elapsed since training: {delta_train_start:0.3f} s \n"
f"\t\tTime for experience collection: {(self.delta_last_experience_collection or 0):0.3f} s \n"
f"\t\tBuffer Length: {self.last_buffer_length or 0} \n"
f"\t\tReturns : {(self.last_mean_return or 0):0.3f}\n"
)
self._add_row(delta_train_start)

2
ml-agents/mlagents/trainers/trainer_util.py


:param multi_gpu: Whether to use multi-GPU training
:return:
"""
trainers = {}
trainers: Dict[str, Trainer] = {}
trainer_parameters_dict = {}
for brain_name in external_brains:
trainer_parameters = trainer_config["default"].copy()

正在加载...
取消
保存