|
|
|
|
|
|
from enum import Enum |
|
|
|
from typing import Callable, List, NamedTuple |
|
|
|
from typing import Callable, NamedTuple |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
conv_3 = torch.relu(self.conv3(conv_2)) |
|
|
|
hidden = self.dense(conv_3.reshape([-1, self.final_flat])) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
|
class DiscreteActionMask(nn.Module): |
|
|
|
def __init__(self, action_size): |
|
|
|
super(DiscreteActionMask, self).__init__() |
|
|
|
self.action_size = action_size |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def break_into_branches( |
|
|
|
concatenated_logits: torch.Tensor, action_size: List[int] |
|
|
|
) -> List[torch.Tensor]: |
|
|
|
""" |
|
|
|
Takes a concatenated set of logits that represent multiple discrete action branches |
|
|
|
and breaks it up into one Tensor per branch. |
|
|
|
:param concatenated_logits: Tensor that represents the concatenated action branches |
|
|
|
:param action_size: List of ints containing the number of possible actions for each branch. |
|
|
|
:return: A List of Tensors containing one tensor per branch. |
|
|
|
""" |
|
|
|
action_idx = [0] + list(np.cumsum(action_size)) |
|
|
|
branched_logits = [ |
|
|
|
concatenated_logits[:, action_idx[i] : action_idx[i + 1]] |
|
|
|
for i in range(len(action_size)) |
|
|
|
] |
|
|
|
return branched_logits |
|
|
|
|
|
|
|
def forward(self, branches_logits, action_masks): |
|
|
|
branch_masks = self.break_into_branches(action_masks, self.action_size) |
|
|
|
raw_probs = [ |
|
|
|
torch.mul( |
|
|
|
torch.softmax(branches_logits[k], dim=-1) + EPSILON, branch_masks[k] |
|
|
|
) |
|
|
|
for k in range(len(self.action_size)) |
|
|
|
] |
|
|
|
normalized_probs = [ |
|
|
|
torch.div(raw_probs[k], torch.sum(raw_probs[k], dim=1, keepdims=True)) |
|
|
|
for k in range(len(self.action_size)) |
|
|
|
] |
|
|
|
output = torch.cat( |
|
|
|
[ |
|
|
|
torch.multinomial(torch.log(normalized_probs[k] + EPSILON), 1) |
|
|
|
for k in range(len(self.action_size)) |
|
|
|
], |
|
|
|
dim=1, |
|
|
|
) |
|
|
|
return ( |
|
|
|
output, |
|
|
|
torch.cat( |
|
|
|
[normalized_probs[k] for k in range(len(self.action_size))], dim=1 |
|
|
|
), |
|
|
|
torch.cat( |
|
|
|
[ |
|
|
|
torch.log(normalized_probs[k] + EPSILON) |
|
|
|
for k in range(len(self.action_size)) |
|
|
|
], |
|
|
|
axis=1, |
|
|
|
), |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class GlobalSteps(nn.Module): |
|
|
|