|
|
|
|
|
|
) |
|
|
|
else: |
|
|
|
log_probs = self.policy_model.all_log_probs |
|
|
|
action_idx = [0] + list(np.cumsum(action_size)) |
|
|
|
entropy = tf.reduce_sum( |
|
|
|
( |
|
|
|
tf.stack( |
|
|
|
[ |
|
|
|
tf.nn.softmax_cross_entropy_with_logits_v2( |
|
|
|
labels=tf.nn.softmax( |
|
|
|
log_probs[:, action_idx[i] : action_idx[i + 1]] |
|
|
|
), |
|
|
|
logits=log_probs[:, action_idx[i] : action_idx[i + 1]], |
|
|
|
) |
|
|
|
for i in range(len(action_size)) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
), |
|
|
|
axis=1, |
|
|
|
) |
|
|
|
self.loss = tf.reduce_mean( |
|
|
|
-tf.log(tf.nn.softmax(log_probs) + 1e-7) * self.expert_action |
|
|
|
) |
|
|
|