|
|
|
|
|
|
|
|
|
|
from mlagents.trainers.stats import StatsSummary |
|
|
|
|
|
|
|
|
|
|
|
class PPOOptimizer(TFOptimizer): |
|
|
|
def __init__(self, policy: TFPolicy, trainer_params: Dict[str, Any]): |
|
|
|
""" |
|
|
|
|
|
|
update_stats[stat_name] = update_vals[update_name] |
|
|
|
return update_stats |
|
|
|
|
|
|
|
def compute_input_sensitivity(self, batch: AgentBuffer, num_sequences: int) -> Dict[int, float]: |
|
|
|
def compute_input_sensitivity( |
|
|
|
self, batch: AgentBuffer, num_sequences: int |
|
|
|
) -> Dict[int, float]: |
|
|
|
out = dict((obs,StatsSummary(grad, 0.0, 0.0)) for obs, grad in enumerate(sens)) |
|
|
|
out = dict((obs, StatsSummary(grad, 0.0, 0.0)) for obs, grad in enumerate(sens)) |
|
|
|
for obs, grad in sorted(enumerate(sens), reverse=True, key=lambda x: x[1]): |
|
|
|
print("Observation {} has relevance {}".format(obs, grad)) |
|
|
|
return out |
|
|
|