浏览代码

ignoring commit checks but write to csv

/sensitivity
Andrew Cohen 5 年前
当前提交
23b84dea
共有 2 个文件被更改,包括 9 次插入3 次删除
  1. 6
      ml-agents/mlagents/trainers/ppo/optimizer.py
  2. 6
      ml-agents/mlagents/trainers/ppo/trainer.py

6
ml-agents/mlagents/trainers/ppo/optimizer.py


from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.stats import StatsSummary
class PPOOptimizer(TFOptimizer):
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]):

"""
# Create the graph here to give more granular control of the TF graph to the Optimizer.
policy.create_tf_graph()
with policy.graph.as_default():
with tf.variable_scope("optimizer/"):
super().__init__(policy, trainer_params)

update_stats[stat_name] = update_vals[update_name]
return update_stats
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> None:
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> Dict[int, float]:
out = dict((obs,StatsSummary(grad, 0.0, 0.0)) for obs, grad in enumerate(sens))
return out
def _construct_feed_dict(
self, mini_batch: AgentBuffer, num_sequences: int

6
ml-agents/mlagents/trainers/ppo/trainer.py


from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers
from mlagents.trainers.stats import CSVWriter
logger = get_logger(__name__)

self.load = load
self.seed = seed
self.policy: NNPolicy = None # type: ignore
self.csv_writer = CSVWriter("sensitivity")
def _check_param_keys(self):
super()._check_param_keys()

)
num_epoch = self.trainer_parameters["num_epoch"]
batch_update_stats = defaultdict(list)
self.optimizer.compute_input_sensitivity(
sensitivities = self.optimizer.compute_input_sensitivity(
self.csv_writer.write_stats("sensitivity", sensitivities, self.step)
for _ in range(num_epoch):
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length)
buffer = self.update_buffer

正在加载...
取消
保存