浏览代码

[feature] Hybrid SAC (#4574)

/develop/actionmodel-csharp
GitHub 4 年前
当前提交
12e1fc28
共有 3 个文件被更改,包括 194 次插入140 次删除
  1. 313
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  2. 15
      ml-agents/mlagents/trainers/tests/torch/test_hybrid.py
  3. 6
      ml-agents/mlagents/trainers/torch/networks.py

313
ml-agents/mlagents/trainers/sac/optimizer_torch.py


import numpy as np
from typing import Dict, List, Mapping, cast, Tuple, Optional
from typing import Dict, List, Mapping, NamedTuple, cast, Tuple, Optional
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer
from mlagents.trainers.policy.torch_policy import TorchPolicy
from mlagents.trainers.settings import NetworkSettings

from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.buffer import AgentBuffer
from mlagents_envs.timers import timed
from mlagents_envs.base_env import ActionSpec
from mlagents.trainers.exception import UnityTrainerException
from mlagents.trainers.settings import TrainerSettings, SACSettings
from contextlib import ExitStack

action_spec: ActionSpec,
):
super().__init__()
self.action_spec = action_spec
if self.action_spec.is_continuous():
self.act_size = self.action_spec.continuous_size
num_value_outs = 1
num_action_ins = self.act_size
num_value_outs = max(sum(action_spec.discrete_branches), 1)
num_action_ins = int(action_spec.continuous_size)
else:
self.act_size = self.action_spec.discrete_branches
num_value_outs = sum(self.act_size)
num_action_ins = 0
self.q1_network = ValueNetwork(
stream_names,
observation_shapes,

)
return q1_out, q2_out
class TargetEntropy(NamedTuple):
discrete: List[float] = [] # One per branch
continuous: float = 0.0
class LogEntCoef(nn.Module):
def __init__(self, discrete, continuous):
super().__init__()
self.discrete = discrete
self.continuous = continuous
def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings):
super().__init__(policy, trainer_params)
hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters)

self.policy = policy
self.act_size = policy.act_size
policy_network_settings = policy.network_settings
self.tau = hyperparameters.tau

name: int(not self.reward_signals[name].ignore_done)
for name in self.stream_names
}
self._action_spec = self.policy.behavior_spec.action_spec
self.policy.behavior_spec.action_spec,
self._action_spec,
)
self.target_network = ValueNetwork(

self.policy.actor_critic.critic, self.target_network, 1.0
)
self._log_ent_coef = torch.nn.Parameter(
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),
# We create one entropy coefficient per action, whether discrete or continuous.
_disc_log_ent_coef = torch.nn.Parameter(
torch.log(
torch.as_tensor(
[self.init_entcoef] * len(self._action_spec.discrete_branches)
)
),
requires_grad=True,
)
_cont_log_ent_coef = torch.nn.Parameter(
torch.log(
torch.as_tensor([self.init_entcoef] * self._action_spec.continuous_size)
),
if self.policy.use_continuous_act:
self.target_entropy = torch.as_tensor(
-1
* self.continuous_target_entropy_scale
* np.prod(self.act_size[0]).astype(np.float32)
)
else:
self.target_entropy = [
self.discrete_target_entropy_scale * np.log(i).astype(np.float32)
for i in self.act_size
]
self._log_ent_coef = TorchSACOptimizer.LogEntCoef(
discrete=_disc_log_ent_coef, continuous=_cont_log_ent_coef
)
_cont_target = (
-1
* self.continuous_target_entropy_scale
* np.prod(self._action_spec.continuous_size).astype(np.float32)
)
_disc_target = [
self.discrete_target_entropy_scale * np.log(i).astype(np.float32)
for i in self._action_spec.discrete_branches
]
self.target_entropy = TorchSACOptimizer.TargetEntropy(
continuous=_cont_target, discrete=_disc_target
)
policy_params = list(self.policy.actor_critic.network_body.parameters()) + list(
self.policy.actor_critic.action_model.parameters()
)

value_params, lr=hyperparameters.learning_rate
)
self.entropy_optimizer = torch.optim.Adam(
[self._log_ent_coef], lr=hyperparameters.learning_rate
self._log_ent_coef.parameters(), lr=hyperparameters.learning_rate
)
self._move_to_device(default_device())

q1p_out: Dict[str, torch.Tensor],
q2p_out: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
_ent_coef = torch.exp(self._log_ent_coef)
for name in values.keys():
if not discrete:
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
else:
action_probs = log_probs.all_discrete_tensor.exp()
_branched_q1p = ModelUtils.break_into_branches(
q1p_out[name] * action_probs, self.act_size
)
_branched_q2p = ModelUtils.break_into_branches(
q2p_out[name] * action_probs, self.act_size
)
_q1p_mean = torch.mean(
torch.stack(
[
torch.sum(_br, dim=1, keepdim=True)
for _br in _branched_q1p
]
),
dim=0,
)
_q2p_mean = torch.mean(
torch.stack(
[
torch.sum(_br, dim=1, keepdim=True)
for _br in _branched_q2p
]
),
dim=0,
)
_cont_ent_coef = self._log_ent_coef.continuous.exp()
_disc_ent_coef = self._log_ent_coef.discrete.exp()
for name in values.keys():
if self._action_spec.discrete_size <= 0:
min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
else:
disc_action_probs = log_probs.all_discrete_tensor.exp()
_branched_q1p = ModelUtils.break_into_branches(
q1p_out[name] * disc_action_probs,
self._action_spec.discrete_branches,
)
_branched_q2p = ModelUtils.break_into_branches(
q2p_out[name] * disc_action_probs,
self._action_spec.discrete_branches,
)
_q1p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p]
),
dim=0,
)
_q2p_mean = torch.mean(
torch.stack(
[torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p]
),
dim=0,
)
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
if not discrete:
if self._action_spec.discrete_size <= 0:
_ent_coef * log_probs.continuous_tensor, dim=1
_cont_ent_coef * log_probs.continuous_tensor, dim=1
)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup), loss_masks

disc_log_probs = log_probs.all_discrete_tensor
log_probs.all_discrete_tensor * log_probs.all_discrete_tensor.exp(),
self.act_size,
disc_log_probs * disc_log_probs.exp(),
self._action_spec.discrete_branches,
torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True)
for i, _lp in enumerate(branched_per_action_ent)
]
)

branched_ent_bonus, axis=0
)
# Add continuous entropy bonus to minimum Q
if self._action_spec.continuous_size > 0:
torch.sum(
_cont_ent_coef * log_probs.continuous_tensor,
dim=1,
keepdim=True,
)
value_loss = 0.5 * ModelUtils.masked_mean(
torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
loss_masks,

log_probs: ActionLogProbs,
q1p_outs: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
_ent_coef = torch.exp(self._log_ent_coef)
_cont_ent_coef, _disc_ent_coef = (
self._log_ent_coef.continuous,
self._log_ent_coef.discrete,
)
_cont_ent_coef = _cont_ent_coef.exp()
_disc_ent_coef = _disc_ent_coef.exp()
if not discrete:
mean_q1 = mean_q1.unsqueeze(1)
batch_policy_loss = torch.mean(
_ent_coef * log_probs.continuous_tensor - mean_q1, dim=1
)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
else:
action_probs = log_probs.all_discrete_tensor.exp()
batch_policy_loss = 0
if self._action_spec.discrete_size > 0:
disc_log_probs = log_probs.all_discrete_tensor
disc_action_probs = disc_log_probs.exp()
log_probs.all_discrete_tensor * action_probs, self.act_size
disc_log_probs * disc_action_probs, self._action_spec.discrete_branches
mean_q1 * action_probs, self.act_size
mean_q1 * disc_action_probs, self._action_spec.discrete_branches
torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
torch.sum(_disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False)
for i, (_lp, _qt) in enumerate(
zip(branched_per_action_ent, branched_q_term)
)

batch_policy_loss = torch.squeeze(branched_policy_loss)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
batch_policy_loss += torch.sum(branched_policy_loss, dim=1)
all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1)
else:
all_mean_q1 = mean_q1
if self._action_spec.continuous_size > 0:
cont_log_probs = log_probs.continuous_tensor
batch_policy_loss += torch.mean(
_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1
)
policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
self, log_probs: ActionLogProbs, loss_masks: torch.Tensor, discrete: bool
self, log_probs: ActionLogProbs, loss_masks: torch.Tensor
if not discrete:
with torch.no_grad():
target_current_diff = torch.sum(
log_probs.continuous_tensor + self.target_entropy, dim=1
)
entropy_loss = -1 * ModelUtils.masked_mean(
self._log_ent_coef * target_current_diff, loss_masks
)
else:
_cont_ent_coef, _disc_ent_coef = (
self._log_ent_coef.continuous,
self._log_ent_coef.discrete,
)
entropy_loss = 0
if self._action_spec.discrete_size > 0:
# Break continuous into separate branch
disc_log_probs = log_probs.all_discrete_tensor
log_probs.all_discrete_tensor * log_probs.all_discrete_tensor.exp(),
self.act_size,
disc_log_probs * disc_log_probs.exp(),
self._action_spec.discrete_branches,
branched_per_action_ent, self.target_entropy
branched_per_action_ent, self.target_entropy.discrete
)
],
axis=1,

)
entropy_loss = -1 * ModelUtils.masked_mean(
torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
entropy_loss += -1 * ModelUtils.masked_mean(
torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks
)
if self._action_spec.continuous_size > 0:
with torch.no_grad():
cont_log_probs = log_probs.continuous_tensor
target_current_diff = torch.sum(
cont_log_probs + self.target_entropy.continuous, dim=1
)
# We update all the _cont_ent_coef as one block
entropy_loss += -1 * ModelUtils.masked_mean(
torch.mean(_cont_ent_coef) * target_current_diff, loss_masks
)
return entropy_loss

) -> Dict[str, torch.Tensor]:
condensed_q_output = {}
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size)
onehot_actions = ModelUtils.actions_to_onehot(
discrete_actions, self._action_spec.discrete_branches
)
branched_q = ModelUtils.break_into_branches(item, self.act_size)
branched_q = ModelUtils.break_into_branches(
item, self._action_spec.discrete_branches
)
only_action_qs = torch.stack(
[
torch.sum(_act * _q, dim=1, keepdim=True)

value_estimates, _ = self.policy.actor_critic.critic_pass(
vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length
)
if self.policy.use_continuous_act:
squeezed_actions = actions.continuous_tensor
# Only need grad for q1, as that is used for policy.
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
sampled_actions.continuous_tensor,
memories=q_memories,
sequence_length=self.policy.sequence_length,
q2_grad=False,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
squeezed_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_stream, q2_stream = q1_out, q2_out
cont_sampled_actions = sampled_actions.continuous_tensor
cont_actions = actions.continuous_tensor
disc_actions = actions.discrete_tensor
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
cont_sampled_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
cont_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
if self._action_spec.discrete_size:
q1_stream = self._condense_q_streams(q1_out, disc_actions)
q2_stream = self._condense_q_streams(q2_out, disc_actions)
# For discrete, you don't need to backprop through the Q for the policy
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
q1_grad=False,
q2_grad=False,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_stream = self._condense_q_streams(q1_out, actions.discrete_tensor)
q2_stream = self._condense_q_streams(q2_out, actions.discrete_tensor)
q1_stream, q2_stream = q1_out, q2_out
with torch.no_grad():
target_values, _ = self.target_network(

sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])
q1_loss, q2_loss = self.sac_q_loss(

log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete
log_probs, value_estimates, q1p_out, q2p_out, masks
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
entropy_loss = self.sac_entropy_loss(log_probs, masks)
total_value_loss = q1_loss + q2_loss + value_loss

"Losses/Value Loss": value_loss.item(),
"Losses/Q1 Loss": q1_loss.item(),
"Losses/Q2 Loss": q2_loss.item(),
"Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(),
"Policy/Discrete Entropy Coeff": torch.mean(
torch.exp(self._log_ent_coef.discrete)
).item(),
"Policy/Continuous Entropy Coeff": torch.mean(
torch.exp(self._log_ent_coef.continuous)
).item(),
"Policy/Learning Rate": decay_lr,
}

15
ml-agents/mlagents/trainers/tests/torch/test_hybrid.py


import pytest
import pytest
from mlagents.trainers.tests.simple_test_envs import (

max_steps=10000,
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9)
@pytest.mark.parametrize("action_size", [(1, 1), (2, 2)])
def test_hybrid_sac(action_size):
env = SimpleEnvironment([BRAIN_NAME], action_sizes=action_size)
new_hyperparams = attr.evolve(
SAC_TORCH_CONFIG.hyperparameters, buffer_size=50000, batch_size=128
)
config = attr.evolve(
SAC_TORCH_CONFIG, hyperparameters=new_hyperparams, max_steps=3000
)
check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=0.9)

6
ml-agents/mlagents/trainers/torch/networks.py


else 0
)
self.visual_processors, self.vector_processors, encoder_input_size = ModelUtils.create_input_processors(
(
self.visual_processors,
self.vector_processors,
encoder_input_size,
) = ModelUtils.create_input_processors(
observation_shapes,
self.h_size,
network_settings.vis_encode_type,

正在加载...
取消
保存