浏览代码

Cherry-pick BC fixes to Release 10 (#4668)

/release_10_branch
GitHub 4 年前
当前提交
f0ed3a38
共有 5 个文件被更改,包括 33 次插入15 次删除
  1. 18
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 2
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 4
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  4. 12
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  5. 12
      ml-agents/mlagents/trainers/torch/components/bc/module.py

18
ml-agents/mlagents/trainers/policy/torch_policy.py


memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
:param vec_obs: List of vector observations.
:param vis_obs: List of visual observations.

:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
:return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and
output memories, all as Torch Tensors.
:return: Tuple of actions, actions clipped to -1, 1, log probabilities (dependent on all_log_probs),
entropies, and output memories, all as Torch Tensors.
"""
if memories is None:
dists, memories = self.actor_critic.get_dists(

actions = actions[:, 0, :]
# Use the sum of entropy across actions, not the mean
entropy_sum = torch.sum(entropies, dim=1)
if self._clip_action and self.use_continuous_act:
clipped_action = torch.clamp(actions, -3, 3) / 3
else:
clipped_action = actions
clipped_action,
all_logs if all_log_probs else log_probs,
entropy_sum,
memories,

run_out = {}
with torch.no_grad():
action, log_probs, entropy, memories = self.sample_actions(
action, clipped_action, log_probs, entropy, memories = self.sample_actions(
if self._clip_action and self.use_continuous_act:
clipped_action = torch.clamp(action, -3, 3) / 3
else:
clipped_action = action
run_out["pre_action"] = ModelUtils.to_numpy(action)
run_out["action"] = ModelUtils.to_numpy(clipped_action)
# Todo - make pre_action difference

2
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(sampled_actions, log_probs, _, _) = self.policy.sample_actions(
(sampled_actions, _, log_probs, _, _) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

4
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


).unsqueeze(0)
with torch.no_grad():
_, log_probs1, _, _ = policy1.sample_actions(
_, _, log_probs1, _, _ = policy1.sample_actions(
_, log_probs2, _, _ = policy2.sample_actions(
_, _, log_probs2, _, _ = policy2.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)

12
ml-agents/mlagents/trainers/tests/torch/test_policy.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
(sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
(
sampled_actions,
clipped_actions,
log_probs,
entropies,
memories,
) = policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

)
else:
assert log_probs.shape == (64, policy.behavior_spec.action_spec.continuous_size)
assert clipped_actions.shape == (
64,
policy.behavior_spec.action_spec.continuous_size,
)
assert entropies.shape == (64,)
if rnn:

12
ml-agents/mlagents/trainers/torch/components/bc/module.py


# Don't continue training if the learning rate has reached 0, to reduce training time.
decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step())
if self.current_lr <= 0:
if self.current_lr <= 1e-10: # Unlike in TF, this never actually reaches 0.
return {"Losses/Pretraining Loss": 0}
batch_losses = []

else:
vis_obs = []
selected_actions, all_log_probs, _, _ = self.policy.sample_actions(
(
selected_actions,
clipped_actions,
all_log_probs,
_,
_,
) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

)
bc_loss = self._behavioral_cloning_loss(
selected_actions, all_log_probs, expert_actions
clipped_actions, all_log_probs, expert_actions
)
self.optimizer.zero_grad()
bc_loss.backward()

正在加载...
取消
保存