|
|
|
|
|
|
|
|
|
|
# Add some stuff to inference dict from optimizer |
|
|
|
self.policy.inference_dict["learning_rate"] = self.learning_rate |
|
|
|
self.policy.initialize_or_load() |
|
|
|
|
|
|
|
def create_cc_critic( |
|
|
|
self, h_size: int, num_layers: int, vis_encode_type: EncoderType |
|
|
|
|
|
|
"{}_value_estimates".format(name) |
|
|
|
] |
|
|
|
|
|
|
|
if "actions_pre" in mini_batch: |
|
|
|
if self.policy.output_pre is not None and "actions_pre" in mini_batch: |
|
|
|
feed_dict[self.policy.output_pre] = mini_batch["actions_pre"] |
|
|
|
else: |
|
|
|
feed_dict[self.policy.action_holder] = mini_batch["actions"] |
|
|
|
|
|
|
for i, _ in enumerate(self.policy.visual_in): |
|
|
|
feed_dict[self.policy.visual_in[i]] = mini_batch["visual_obs%d" % i] |
|
|
|
if self.policy.use_recurrent: |
|
|
|
feed_dict[self.policy.memory_in] = self._make_zero_mem( |
|
|
|
self.policy.m_size, mini_batch.num_experiences |
|
|
|
) |
|
|
|
feed_dict[self.policy.memory_in] = [ |
|
|
|
mini_batch["memory"][i] |
|
|
|
for i in range( |
|
|
|
0, len(mini_batch["memory"]), self.policy.sequence_length |
|
|
|
) |
|
|
|
] |
|
|
|
feed_dict[self.memory_in] = self._make_zero_mem( |
|
|
|
self.m_size, mini_batch.num_experiences |
|
|
|
) |