浏览代码

add mede opt with format

/exp-mede
Andrew Cohen 5 年前
当前提交
88153b61
共有 1 个文件被更改,包括 26 次插入13 次删除
  1. 39
      ml-agents/mlagents/trainers/sac/mede_optimizer.py

39
ml-agents/mlagents/trainers/sac/mede_optimizer.py


if self.policy.use_continuous_act:
self.discp = ModelUtils.create_discriminator(
obs,
self.num_diverse,
action_input=self.policy.output,
obs, self.num_diverse, action_input=self.policy.output
)
# The optimizer's m_size is 3 times the policy (Q1, Q2, and Value)

"update_disc": self.update_batch_disc,
"learning_rate": self.learning_rate,
}
return tf.split(observation_and_skill, [self.policy.vec_obs_size - self.num_diverse, self.num_diverse], 1)
return tf.split(
observation_and_skill,
[self.policy.vec_obs_size - self.num_diverse, self.num_diverse],
1,
)
def _create_inputs_and_outputs(self) -> None:
"""

self.rewards_holders = {}
self.min_policy_qs = {}
#discriminator loss
self.disc_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self._z_one_hot, logits=self.disc))
# discriminator loss
self.disc_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
labels=self._z_one_hot, logits=self.disc
)
)
discriminabilityp = -1 * tf.nn.softmax_cross_entropy_with_logits(labels=self._z_one_hot, logits=self.discp)
self.discriminability = -1 * tf.nn.softmax_cross_entropy_with_logits(labels=self._z_one_hot, logits=self.disc)
discriminabilityp = -1 * tf.nn.softmax_cross_entropy_with_logits(
labels=self._z_one_hot, logits=self.discp
)
self.discriminability = -1 * tf.nn.softmax_cross_entropy_with_logits(
labels=self._z_one_hot, logits=self.disc
)
for name in stream_names:
if discrete:

]
)
self.policy_loss = tf.reduce_mean(
tf.to_float(self.policy.mask) * tf.squeeze(branched_policy_loss) - self.discriminability
tf.to_float(self.policy.mask) * tf.squeeze(branched_policy_loss)
- self.discriminability
)
# Do vbackup entropy bonus per branch as well.

)
)
batch_policy_loss = tf.reduce_mean(
self.ent_coef * self.policy.all_log_probs - self.policy_network.q1_p,
self.ent_coef * self.policy.all_log_probs - self.policy_network.q1_p,
axis=1,
)
self.policy_loss = tf.reduce_mean(

)
]
discriminator_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="discriminator")
discriminator_vars = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope="discriminator"
)
self.update_batch_disc = discriminator_optimizer.minimize(
self.disc_loss, var_list=discriminator_vars

正在加载...
取消
保存