浏览代码

Right loss function for stability, fix some pypi

/develop/action-slice
Ervin Teng 4 年前
当前提交
a4fcbb63
共有 2 个文件被更改,包括 4 次插入4 次删除
  1. 6
      ml-agents/mlagents/trainers/coma/optimizer_torch.py
  2. 2
      ml-agents/mlagents/trainers/coma/trainer.py

6
ml-agents/mlagents/trainers/coma/optimizer_torch.py


)
loss = (
policy_loss
+ 0.5 * (value_loss + baseline_loss)
+ 0.5 * (value_loss + 0.5 * baseline_loss)
- decay_bet * ModelUtils.masked_mean(entropy, loss_masks)
)

modules.update(reward_provider.get_modules())
return modules
def get_trajectory_value_estimates(
def get_trajectory_and_baseline_value_estimates(
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, float]]:
n_obs = len(self.policy.behavior_spec.observation_specs)

2
ml-agents/mlagents/trainers/coma/trainer.py


value_estimates,
baseline_estimates,
value_next,
) = self.optimizer.get_trajectory_value_estimates(
) = self.optimizer.get_trajectory_and_baseline_value_estimates(
agent_buffer_trajectory,
trajectory.next_obs,
trajectory.next_group_obs,

正在加载...
取消
保存