|
|
|
|
|
|
) -> tf.Tensor: |
|
|
|
|
|
|
|
if num_layers < 0: |
|
|
|
return self.current_action |
|
|
|
return action |
|
|
|
|
|
|
|
hidden_stream = ModelUtils.create_vector_observation_encoder( |
|
|
|
action, |
|
|
|
|
|
|
scope=f"main_graph", |
|
|
|
reuse=False, |
|
|
|
) |
|
|
|
# hidden_policy = ModelUtils.create_vector_observation_encoder( |
|
|
|
# self.processed_vector_in, |
|
|
|
# h_size, |
|
|
|
# ModelUtils.swish, |
|
|
|
# num_layers, |
|
|
|
# scope=f"main_graph", |
|
|
|
# reuse=False, |
|
|
|
# ) |
|
|
|
distribution = GaussianDistribution( |
|
|
|
hidden_policy, |
|
|
|
self.act_size, |
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
if separate_train: |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
# encoded_action = tf.stop_gradient(encoded_action) |
|
|
|
|
|
|
|
combined_input = tf.concat([encoded_state, encoded_action], axis=1) |
|
|
|
hidden = combined_input |
|
|
|
|
|
|
|
|
|
|
def create_forward_loss(self, reuse: bool, transfer: bool): |
|
|
|
|
|
|
|
if not transfer: |
|
|
|
if reuse: |
|
|
|
encoded_next_state = tf.stop_gradient(self.next_encoder) |
|
|
|
else: |
|
|
|
encoded_next_state = self.next_targ_encoder # gradient of target encode is already stopped |
|
|
|
# if not transfer: |
|
|
|
if reuse: |
|
|
|
encoded_next_state = tf.stop_gradient(self.next_encoder) |
|
|
|
else: |
|
|
|
encoded_next_state = self.next_targ_encoder # gradient of target encode is already stopped |
|
|
|
squared_difference = 0.5 * tf.reduce_sum( |
|
|
|
tf.squared_difference(tf.tanh(self.predict), encoded_next_state), axis=1 |
|
|
|
) |
|
|
|
self.forward_loss = tf.reduce_mean( |
|
|
|
tf.dynamic_partition(squared_difference, self.mask, 2)[1] |
|
|
|
) |
|
|
|
squared_difference = 0.5 * tf.reduce_sum( |
|
|
|
tf.squared_difference(tf.tanh(self.predict), encoded_next_state), axis=1 |
|
|
|
) |
|
|
|
self.forward_loss = tf.reduce_mean( |
|
|
|
tf.dynamic_partition(squared_difference, self.mask, 2)[1] |
|
|
|
) |
|
|
|
else: |
|
|
|
if reuse: |
|
|
|
squared_difference_1 = 0.5 * tf.reduce_sum( |
|
|
|
tf.squared_difference(tf.tanh(self.predict), tf.stop_gradient(self.next_encoder)), |
|
|
|
axis=1 |
|
|
|
) |
|
|
|
squared_difference_2 = 0.5 * tf.reduce_sum( |
|
|
|
tf.squared_difference(tf.tanh(tf.stop_gradient(self.predict)), self.next_encoder), |
|
|
|
axis=1 |
|
|
|
) |
|
|
|
else: |
|
|
|
squared_difference_1 = 0.5 * tf.reduce_sum( |
|
|
|
tf.squared_difference(tf.tanh(self.predict), self.next_targ_encoder), |
|
|
|
axis=1 |
|
|
|
) |
|
|
|
squared_difference_2 = 0.5 * tf.reduce_sum( |
|
|
|
tf.squared_difference(tf.tanh(self.targ_predict), self.next_encoder), |
|
|
|
axis=1 |
|
|
|
) |
|
|
|
self.forward_loss = tf.reduce_mean( |
|
|
|
tf.dynamic_partition(0.5 * squared_difference_1 + 0.5 * squared_difference_2, self.mask, 2)[1] |
|
|
|
) |
|
|
|
# else: |
|
|
|
# if reuse: |
|
|
|
# squared_difference_1 = 0.5 * tf.reduce_sum( |
|
|
|
# tf.squared_difference(tf.tanh(self.predict), tf.stop_gradient(self.next_encoder)), |
|
|
|
# axis=1 |
|
|
|
# ) |
|
|
|
# squared_difference_2 = 0.5 * tf.reduce_sum( |
|
|
|
# tf.squared_difference(tf.tanh(tf.stop_gradient(self.predict)), self.next_encoder), |
|
|
|
# axis=1 |
|
|
|
# ) |
|
|
|
# else: |
|
|
|
# squared_difference_1 = 0.5 * tf.reduce_sum( |
|
|
|
# tf.squared_difference(tf.tanh(self.predict), self.next_targ_encoder), |
|
|
|
# axis=1 |
|
|
|
# ) |
|
|
|
# squared_difference_2 = 0.5 * tf.reduce_sum( |
|
|
|
# tf.squared_difference(tf.tanh(self.targ_predict), self.next_encoder), |
|
|
|
# axis=1 |
|
|
|
# ) |
|
|
|
# self.forward_loss = tf.reduce_mean( |
|
|
|
# tf.dynamic_partition(0.5 * squared_difference_1 + 0.5 * squared_difference_2, self.mask, 2)[1] |
|
|
|
# ) |
|
|
|
|
|
|
|
|
|
|
|
def create_reward_model( |
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
if separate_train: |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
encoded_state = tf.stop_gradient(encoded_state) |
|
|
|
# encoded_action = tf.stop_gradient(encoded_action) |
|
|
|
|
|
|
|
combined_input = tf.concat([encoded_state, encoded_action], axis=1) |
|
|
|
|
|
|
|