浏览代码

Speed up a bit faster

/develop/tf2.0
Ervin Teng 5 年前
当前提交
d983a636
共有 1 个文件被更改,包括 18 次插入8 次删除
  1. 26
      ml-agents/mlagents/trainers/ppo/policy.py

26
ml-agents/mlagents/trainers/ppo/policy.py


),
)
def ppo_value_loss(self, values, old_values, returns):
def ppo_value_loss(self, values, old_values, returns, epsilon):
"""
Creates training-specific Tensorflow ops for PPO models.
:param probs: Current policy probabilities

:param max_step: Total number of training steps.
"""
decay_epsilon = self.trainer_params["epsilon"]
decay_epsilon = epsilon
value_losses = []
for name, head in values.items():

value_loss = tf.reduce_mean(value_losses)
return value_loss
@tf.function
def ppo_policy_loss(self, advantages, probs, old_probs, masks, epsilon):
"""
Creates training-specific Tensorflow ops for PPO models.

"""
advantage = tf.expand_dims(advantages, -1)
decay_epsilon = self.trainer_params["epsilon"]
decay_epsilon = epsilon
r_theta = tf.exp(probs - old_probs)
p_opt_a = r_theta * advantage

action=run_out.get("action"), value=run_out.get("value"), outputs=run_out
)
@tf.function
def update_computation(self, obs, actions):
values = self.model.get_values(obs)
dist = self.model(obs)
probs = dist.log_prob(actions)
entropy = dist.entropy()
return values, probs, entropy
@timed
def update(self, mini_batch, num_sequences):
"""

old_values[name] = mini_batch["{}_value_estimates".format(name)]
obs = np.array(mini_batch["vector_obs"])
values = self.model.get_values(obs)
dist = self.model(obs)
probs = dist.log_prob(np.array(mini_batch["actions"]))
entropy = dist.entropy()
value_loss = self.ppo_value_loss(values, old_values, returns)
values, probs, entropy = self.update_computation(
obs, np.array(mini_batch["actions"])
)
value_loss = self.ppo_value_loss(
values, old_values, returns, self.trainer_params["epsilon"]
)
policy_loss = self.ppo_policy_loss(
np.array(mini_batch["advantages"]),
probs,

正在加载...
取消
保存