|
|
|
|
|
|
num_epoch = self.hyperparameters.num_epoch |
|
|
|
batch_update_stats = defaultdict(list) |
|
|
|
for _ in range(num_epoch): |
|
|
|
# print("update epoch") |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_part( |
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences, "policy" |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
# if self.train_model: |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences, "model" |
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences, "model_only" |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
|
|
|
# for stat_name, value in update_stats.items(): |
|
|
|
# batch_update_stats[stat_name].append(value) |
|
|
|
|
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_part( |
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences, "policy" |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
if self.use_bisim: |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer1 = copy.deepcopy(self.update_buffer) |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer2 = copy.deepcopy(self.update_buffer) |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_encoder( |
|
|
|
buffer1.make_mini_batch(i, i + batch_size), |
|
|
|
buffer2.make_mini_batch(i, i + batch_size), |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
# if self.use_bisim: |
|
|
|
# self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
# buffer1 = copy.deepcopy(self.update_buffer) |
|
|
|
# self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
# buffer2 = copy.deepcopy(self.update_buffer) |
|
|
|
# max_num_batch = buffer_length // batch_size |
|
|
|
# for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
# update_stats = self.optimizer.update_encoder( |
|
|
|
# buffer1.make_mini_batch(i, i + batch_size), |
|
|
|
# buffer2.make_mini_batch(i, i + batch_size), |
|
|
|
# ) |
|
|
|
# for stat_name, value in update_stats.items(): |
|
|
|
# batch_update_stats[stat_name].append(value) |
|
|
|
else: |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
|
|
|
num_epoch = self.hyperparameters.num_epoch |
|
|
|
batch_update_stats = defaultdict(list) |
|
|
|
for _ in range(num_epoch): |
|
|
|
# print("model epoch") |
|
|
|
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.off_policy_buffer |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
# self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
# buffer = self.off_policy_buffer |
|
|
|
# max_num_batch = buffer_length // batch_size |
|
|
|
# for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
# update_stats = self.optimizer.update_part( |
|
|
|
# buffer.make_mini_batch(i, i + batch_size), n_sequences, "model_only" |
|
|
|
# ) |
|
|
|
# for stat_name, value in update_stats.items(): |
|
|
|
# batch_update_stats[stat_name].append(value) |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
if stat == "Losses/Model Loss": # and np.mean(stat_list) < 0.01: |
|
|
|