浏览代码

Add hybrid actions to SAC

/develop/hybrid-actions
Ervin Teng 4 年前
当前提交
8dec4771
共有 5 个文件被更改,包括 218 次插入124 次删除
  1. 4
      ml-agents/mlagents/trainers/policy/torch_policy.py
  2. 153
      ml-agents/mlagents/trainers/sac/optimizer_torch.py
  3. 8
      ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py
  4. 165
      ml-agents/mlagents/trainers/torch/action_models.py
  5. 12
      ml-agents/mlagents/trainers/torch/networks.py

4
ml-agents/mlagents/trainers/policy/torch_policy.py


self, decision_requests: DecisionSteps
) -> Tuple[SplitObservations, np.ndarray]:
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)
#mask = None
# mask = None
mask = torch.ones([len(decision_requests), np.sum(self.discrete_act_size)])
if decision_requests.action_mask is not None:
mask = torch.as_tensor(

:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
actions, log_probs, entropies, value_heads, memories = self.actor_critic.get_action_stats_and_value(
vec_obs, vis_obs, masks, memories, seq_len
vec_obs, vis_obs, masks, memories, seq_len, all_log_probs
)
return (
actions,

153
ml-agents/mlagents/trainers/sac/optimizer_torch.py


from mlagents_envs.communicator_objects.space_type_pb2 import continuous
import numpy as np
from typing import Dict, List, Mapping, cast, Tuple, Optional
from mlagents.torch_utils import torch, nn, default_device

stream_names: List[str],
observation_shapes: List[Tuple[int, ...]],
network_settings: NetworkSettings,
act_type: ActionType,
act_size: List[int],
continuous_act_size: List[int],
discrete_act_size: List[int],
if act_type == ActionType.CONTINUOUS:
num_value_outs = 1
num_action_ins = sum(act_size)
else:
num_value_outs = sum(act_size)
num_action_ins = 0
num_value_outs = max(sum(discrete_act_size), 1)
num_action_ins = int(continuous_act_size)
self.q1_network = ValueNetwork(
stream_names,
observation_shapes,

self.init_entcoef = hyperparameters.init_entcoef
self.policy = policy
self.act_size = policy.act_size
policy_network_settings = policy.network_settings
self.tau = hyperparameters.tau

self.stream_names,
self.policy.behavior_spec.observation_shapes,
policy_network_settings,
self.policy.behavior_spec.action_type,
self.act_size,
self.policy.continuous_act_size,
self.policy.discrete_act_size,
)
self.target_network = ValueNetwork(

)
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0)
_total_act_size = 0
if self.policy.continuous_act_size > 0:
_total_act_size += 1
_total_act_size += len(self.policy.discrete_act_size)
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))),
torch.log(torch.as_tensor([self.init_entcoef] * _total_act_size)),
if self.policy.use_continuous_act:
self.target_entropy = torch.as_tensor(
-1
* self.continuous_target_entropy_scale
* np.prod(self.act_size[0]).astype(np.float32)
self.target_entropy = []
if self.policy.continuous_act_size > 0:
self.target_entropy.append(
torch.as_tensor(
-1
* self.continuous_target_entropy_scale
* np.prod(self.policy.continuous_act_size).astype(np.float32)
)
self.target_entropy = [
self.target_entropy += [
for i in self.act_size
for i in self.policy.discrete_act_size
self.policy.actor_critic.distribution.parameters()
self.policy.actor_critic.action_model.parameters()
)
value_params = list(self.value_network.parameters()) + list(
self.policy.actor_critic.critic.parameters()

q1p_out: Dict[str, torch.Tensor],
q2p_out: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
if not discrete:
if len(self.policy.discrete_act_size) <= 0:
action_probs = log_probs.exp()
disc_action_probs = log_probs[:, self.policy.continuous_act_size:].exp()
q1p_out[name] * action_probs, self.act_size
q1p_out[name] * disc_action_probs, self.policy.discrete_act_size
q2p_out[name] * action_probs, self.act_size
q2p_out[name] * disc_action_probs, self.policy.discrete_act_size
)
_q1p_mean = torch.mean(
torch.stack(

min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)
value_losses = []
if not discrete:
if len(self.policy.discrete_act_size) <= 0:
for name in values.keys():
with torch.no_grad():
v_backup = min_policy_qs[name] - torch.sum(

value_losses.append(value_loss)
else:
branched_per_action_ent = ModelUtils.break_into_branches(
log_probs * log_probs.exp(), self.act_size
log_probs * log_probs.exp(), self.policy.discrete_act_size
)
# We have to do entropy bonus per action branch
branched_ent_bonus = torch.stack(

log_probs: torch.Tensor,
q1p_outs: Dict[str, torch.Tensor],
loss_masks: torch.Tensor,
discrete: bool,
if not discrete:
if len(self.policy.discrete_act_size) <= 0:
action_probs = log_probs.exp()
disc_log_probs = log_probs[:, self.policy.continuous_act_size:]
disc_action_probs = disc_log_probs.exp()
log_probs * action_probs, self.act_size
disc_log_probs * disc_action_probs, self.policy.discrete_act_size
mean_q1 * action_probs, self.act_size
mean_q1 * disc_action_probs, self.policy.discrete_act_size
)
branched_policy_loss = torch.stack(
[

return policy_loss
def sac_entropy_loss(
self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
self, log_probs: torch.Tensor, loss_masks: torch.Tensor
if not discrete:
if len(self.policy.discrete_act_size) <= 0:
with torch.no_grad():
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
entropy_loss = -torch.mean(

with torch.no_grad():
branched_per_action_ent = ModelUtils.break_into_branches(
log_probs * log_probs.exp(), self.act_size
log_probs * log_probs.exp(), self.policy.discrete_act_size
)
target_current_diff_branched = torch.stack(
[

self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor
) -> Dict[str, torch.Tensor]:
condensed_q_output = {}
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size)
onehot_actions = ModelUtils.actions_to_onehot(
discrete_actions, self.policy.discrete_act_size
)
branched_q = ModelUtils.break_into_branches(item, self.act_size)
branched_q = ModelUtils.break_into_branches(
item, self.policy.discrete_act_size
)
only_action_qs = torch.stack(
[
torch.sum(_act * _q, dim=1, keepdim=True)

vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])]
act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
if self.policy.use_continuous_act:
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
else:
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long)
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
memories_list = [
ModelUtils.list_to_tensor(batch["memory"][i])

masks=act_masks,
memories=memories,
seq_len=self.policy.sequence_length,
all_log_probs=not self.policy.use_continuous_act,
all_log_probs=True,
if self.policy.use_continuous_act:
squeezed_actions = actions.squeeze(-1)
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
sampled_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
squeezed_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_stream, q2_stream = q1_out, q2_out
cont_sampled_actions = sampled_actions[:, : self.policy.continuous_act_size]
# disc_sampled_actions = sampled_actions[:, self.policy.continuous_act_size :]
cont_actions = actions[:, : self.policy.continuous_act_size]
disc_actions = actions[:, self.policy.continuous_act_size :]
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
cont_sampled_actions,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
cont_actions.squeeze(-1),
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
if self.policy.discrete_act_size:
q1_stream = self._condense_q_streams(q1_out, disc_actions.squeeze(-1))
q2_stream = self._condense_q_streams(q2_out, disc_actions.squeeze(-1))
with torch.no_grad():
q1p_out, q2p_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_out, q2_out = self.value_network(
vec_obs,
vis_obs,
memories=q_memories,
sequence_length=self.policy.sequence_length,
)
q1_stream = self._condense_q_streams(q1_out, actions)
q2_stream = self._condense_q_streams(q2_out, actions)
q1_stream, q2_stream = q1_out, q2_out
with torch.no_grad():
target_values, _ = self.target_network(

sequence_length=self.policy.sequence_length,
)
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
use_discrete = not self.policy.use_continuous_act
dones = ModelUtils.list_to_tensor(batch["done"])
q1_loss, q2_loss = self.sac_q_loss(

log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete
log_probs, sampled_values, q1p_out, q2p_out, masks
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete)
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete)
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
entropy_loss = self.sac_entropy_loss(log_probs, masks)
total_value_loss = q1_loss + q2_loss + value_loss

8
ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py


config = attr.evolve(PPO_CONFIG, hyperparameters=new_hyperparams, max_steps=100000)
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=5.0)
def test_2dhybrid_sac():
env = HybridEnvironment([BRAIN_NAME], continuous_action_size=1, discrete_action_size=2, step_size=0.8)
new_hyperparams = attr.evolve(
SAC_CONFIG.hyperparameters, buffer_size=50000, batch_size=128
)
config = attr.evolve(SAC_CONFIG, hyperparameters=new_hyperparams, max_steps=100000)
_check_environment_trains(env, {BRAIN_NAME: config}, success_threshold=5.0)
#
#@pytest.mark.parametrize("use_discrete", [True, False])
#def test_2d_ppo(use_discrete):

165
ml-agents/mlagents/trainers/torch/action_models.py


import numpy as np
import math
from mlagents.trainers.torch.layers import linear_layer, Initialization
from mlagents.trainers.torch.distributions import DistInstance, DiscreteDistInstance, GaussianDistribution, MultiCategoricalDistribution
from mlagents.trainers.torch.distributions import (
DistInstance,
DiscreteDistInstance,
GaussianDistribution,
MultiCategoricalDistribution,
)
from mlagents.trainers.torch.utils import ModelUtils

#@abc.abstractmethod
#def entropy(self, action_list: np.ndarray) -> torch.Tensor:
# @abc.abstractmethod
# def entropy(self, action_list: np.ndarray) -> torch.Tensor:
#@abc.abstractmethod
#def log_probs(self, action_list: np.ndarray) -> torch.Tensor:
# @abc.abstractmethod
# def log_probs(self, action_list: np.ndarray) -> torch.Tensor:
# pass
def _sample_action(self, dists: List[DistInstance]) -> List[torch.Tensor]:

@abc.abstractmethod
def forward(self, inputs: torch.Tensor, masks: torch.Tensor):
pass
class HybridActionModel(ActionModel):
def __init__(

self.encoding_size = hidden_size
self.continuous_act_size = continuous_act_size
self.discrete_act_size = discrete_act_size
self.continuous_distribution = None #: List[GaussianDistribution] = []
self.discrete_distribution = None #: List[MultiCategoricalDistribution] = []
self.continuous_distribution = None #: List[GaussianDistribution] = []
self.discrete_distribution = None #: List[MultiCategoricalDistribution] = []
self.encoding_size,
continuous_act_size,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
self.encoding_size,
continuous_act_size,
conditional_sigma=conditional_sigma,
tanh_squash=tanh_squash,
)
self.discrete_distribution = MultiCategoricalDistribution(self.encoding_size, discrete_act_size)
self.discrete_distribution = MultiCategoricalDistribution(
self.encoding_size, discrete_act_size
)
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def evaluate(
self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
continuous_actions, discrete_actions = torch.split(actions, [self.continuous_act_size, len(self.discrete_act_size)], dim=1)
continuous_actions, discrete_actions = torch.split(
actions, [self.continuous_act_size, len(self.discrete_act_size)], dim=1
)
continuous_action_list = [
continuous_actions[..., i] for i in range(continuous_actions.shape[-1])
]
(
continuous_log_probs,
continuous_entropies,
_,
) = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists)
continuous_action_list = [continuous_actions[..., i] for i in range(continuous_actions.shape[-1])]
continuous_log_probs, continuous_entropies, _ = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists)
discrete_action_list = [
discrete_actions[:, i] for i in range(len(self.discrete_act_size))
]
discrete_log_probs, discrete_entropies, _ = ModelUtils.get_probs_and_entropy(
discrete_action_list, discrete_dists
)
discrete_action_list = [discrete_actions[:, i] for i in range(len(self.discrete_act_size))]
discrete_log_probs, discrete_entropies, _ = ModelUtils.get_probs_and_entropy(discrete_action_list, discrete_dists)
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1)
entropies = torch.cat([continuous_entropies, torch.mean(discrete_entropies, dim=0).unsqueeze(0)], dim=1)
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1)
entropies = torch.cat(
[continuous_entropies, torch.mean(discrete_entropies, dim=0).unsqueeze(0)],
dim=1,
)
return log_probs, entropies
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:

def _get_dists(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
#continuous_distributions: List[DistInstance] = []
#discrete_distributions: List[DiscreteDistInstance] = []
continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions]
discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions]
#for continuous_dist in self.continuous_distributions:
def _get_dists(
self, inputs: torch.Tensor, masks: torch.Tensor
) -> Tuple[List[DistInstance], List[DiscreteDistInstance]]:
# continuous_distributions: List[DistInstance] = []
# discrete_distributions: List[DiscreteDistInstance] = []
continuous_dist_instances = self.continuous_distribution(
inputs
) # for continuous_dist in self.continuous_distributions]
discrete_dist_instances = self.discrete_distribution(
inputs, masks
) # for discrete_dist in self.discrete_distributions]
# for continuous_dist in self.continuous_distributions:
#for discrete_dist in self.discrete_distributions:
# for discrete_dist in self.discrete_distributions:
def forward(self, inputs: torch.Tensor, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def _get_discrete_continuous_outputs(
self, inputs: torch.Tensor, masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
continuous_entropies, continuous_log_probs, continuous_all_probs = ModelUtils.get_probs_and_entropy(
continuous_action_list, continuous_dists
)
(
continuous_entropies,
continuous_log_probs,
continuous_all_probs,
) = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists)
discrete_entropies, discrete_log_probs, discrete_all_probs = ModelUtils.get_probs_and_entropy(
discrete_action_list, discrete_dists
)
(
discrete_entropies,
discrete_log_probs,
discrete_all_probs,
) = ModelUtils.get_probs_and_entropy(discrete_action_list, discrete_dists)
action = torch.cat([continuous_actions, discrete_actions.type(torch.float)], dim=1)
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1)
entropies = torch.cat([continuous_entropies, discrete_entropies], dim=1)
return (
discrete_actions,
discrete_entropies,
discrete_log_probs,
discrete_all_probs,
continuous_actions,
continuous_entropies,
continuous_log_probs,
continuous_all_probs,
)
def forward_all_disc_probs(
self, inputs: torch.Tensor, masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(
discrete_actions,
discrete_entropies,
_,
discrete_all_probs,
continuous_actions,
continuous_entropies,
continuous_log_probs,
_,
) = self._get_discrete_continuous_outputs(inputs, masks)
action = torch.cat(
[continuous_actions, discrete_actions.type(torch.float)], dim=1
)
log_probs = torch.cat([continuous_log_probs, discrete_all_probs], dim=1)
entropies = torch.cat([continuous_entropies, discrete_entropies], dim=1)
return (action, log_probs, entropies)
def forward(
self, inputs: torch.Tensor, masks: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
(
discrete_actions,
discrete_entropies,
discrete_log_probs,
_,
continuous_actions,
continuous_entropies,
continuous_log_probs,
_,
) = self._get_discrete_continuous_outputs(inputs, masks)
action = torch.cat(
[continuous_actions, discrete_actions.type(torch.float)], dim=1
)
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1)
entropies = torch.cat([continuous_entropies, discrete_entropies], dim=1)
return (action, log_probs, entropies)

12
ml-agents/mlagents/trainers/torch/networks.py


masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
all_discrete_probs: bool = False,
action, log_probs, entropies = self.action_model(encoding, masks)
if all_discrete_probs:
action, log_probs, entropies = self.action_model.forward_all_disc_probs(encoding, masks)
else:
action, log_probs, entropies = self.action_model(encoding, masks)
value_outputs = self.value_heads(encoding)
return action, log_probs, entropies, value_outputs, memories

masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
all_discrete_probs: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
if self.use_lstm:
# Use only the back half of memories for critic and actor

encoding, memories = self.network_body(
vec_inputs, vis_inputs, memories=memories, sequence_length=sequence_length
)
action, log_probs, entropies = self.action_model(encoding, masks)
if all_discrete_probs:
action, log_probs, entropies = self.action_model.forward_all_disc_probs(encoding, masks)
else:
action, log_probs, entropies = self.action_model(encoding, masks)
value_outputs, critic_mem_outs = self.critic(
vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
)

正在加载...
取消
保存