|
|
|
|
|
|
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer |
|
|
|
from mlagents.trainers.buffer import AgentBuffer |
|
|
|
|
|
|
|
from mlagents.trainers.stats import StatsSummary |
|
|
|
|
|
|
|
class PPOOptimizer(TFOptimizer): |
|
|
|
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|
|
|
|
|
|
""" |
|
|
|
# Create the graph here to give more granular control of the TF graph to the Optimizer. |
|
|
|
policy.create_tf_graph() |
|
|
|
|
|
|
|
with policy.graph.as_default(): |
|
|
|
with tf.variable_scope("optimizer/"): |
|
|
|
super().__init__(policy, trainer_params) |
|
|
|
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> None: |
|
|
|
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> Dict[int, float]: |
|
|
|
out = dict((obs,StatsSummary(grad, 0.0, 0.0)) for obs, grad in enumerate(sens)) |
|
|
|
return out |
|
|
|
|
|
|
|
def _construct_feed_dict( |
|
|
|
self, mini_batch: AgentBuffer, num_sequences: int |
|
|
|