|
|
|
|
|
|
from mlagents_envs.communicator_objects.space_type_pb2 import continuous |
|
|
|
import numpy as np |
|
|
|
from typing import Dict, List, Mapping, cast, Tuple, Optional |
|
|
|
from mlagents.torch_utils import torch, nn, default_device |
|
|
|
|
|
|
stream_names: List[str], |
|
|
|
observation_shapes: List[Tuple[int, ...]], |
|
|
|
network_settings: NetworkSettings, |
|
|
|
act_type: ActionType, |
|
|
|
act_size: List[int], |
|
|
|
continuous_act_size: List[int], |
|
|
|
discrete_act_size: List[int], |
|
|
|
if act_type == ActionType.CONTINUOUS: |
|
|
|
num_value_outs = 1 |
|
|
|
num_action_ins = sum(act_size) |
|
|
|
else: |
|
|
|
num_value_outs = sum(act_size) |
|
|
|
num_action_ins = 0 |
|
|
|
num_value_outs = max(sum(discrete_act_size), 1) |
|
|
|
num_action_ins = int(continuous_act_size) |
|
|
|
|
|
|
|
self.q1_network = ValueNetwork( |
|
|
|
stream_names, |
|
|
|
observation_shapes, |
|
|
|
|
|
|
self.init_entcoef = hyperparameters.init_entcoef |
|
|
|
|
|
|
|
self.policy = policy |
|
|
|
self.act_size = policy.act_size |
|
|
|
policy_network_settings = policy.network_settings |
|
|
|
|
|
|
|
self.tau = hyperparameters.tau |
|
|
|
|
|
|
self.stream_names, |
|
|
|
self.policy.behavior_spec.observation_shapes, |
|
|
|
policy_network_settings, |
|
|
|
self.policy.behavior_spec.action_type, |
|
|
|
self.act_size, |
|
|
|
self.policy.continuous_act_size, |
|
|
|
self.policy.discrete_act_size, |
|
|
|
) |
|
|
|
|
|
|
|
self.target_network = ValueNetwork( |
|
|
|
|
|
|
) |
|
|
|
self.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) |
|
|
|
|
|
|
|
_total_act_size = 0 |
|
|
|
if self.policy.continuous_act_size > 0: |
|
|
|
_total_act_size += 1 |
|
|
|
_total_act_size += len(self.policy.discrete_act_size) |
|
|
|
|
|
|
|
torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), |
|
|
|
torch.log(torch.as_tensor([self.init_entcoef] * _total_act_size)), |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
self.target_entropy = torch.as_tensor( |
|
|
|
-1 |
|
|
|
* self.continuous_target_entropy_scale |
|
|
|
* np.prod(self.act_size[0]).astype(np.float32) |
|
|
|
self.target_entropy = [] |
|
|
|
if self.policy.continuous_act_size > 0: |
|
|
|
self.target_entropy.append( |
|
|
|
torch.as_tensor( |
|
|
|
-1 |
|
|
|
* self.continuous_target_entropy_scale |
|
|
|
* np.prod(self.policy.continuous_act_size).astype(np.float32) |
|
|
|
) |
|
|
|
self.target_entropy = [ |
|
|
|
self.target_entropy += [ |
|
|
|
for i in self.act_size |
|
|
|
for i in self.policy.discrete_act_size |
|
|
|
self.policy.actor_critic.distribution.parameters() |
|
|
|
self.policy.actor_critic.action_model.parameters() |
|
|
|
) |
|
|
|
value_params = list(self.value_network.parameters()) + list( |
|
|
|
self.policy.actor_critic.critic.parameters() |
|
|
|
|
|
|
q1p_out: Dict[str, torch.Tensor], |
|
|
|
q2p_out: Dict[str, torch.Tensor], |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
discrete: bool, |
|
|
|
if not discrete: |
|
|
|
if len(self.policy.discrete_act_size) <= 0: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
disc_action_probs = log_probs[:, self.policy.continuous_act_size:].exp() |
|
|
|
q1p_out[name] * action_probs, self.act_size |
|
|
|
q1p_out[name] * disc_action_probs, self.policy.discrete_act_size |
|
|
|
q2p_out[name] * action_probs, self.act_size |
|
|
|
q2p_out[name] * disc_action_probs, self.policy.discrete_act_size |
|
|
|
) |
|
|
|
_q1p_mean = torch.mean( |
|
|
|
torch.stack( |
|
|
|
|
|
|
min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) |
|
|
|
|
|
|
|
value_losses = [] |
|
|
|
if not discrete: |
|
|
|
if len(self.policy.discrete_act_size) <= 0: |
|
|
|
for name in values.keys(): |
|
|
|
with torch.no_grad(): |
|
|
|
v_backup = min_policy_qs[name] - torch.sum( |
|
|
|
|
|
|
value_losses.append(value_loss) |
|
|
|
else: |
|
|
|
branched_per_action_ent = ModelUtils.break_into_branches( |
|
|
|
log_probs * log_probs.exp(), self.act_size |
|
|
|
log_probs * log_probs.exp(), self.policy.discrete_act_size |
|
|
|
) |
|
|
|
# We have to do entropy bonus per action branch |
|
|
|
branched_ent_bonus = torch.stack( |
|
|
|
|
|
|
log_probs: torch.Tensor, |
|
|
|
q1p_outs: Dict[str, torch.Tensor], |
|
|
|
loss_masks: torch.Tensor, |
|
|
|
discrete: bool, |
|
|
|
if not discrete: |
|
|
|
if len(self.policy.discrete_act_size) <= 0: |
|
|
|
action_probs = log_probs.exp() |
|
|
|
disc_log_probs = log_probs[:, self.policy.continuous_act_size:] |
|
|
|
disc_action_probs = disc_log_probs.exp() |
|
|
|
log_probs * action_probs, self.act_size |
|
|
|
disc_log_probs * disc_action_probs, self.policy.discrete_act_size |
|
|
|
mean_q1 * action_probs, self.act_size |
|
|
|
mean_q1 * disc_action_probs, self.policy.discrete_act_size |
|
|
|
) |
|
|
|
branched_policy_loss = torch.stack( |
|
|
|
[ |
|
|
|
|
|
|
return policy_loss |
|
|
|
|
|
|
|
def sac_entropy_loss( |
|
|
|
self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool |
|
|
|
self, log_probs: torch.Tensor, loss_masks: torch.Tensor |
|
|
|
if not discrete: |
|
|
|
if len(self.policy.discrete_act_size) <= 0: |
|
|
|
with torch.no_grad(): |
|
|
|
target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) |
|
|
|
entropy_loss = -torch.mean( |
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
branched_per_action_ent = ModelUtils.break_into_branches( |
|
|
|
log_probs * log_probs.exp(), self.act_size |
|
|
|
log_probs * log_probs.exp(), self.policy.discrete_act_size |
|
|
|
) |
|
|
|
target_current_diff_branched = torch.stack( |
|
|
|
[ |
|
|
|
|
|
|
self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor |
|
|
|
) -> Dict[str, torch.Tensor]: |
|
|
|
condensed_q_output = {} |
|
|
|
onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) |
|
|
|
onehot_actions = ModelUtils.actions_to_onehot( |
|
|
|
discrete_actions, self.policy.discrete_act_size |
|
|
|
) |
|
|
|
branched_q = ModelUtils.break_into_branches(item, self.act_size) |
|
|
|
branched_q = ModelUtils.break_into_branches( |
|
|
|
item, self.policy.discrete_act_size |
|
|
|
) |
|
|
|
only_action_qs = torch.stack( |
|
|
|
[ |
|
|
|
torch.sum(_act * _q, dim=1, keepdim=True) |
|
|
|
|
|
|
vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] |
|
|
|
next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] |
|
|
|
act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|
|
|
else: |
|
|
|
actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) |
|
|
|
|
|
|
|
actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) |
|
|
|
|
|
|
|
memories_list = [ |
|
|
|
ModelUtils.list_to_tensor(batch["memory"][i]) |
|
|
|
|
|
|
masks=act_masks, |
|
|
|
memories=memories, |
|
|
|
seq_len=self.policy.sequence_length, |
|
|
|
all_log_probs=not self.policy.use_continuous_act, |
|
|
|
all_log_probs=True, |
|
|
|
if self.policy.use_continuous_act: |
|
|
|
squeezed_actions = actions.squeeze(-1) |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
squeezed_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_stream, q2_stream = q1_out, q2_out |
|
|
|
|
|
|
|
cont_sampled_actions = sampled_actions[:, : self.policy.continuous_act_size] |
|
|
|
# disc_sampled_actions = sampled_actions[:, self.policy.continuous_act_size :] |
|
|
|
|
|
|
|
cont_actions = actions[:, : self.policy.continuous_act_size] |
|
|
|
disc_actions = actions[:, self.policy.continuous_act_size :] |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
cont_sampled_actions, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
cont_actions.squeeze(-1), |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
if self.policy.discrete_act_size: |
|
|
|
q1_stream = self._condense_q_streams(q1_out, disc_actions.squeeze(-1)) |
|
|
|
q2_stream = self._condense_q_streams(q2_out, disc_actions.squeeze(-1)) |
|
|
|
with torch.no_grad(): |
|
|
|
q1p_out, q2p_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_out, q2_out = self.value_network( |
|
|
|
vec_obs, |
|
|
|
vis_obs, |
|
|
|
memories=q_memories, |
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
q1_stream = self._condense_q_streams(q1_out, actions) |
|
|
|
q2_stream = self._condense_q_streams(q2_out, actions) |
|
|
|
q1_stream, q2_stream = q1_out, q2_out |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
target_values, _ = self.target_network( |
|
|
|
|
|
|
sequence_length=self.policy.sequence_length, |
|
|
|
) |
|
|
|
masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) |
|
|
|
use_discrete = not self.policy.use_continuous_act |
|
|
|
dones = ModelUtils.list_to_tensor(batch["done"]) |
|
|
|
|
|
|
|
q1_loss, q2_loss = self.sac_q_loss( |
|
|
|
|
|
|
log_probs, sampled_values, q1p_out, q2p_out, masks, use_discrete |
|
|
|
log_probs, sampled_values, q1p_out, q2p_out, masks |
|
|
|
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) |
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) |
|
|
|
policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) |
|
|
|
entropy_loss = self.sac_entropy_loss(log_probs, masks) |
|
|
|
|
|
|
|
total_value_loss = q1_loss + q2_loss + value_loss |
|
|
|
|
|
|
|