浏览代码

add target net

/develop/coma-noact
Andrew Cohen 4 年前
当前提交
d1285626
共有 1 个文件被更改,包括 10 次插入13 次删除
  1. 23
      ml-agents/mlagents/trainers/ppo/optimizer_torch.py

23
ml-agents/mlagents/trainers/ppo/optimizer_torch.py


self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0
)
def ppo_value_loss(
self,
values: Dict[str, torch.Tensor],

"""
value_losses = []
for name, head in values.items():
#old_val_tensor = old_values[name]
returns_tensor = returns[name]# + 0.99 * old_val_tensor
#clipped_value_estimate = old_val_tensor + torch.clamp(
# old_val_tensor = old_values[name]
returns_tensor = returns[name] # + 0.99 * old_val_tensor
# clipped_value_estimate = old_val_tensor + torch.clamp(
#)
#value_loss = (returns_tensor - head) ** 2
# )
# value_loss = (returns_tensor - head) ** 2
#v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
#value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
# v_opt_b = (returns_tensor - clipped_value_estimate) ** 2
# value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks)
value_loss = ModelUtils.masked_mean(v_opt_a, loss_masks)
value_losses.append(value_loss)
value_loss = torch.mean(torch.stack(value_losses))

returns_q[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns_q"])
returns_b[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns_b"])
returns_v[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns_v"])
#
#
n_obs = len(self.policy.behavior_spec.sensor_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)

old_log_probs = ActionLogProbs.from_dict(batch).flatten()
log_probs = log_probs.flatten()
loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
q_loss = self.ppo_value_loss(
qs, old_values, returns_q, decay_eps, loss_masks
)
q_loss = self.ppo_value_loss(qs, old_values, returns_q, decay_eps, loss_masks)
baseline_loss = self.ppo_value_loss(
baseline_vals, old_marg_values, returns_b, decay_eps, loss_masks
)

self.optimizer.step()
ModelUtils.soft_update(
self.policy.actor_critic.critic, self.policy.actor_critic.target, 1.0
self.policy.actor_critic.critic, self.policy.actor_critic.target, 0.005
)
update_stats = {
# NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.

正在加载...
取消
保存