浏览代码

buffer split for SAC

/develop-newnormalization
Ervin Teng 5 年前
当前提交
e5459c49
共有 1 个文件被更改,包括 22 次插入24 次删除
  1. 46
      ml-agents/mlagents/trainers/sac/trainer.py

46
ml-agents/mlagents/trainers/sac/trainer.py


)
LOGGER.debug(
"Loaded update buffer with {} sequences".format(
len(self.training_buffer.update_buffer["actions"])
len(self.update_buffer["actions"])
)
)

filename = os.path.join(self.policy.model_path, "last_replay_buffer.hdf5")
LOGGER.info("Saving Experience Replay Buffer to {}".format(filename))
with open(filename, "wb") as file_object:
self.training_buffer.update_buffer.save_to_file(file_object)
self.update_buffer.save_to_file(file_object)
def load_replay_buffer(self) -> None:
"""

LOGGER.info("Loading Experience Replay Buffer from {}".format(filename))
with open(filename, "rb+") as file_object:
self.training_buffer.update_buffer.load_from_file(file_object)
self.update_buffer.load_from_file(file_object)
len(self.training_buffer.update_buffer["actions"])
len(self.update_buffer["actions"])
)
)

Takes the output of the last action and store it into the training buffer.
"""
actions = take_action_outputs["action"]
self.training_buffer[agent_id]["actions"].append(actions[agent_idx])
self.processing_buffer[agent_id]["actions"].append(actions[agent_idx])
def add_rewards_outputs(
self,

"""
Takes the value output of the last action and store it into the training buffer.
"""
self.training_buffer[agent_id]["environment_rewards"].append(
self.processing_buffer[agent_id]["environment_rewards"].append(
rewards_out.environment[agent_next_idx]
)

if self.is_training:
self.policy.update_normalization(next_info.vector_observations)
for l in range(len(next_info.agents)):
agent_actions = self.training_buffer[next_info.agents[l]]["actions"]
agent_actions = self.processing_buffer[next_info.agents[l]]["actions"]
if (
next_info.local_done[l]
or len(agent_actions) >= self.trainer_parameters["time_horizon"]

# Bootstrap using last brain info. Set last element to duplicate obs and remove dones.
if next_info.max_reached[l]:
bootstrapping_info = self.training_buffer[agent_id].last_brain_info
bootstrapping_info = self.processing_buffer[
agent_id
].last_brain_info
self.training_buffer[agent_id]["next_visual_obs%d" % i][
self.processing_buffer[agent_id]["next_visual_obs%d" % i][
self.training_buffer[agent_id]["next_vector_in"][
self.processing_buffer[agent_id]["next_vector_in"][
self.training_buffer[agent_id]["done"][-1] = False
self.processing_buffer[agent_id]["done"][-1] = False
self.training_buffer.append_update_buffer(
self.processing_buffer.append_update_buffer(
self.update_buffer,
agent_id,
batch_size=None,

self.training_buffer[agent_id].reset_agent()
self.processing_buffer[agent_id].reset_agent()
if next_info.local_done[l]:
self.stats["Environment/Episode Length"].append(
self.episode_steps.get(agent_id, 0)

:return: A boolean corresponding to whether or not update_model() can be run
"""
return (
len(self.training_buffer.update_buffer["actions"])
>= self.trainer_parameters["batch_size"]
len(self.update_buffer["actions"]) >= self.trainer_parameters["batch_size"]
and self.step >= self.trainer_parameters["buffer_init_steps"]
)

"""
if self.step % self.train_interval == 0:
self.trainer_metrics.start_policy_update_timer(
number_experiences=len(self.training_buffer.update_buffer["actions"]),
number_experiences=len(self.update_buffer["actions"]),
mean_return=float(np.mean(self.cumulative_returns_since_policy_update)),
)
self.update_sac_policy()

batch_update_stats: Dict[str, list] = defaultdict(list)
for _ in range(num_updates):
LOGGER.debug("Updating SAC policy at step {}".format(self.step))
buffer = self.training_buffer.update_buffer
buffer = self.update_buffer
len(self.training_buffer.update_buffer["actions"])
len(self.update_buffer["actions"])
>= self.trainer_parameters["batch_size"]
):
sampled_minibatch = buffer.sample_mini_batch(

# Truncate update buffer if neccessary. Truncate more than we need to to avoid truncating
# a large buffer at each update.
if (
len(self.training_buffer.update_buffer["actions"])
> self.trainer_parameters["buffer_size"]
):
self.training_buffer.truncate_update_buffer(
if len(self.update_buffer["actions"]) > self.trainer_parameters["buffer_size"]:
self.update_buffer.truncate(
int(self.trainer_parameters["buffer_size"] * BUFFER_TRUNCATE_PERCENT)
)

N times, then the reward signals are updated N times. Normally, the reward signal
and policy are updated in parallel.
"""
buffer = self.training_buffer.update_buffer
buffer = self.update_buffer
num_updates = self.reward_signal_updates_per_train
n_sequences = max(
int(self.trainer_parameters["batch_size"] / self.policy.sequence_length), 1

正在加载...
取消
保存