浏览代码

[🐛🔨 ] Fix sac target for continuous actions (#5372)

* Fix of the target entropy for continuous SAC

* Lowering required steps of test and remove unecessary unsqueeze

* Changing the target from -dim(a)^2 to -dim(a) by removing implicit broadcasting
/colab-links
GitHub 4 年前
当前提交
fc6e8c35
共有 2 个文件被更改,包括 5 次插入5 次删除
  1. 8
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 2
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py

8
ml-agents/mlagents/trainers/sac/optimizer_torch.py


all_mean_q1 = mean_q1
if self._action_spec.continuous_size > 0:
cont_log_probs = log_probs.continuous_tensor
batch_policy_loss += torch.mean(
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1
batch_policy_loss += (
_cont_ent_coef * torch.sum(cont_log_probs, dim=1) - all_mean_q1
)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)

if self._action_spec.continuous_size > 0:
with torch.no_grad():
cont_log_probs = log_probs.continuous_tensor
target_current_diff = torch.sum(
cont_log_probs + self.target_entropy.continuous, dim=1
target_current_diff = (
torch.sum(cont_log_probs, dim=1) + self.target_entropy.continuous
)
# We update all the _cont_ent_coef as one block
entropy_loss += -1 * ModelUtils.masked_mean(

2
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


SAC_TORCH_CONFIG.hyperparameters, buffer_init_steps=2000
)
config = attr.evolve(
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=6000
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=3000
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.8)

正在加载...
取消
保存