浏览代码

Don't run value during policy evaluate, optimized soft update function (#4501)

* Don't run value during inference

* Execute critic with LSTM

* Address comments

* Unformat

* Optimized soft update

* Move soft update to model utils

* Add test for soft update
/MLA-1734-demo-provider
GitHub 4 年前
当前提交
4e4ad7b0
共有 7 个文件被更改,包括 75 次插入46 次删除
  1. 37
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 27
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 4
      ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
  4. 10
      ml-agents/mlagents/trainers/tests/torch/test_policy.py
  5. 17
      ml-agents/mlagents/trainers/tests/torch/test_utils.py
  6. 2
      ml-agents/mlagents/trainers/torch/components/bc/module.py
  7. 24
      ml-agents/mlagents/trainers/torch/utils.py

37
ml-agents/mlagents/trainers/policy/torch_policy.py


memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
all_log_probs: bool = False,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
:param vec_obs: List of vector observations.
:param vis_obs: List of visual observations.
:param masks: Loss masks for RNN, else None.
:param memories: Input memories when using RNN, else None.
:param seq_len: Sequence length when using RNN.
:return: Tuple of actions, log probabilities (dependent on all_log_probs), entropies, and
output memories, all as Torch Tensors.
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
if memories is None:
dists, memories = self.actor_critic.get_dists(
vec_obs, vis_obs, masks, memories, seq_len
)
else:
# If we're using LSTM. we need to execute the values to get the critic memories
dists, _, memories = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
action_list = self.actor_critic.sample_action(dists)
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists

else:
actions = actions[:, 0, :]
return (
actions,
all_logs if all_log_probs else log_probs,
entropies,
value_heads,
memories,
)
return (actions, all_logs if all_log_probs else log_probs, entropies, memories)
def evaluate_actions(
self,

run_out = {}
with torch.no_grad():
action, log_probs, entropy, value_heads, memories = self.sample_actions(
action, log_probs, entropy, memories = self.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories
)
run_out["action"] = ModelUtils.to_numpy(action)

run_out["entropy"] = ModelUtils.to_numpy(entropy)
run_out["value_heads"] = {
name: ModelUtils.to_numpy(t) for name, t in value_heads.items()
}
run_out["value"] = np.mean(list(run_out["value_heads"].values()), 0)
run_out["learning_rate"] = 0.0
if self.use_recurrent:
run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0)

27
ml-agents/mlagents/trainers/sac/optimizer_torch.py


self.policy.behavior_spec.observation_shapes,
policy_network_settings,
)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)
ModelUtils.soft_update(
self.policy.actor_critic.critic, self.target_network, 1.0
)
self._log_ent_coef = torch.nn.Parameter(
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),

q1_loss = torch.mean(torch.stack(q1_losses))
q2_loss = torch.mean(torch.stack(q2_losses))
return q1_loss, q2_loss
def soft_update(self, source: nn.Module, target: nn.Module, tau: float) -> None:
for source_param, target_param in zip(source.parameters(), target.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - tau) + source_param.data * tau
)
def sac_value_loss(
self,

self.target_network.network_body.copy_normalization(
self.policy.actor_critic.network_body
)
(
sampled_actions,
log_probs,
entropies,
sampled_values,
_,
) = self.policy.sample_actions(
(sampled_actions, log_probs, _, _) = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

)
value_estimates, _ = self.policy.actor_critic.critic_pass(
vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length
)
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)

q1_stream, q2_stream, target_values, dones, rewards, masks
)
value_loss = self.sac_value_loss(
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete
log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete
)
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)

self.entropy_optimizer.step()
# Update target network
self.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau)
ModelUtils.soft_update(
self.policy.actor_critic.critic, self.target_network, self.tau
)
update_stats = {
"Losses/Policy Loss": policy_loss.item(),
"Losses/Value Loss": value_loss.item(),

4
ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py


).unsqueeze(0)
with torch.no_grad():
_, log_probs1, _, _, _ = policy1.sample_actions(
_, log_probs1, _, _ = policy1.sample_actions(
_, log_probs2, _, _, _ = policy2.sample_actions(
_, log_probs2, _, _ = policy2.sample_actions(
vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True
)

10
ml-agents/mlagents/trainers/tests/torch/test_policy.py


if len(memories) > 0:
memories = torch.stack(memories).unsqueeze(0)
(
sampled_actions,
log_probs,
entropies,
sampled_values,
memories,
) = policy.sample_actions(
(sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

else:
assert log_probs.shape == (64, policy.behavior_spec.action_shape)
assert entropies.shape == (64, policy.behavior_spec.action_size)
for val in sampled_values.values():
assert val.shape == (64,)
if rnn:
assert memories.shape == (1, 1, policy.m_size)

17
ml-agents/mlagents/trainers/tests/torch/test_utils.py


masks = torch.tensor([False, False, True, True, True])
mean = ModelUtils.masked_mean(test_input, masks=masks)
assert mean == 4.0
def test_soft_update():
class TestModule(torch.nn.Module):
def __init__(self, vals):
super().__init__()
self.parameter = torch.nn.Parameter(torch.ones(5, 5, 5) * vals)
tm1 = TestModule(0)
tm2 = TestModule(1)
tm3 = TestModule(2)
ModelUtils.soft_update(tm1, tm3, tau=0.5)
assert torch.equal(tm3.parameter, torch.ones(5, 5, 5))
ModelUtils.soft_update(tm1, tm2, tau=1.0)
assert torch.equal(tm2.parameter, tm1.parameter)

2
ml-agents/mlagents/trainers/torch/components/bc/module.py


else:
vis_obs = []
selected_actions, all_log_probs, _, _, _ = self.policy.sample_actions(
selected_actions, all_log_probs, _, _ = self.policy.sample_actions(
vec_obs,
vis_obs,
masks=act_masks,

24
ml-agents/mlagents/trainers/torch/utils.py


return (tensor.T * masks).sum() / torch.clamp(
(torch.ones_like(tensor.T) * masks).float().sum(), min=1.0
)
@staticmethod
def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None:
"""
Performs an in-place polyak update of the target module based on the source,
by a ratio of tau. Note that source and target modules must have the same
parameters, where:
target = tau * source + (1-tau) * target
:param source: Source module whose parameters will be used.
:param target: Target module whose parameters will be updated.
:param tau: Percentage of source parameters to use in average. Setting tau to
1 will copy the source parameters to the target.
"""
with torch.no_grad():
for source_param, target_param in zip(
source.parameters(), target.parameters()
):
target_param.data.mul_(1.0 - tau)
torch.add(
target_param.data,
source_param.data,
alpha=tau,
out=target_param.data,
)
正在加载...
取消
保存