|
|
|
|
|
|
from mlagents.trainers.trajectory import Trajectory |
|
|
|
from mlagents.trainers.behavior_id_utils import BehaviorIdentifiers |
|
|
|
from mlagents.trainers.settings import TrainerSettings, PPOSettings |
|
|
|
|
|
|
|
from mlagents.trainers.stats import CSVWriter |
|
|
|
from mlagents.trainers.stats import StatsPropertyType |
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
self.load = load |
|
|
|
self.seed = seed |
|
|
|
self.policy: NNPolicy = None # type: ignore |
|
|
|
self.csv_writer = CSVWriter("sensitivity") |
|
|
|
|
|
|
|
def _process_trajectory(self, trajectory: Trajectory) -> None: |
|
|
|
""" |
|
|
|
|
|
|
trajectory.next_obs, |
|
|
|
trajectory.done_reached and not trajectory.interrupted, |
|
|
|
) |
|
|
|
|
|
|
|
saliencies = self.optimizer.get_saliency(agent_buffer_trajectory) |
|
|
|
self._stats_reporter.add_property(StatsPropertyType.SALIENCY, saliencies) |
|
|
|
|
|
|
|
for name, v in value_estimates.items(): |
|
|
|
agent_buffer_trajectory[f"{name}_value_estimates"].extend(v) |
|
|
|
self._stats_reporter.add_stat( |
|
|
|
|
|
|
) |
|
|
|
num_epoch = self.hyperparameters.num_epoch |
|
|
|
batch_update_stats = defaultdict(list) |
|
|
|
sensitivities = self.optimizer.compute_input_sensitivity( |
|
|
|
self.update_buffer, self.policy.sequence_length |
|
|
|
) |
|
|
|
self.csv_writer.write_stats("sensitivity", sensitivities, self.step) |
|
|
|
|
|
|
|
for _ in range(num_epoch): |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|