|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mlagents.envs.brain import BrainParameters, BrainInfo |
|
|
|
from mlagents.envs.action_info import ActionInfoOutputs |
|
|
|
from mlagents.trainers.brain import BrainParameters, BrainInfo |
|
|
|
from mlagents.trainers.action_info import ActionInfoOutputs |
|
|
|
from mlagents.envs.timers import timed |
|
|
|
from mlagents.trainers.tf_policy import TFPolicy |
|
|
|
from mlagents.trainers.sac.policy import SACPolicy |
|
|
|
|
|
|
else False |
|
|
|
) |
|
|
|
|
|
|
|
# Load the replay buffer if load |
|
|
|
if load and self.checkpoint_replay_buffer: |
|
|
|
try: |
|
|
|
self.load_replay_buffer() |
|
|
|
except (AttributeError, FileNotFoundError): |
|
|
|
LOGGER.warning( |
|
|
|
"Replay buffer was unable to load, starting from scratch." |
|
|
|
) |
|
|
|
LOGGER.debug( |
|
|
|
"Loaded update buffer with {} sequences".format( |
|
|
|
self.update_buffer.num_experiences |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
self.episode_steps = {} |
|
|
|
|
|
|
|
def save_model(self) -> None: |
|
|
|
|
|
|
""" |
|
|
|
Save the training buffer's update buffer to a pickle file. |
|
|
|
""" |
|
|
|
filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5") |
|
|
|
filename = os.path.join( |
|
|
|
self.trainer_parameters["model_path"], "last_replay_buffer.hdf5" |
|
|
|
) |
|
|
|
LOGGER.info("Saving Experience Replay Buffer to {}".format(filename)) |
|
|
|
with open(filename, "wb") as file_object: |
|
|
|
self.update_buffer.save_to_file(file_object) |
|
|
|
|
|
|
Loads the last saved replay buffer from a file. |
|
|
|
""" |
|
|
|
filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5") |
|
|
|
filename = os.path.join( |
|
|
|
self.trainer_parameters["model_path"], "last_replay_buffer.hdf5" |
|
|
|
) |
|
|
|
LOGGER.info("Loading Experience Replay Buffer from {}".format(filename)) |
|
|
|
with open(filename, "rb+") as file_object: |
|
|
|
self.update_buffer.load_from_file(file_object) |
|
|
|
|
|
|
for _reward_signal in policy.reward_signals.keys(): |
|
|
|
self.collected_rewards[_reward_signal] = {} |
|
|
|
|
|
|
|
# Load the replay buffer if load |
|
|
|
if self.load and self.checkpoint_replay_buffer: |
|
|
|
try: |
|
|
|
self.load_replay_buffer() |
|
|
|
except (AttributeError, FileNotFoundError): |
|
|
|
LOGGER.warning( |
|
|
|
"Replay buffer was unable to load, starting from scratch." |
|
|
|
) |
|
|
|
LOGGER.debug( |
|
|
|
"Loaded update buffer with {} sequences".format( |
|
|
|
self.update_buffer.num_experiences |
|
|
|
) |
|
|
|
) |
|
|
|
return policy |
|
|
|
|
|
|
|
def update_sac_policy(self) -> None: |
|
|
|