|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
policy: TFPolicy, |
|
|
|
policy_model: LearningModel, |
|
|
|
strength: float, |
|
|
|
gamma: float, |
|
|
|
encoding_size: int = 128, |
|
|
|
|
|
|
:param encoding_size: The size of the hidden encoding layer for the ICM |
|
|
|
:param learning_rate: The learning rate for the ICM. |
|
|
|
""" |
|
|
|
super().__init__(policy, policy_model, strength, gamma) |
|
|
|
super().__init__(policy, strength, gamma) |
|
|
|
policy_model, encoding_size=encoding_size, learning_rate=learning_rate |
|
|
|
policy, encoding_size=encoding_size, learning_rate=learning_rate |
|
|
|
) |
|
|
|
self.use_terminal_states = False |
|
|
|
self.update_dict = { |
|
|
|
|
|
|
|
|
|
|
def prepare_update( |
|
|
|
self, |
|
|
|
policy_model: LearningModel, |
|
|
|
policy: TFPolicy, |
|
|
|
mini_batch: Dict[str, np.ndarray], |
|
|
|
num_sequences: int, |
|
|
|
) -> Dict[tf.Tensor, Any]: |
|
|
|
|
|
|
:return: Feed_dict needed for update. |
|
|
|
""" |
|
|
|
feed_dict = { |
|
|
|
policy_model.batch_size: num_sequences, |
|
|
|
policy_model.sequence_length: self.policy.sequence_length, |
|
|
|
policy_model.mask_input: mini_batch["masks"], |
|
|
|
policy.batch_size: num_sequences, |
|
|
|
policy.sequence_length: self.policy.sequence_length, |
|
|
|
policy.mask_input: mini_batch["masks"], |
|
|
|
feed_dict[policy_model.selected_actions] = mini_batch["actions"] |
|
|
|
feed_dict[policy.selected_actions] = mini_batch["actions"] |
|
|
|
feed_dict[policy_model.action_holder] = mini_batch["actions"] |
|
|
|
feed_dict[policy.action_holder] = mini_batch["actions"] |
|
|
|
feed_dict[policy_model.vector_in] = mini_batch["vector_obs"] |
|
|
|
feed_dict[policy.vector_in] = mini_batch["vector_obs"] |
|
|
|
if policy_model.vis_obs_size > 0: |
|
|
|
for i, vis_in in enumerate(policy_model.visual_in): |
|
|
|
if policy.vis_obs_size > 0: |
|
|
|
for i, vis_in in enumerate(policy.visual_in): |
|
|
|
feed_dict[vis_in] = mini_batch["visual_obs%d" % i] |
|
|
|
for i, next_vis_in in enumerate(self.model.next_visual_in): |
|
|
|
feed_dict[next_vis_in] = mini_batch["next_visual_obs%d" % i] |
|
|
|