|
|
|
|
|
|
self.use_op_buffer = self.hyperparameters.use_op_buffer |
|
|
|
self.conv_thres = self.hyperparameters.conv_thres |
|
|
|
self.use_bisim = self.hyperparameters.use_bisim |
|
|
|
self.num_check = 0 |
|
|
|
self.num_update = 0 |
|
|
|
self.train_model = True |
|
|
|
self.old_loss = np.inf |
|
|
|
print("The current algorithm is PPO Transfer") |
|
|
|
|
|
|
agent_buffer_trajectory = trajectory.to_agentbuffer() |
|
|
|
# Update the normalization |
|
|
|
if self.is_training: |
|
|
|
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"]) |
|
|
|
self.policy.update_normalization(agent_buffer_trajectory["vector_obs"], agent_buffer_trajectory["next_vector_in"]) |
|
|
|
|
|
|
|
# Get all value estimates |
|
|
|
value_estimates, value_next = self.optimizer.get_trajectory_value_estimates( |
|
|
|
|
|
|
Returns whether or not the trainer has enough elements to run update model |
|
|
|
:return: A boolean corresponding to whether or not update_model() can be run |
|
|
|
""" |
|
|
|
# if self.train_model and self.use_op_buffer: |
|
|
|
# size_of_buffer = self.off_policy_buffer.num_experiences |
|
|
|
# self.num_check += 1 |
|
|
|
# if self.num_check % 50 == 0 and size_of_buffer >= self.hyperparameters.buffer_size: |
|
|
|
# return True |
|
|
|
# else: |
|
|
|
# return False |
|
|
|
# else: |
|
|
|
size_of_buffer = self.update_buffer.num_experiences |
|
|
|
return size_of_buffer > self.hyperparameters.buffer_size |
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
if self.use_bisim: |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer1 = copy.deepcopy(self.update_buffer) |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer2 = copy.deepcopy(self.update_buffer) |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer3 = copy.deepcopy(self.update_buffer) |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_encoder( |
|
|
|
buffer1.make_mini_batch(i, i + batch_size), |
|
|
|
buffer2.make_mini_batch(i, i + batch_size), |
|
|
|
buffer3.make_mini_batch(i, i + batch_size), |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
# if self.use_bisim: |
|
|
|
# self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
# buffer1 = copy.deepcopy(self.update_buffer) |
|
|
|
# self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
# buffer2 = copy.deepcopy(self.update_buffer) |
|
|
|
# max_num_batch = buffer_length // batch_size |
|
|
|
# for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
# update_stats = self.optimizer.update_encoder( |
|
|
|
# buffer1.make_mini_batch(i, i + batch_size), |
|
|
|
# buffer2.make_mini_batch(i, i + batch_size), |
|
|
|
# ) |
|
|
|
# for stat_name, value in update_stats.items(): |
|
|
|
# batch_update_stats[stat_name].append(value) |
|
|
|
else: |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
if self.use_bisim: |
|
|
|
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer1 = copy.deepcopy(self.off_policy_buffer) |
|
|
|
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer2 = copy.deepcopy(self.off_policy_buffer) |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_encoder( |
|
|
|
buffer1.make_mini_batch(i, i + batch_size), |
|
|
|
buffer2.make_mini_batch(i, i + batch_size), |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
if stat == "Losses/Model Loss": # and np.mean(stat_list) < 0.01: |
|
|
|
|
|
|
self._stats_reporter.add_stat(stat, val) |
|
|
|
|
|
|
|
# self.off_policy_buffer.reset_agent() |
|
|
|
if self.off_policy_buffer.num_experiences > 10 * self.hyperparameters.buffer_size: |
|
|
|
print("truncate") |
|
|
|
self.off_policy_buffer.truncate( |
|
|
|
int(5 * self.hyperparameters.buffer_size) |
|
|
|
) |
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
def _update_policy_new(self): |
|
|
|
""" |
|
|
|
Uses demonstration_buffer to update the policy. |
|
|
|
The reward signal generators must be updated in this method at their own pace. |
|
|
|
""" |
|
|
|
|
|
|
|
update_buffer_length = self.update_buffer.num_experiences |
|
|
|
op_buffer_length = self.off_policy_buffer.num_experiences |
|
|
|
self.cumulative_returns_since_policy_update.clear() |
|
|
|
|
|
|
|
# Make sure batch_size is a multiple of sequence length. During training, we |
|
|
|
# will need to reshape the data into a batch_size x sequence_length tensor. |
|
|
|
batch_size = ( |
|
|
|
self.hyperparameters.batch_size |
|
|
|
- self.hyperparameters.batch_size % self.policy.sequence_length |
|
|
|
) |
|
|
|
# Make sure there is at least one sequence |
|
|
|
batch_size = max(batch_size, self.policy.sequence_length) |
|
|
|
|
|
|
|
n_sequences = max( |
|
|
|
int(self.hyperparameters.batch_size / self.policy.sequence_length), 1 |
|
|
|
) |
|
|
|
|
|
|
|
advantages = self.update_buffer["advantages"].get_batch() |
|
|
|
self.update_buffer["advantages"].set( |
|
|
|
(advantages - advantages.mean()) / (advantages.std() + 1e-10) |
|
|
|
) |
|
|
|
num_epoch = self.hyperparameters.num_epoch |
|
|
|
batch_update_stats = defaultdict(list) |
|
|
|
if self.use_iealter: |
|
|
|
for _ in range(num_epoch): |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
max_num_batch = update_buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_part( |
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences, "policy" |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
for _ in range(num_epoch): |
|
|
|
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.off_policy_buffer |
|
|
|
max_num_batch = update_buffer_length // batch_size # update with as much data as the policy has |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_part( |
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences, "model" |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
if self.use_bisim: |
|
|
|
for _ in range(num_epoch): |
|
|
|
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer1 = copy.deepcopy(self.off_policy_buffer) |
|
|
|
self.off_policy_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer2 = copy.deepcopy(self.off_policy_buffer) |
|
|
|
max_num_batch = update_buffer_length // batch_size # update with as much data as the policy has |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update_encoder( |
|
|
|
buffer1.make_mini_batch(i, i + batch_size), |
|
|
|
buffer2.make_mini_batch(i, i + batch_size), |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
else: |
|
|
|
for _ in range(num_epoch): |
|
|
|
self.update_buffer.shuffle(sequence_length=self.policy.sequence_length) |
|
|
|
buffer = self.update_buffer |
|
|
|
max_num_batch = buffer_length // batch_size |
|
|
|
for i in range(0, max_num_batch * batch_size, batch_size): |
|
|
|
update_stats = self.optimizer.update( |
|
|
|
buffer.make_mini_batch(i, i + batch_size), n_sequences |
|
|
|
) |
|
|
|
for stat_name, value in update_stats.items(): |
|
|
|
batch_update_stats[stat_name].append(value) |
|
|
|
for stat, stat_list in batch_update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, np.mean(stat_list)) |
|
|
|
|
|
|
|
if self.optimizer.bc_module: |
|
|
|
update_stats = self.optimizer.bc_module.update() |
|
|
|
for stat, val in update_stats.items(): |
|
|
|
self._stats_reporter.add_stat(stat, val) |
|
|
|
self._clear_update_buffer() |
|
|
|
|
|
|
|
if self.off_policy_buffer.num_experiences > 10 * self.hyperparameters.buffer_size: |
|
|
|
print("truncate") |
|
|
|
self.off_policy_buffer.truncate( |
|
|
|