|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActionModel(nn.Module, abc.ABC): |
|
|
|
#@abc.abstractmethod |
|
|
|
#def entropy(self, action_list: np.ndarray) -> torch.Tensor: |
|
|
|
# pass |
|
|
|
#@abc.abstractmethod |
|
|
|
#def log_probs(self, action_list: np.ndarray) -> torch.Tensor: |
|
|
|
# pass |
|
|
|
|
|
|
|
""" |
|
|
|
Samples actions from list of distribution instances |
|
|
|
""" |
|
|
|
actions = [] |
|
|
|
for action_dist in dists: |
|
|
|
action = action_dist.sample() |
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def evaluate(self, inputs: torch.Tensor, masks: torch.Tensor, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
""" |
|
|
|
Returns the log_probs and entropies of actions |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: |
|
|
|
""" |
|
|
|
Returns the tensor to be exported to ONNX for the distribution |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
@abc.abstractmethod |
|
|
|
""" |
|
|
|
Returns the actions, log probs and entropies for given input |
|
|
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
class HybridActionModel(ActionModel): |
|
|
|
|
|
|
self.encoding_size = hidden_size |
|
|
|
self.continuous_act_size = continuous_act_size |
|
|
|
self.discrete_act_size = discrete_act_size |
|
|
|
self._distributions : List[Union[GaussianDistribution, MulticategoricalDistribution]] = [] |
|
|
|
self._split_list : List[int] = [] |
|
|
|
self.continuous_distribution = GaussianDistribution( |
|
|
|
self._distributions.append(GaussianDistribution( |
|
|
|
) |
|
|
|
self._split_list.append(continuous_act_size) |
|
|
|
self.discrete_distribution = MultiCategoricalDistribution(self.encoding_size, discrete_act_size) |
|
|
|
|
|
|
|
self._distributions.append(MultiCategoricalDistribution(self.encoding_size, discrete_act_size)) |
|
|
|
self._split_list += [1 for _ in range(len(discrete_act_size))] |
|
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks) |
|
|
|
continuous_actions, discrete_actions = torch.split(actions, [self.continuous_act_size, len(self.discrete_act_size)], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
continuous_action_list = [continuous_actions[..., i] for i in range(continuous_actions.shape[-1])] |
|
|
|
continuous_log_probs, continuous_entropies, _ = ModelUtils.get_probs_and_entropy(continuous_action_list, continuous_dists) |
|
|
|
|
|
|
|
discrete_action_list = [discrete_actions[:, i] for i in range(len(self.discrete_act_size))] |
|
|
|
discrete_log_probs, discrete_entropies, _ = ModelUtils.get_probs_and_entropy(discrete_action_list, discrete_dists) |
|
|
|
|
|
|
|
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1) |
|
|
|
entropies = torch.cat([continuous_entropies, torch.mean(discrete_entropies, dim=0).unsqueeze(0)], dim=1) |
|
|
|
dists = self._get_dists(inputs, masks) |
|
|
|
split_actions = torch.split(actions, self._split_list, dim=1) |
|
|
|
action_lists : List[torch.Tensor] = [] |
|
|
|
for split_action in split_actions: |
|
|
|
action_list = [split_action[..., i] for i in range(split_action.shape[-1])] |
|
|
|
action_lists += action_list |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) #self._get_stats(actions, dists) |
|
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks) |
|
|
|
dists = continuous_dists + discrete_dists |
|
|
|
dists = self._get_dists(inputs, masks) |
|
|
|
#continuous_distributions: List[DistInstance] = [] |
|
|
|
#discrete_distributions: List[DiscreteDistInstance] = [] |
|
|
|
continuous_dist_instances = self.continuous_distribution(inputs)# for continuous_dist in self.continuous_distributions] |
|
|
|
discrete_dist_instances = self.discrete_distribution(inputs, masks)# for discrete_dist in self.discrete_distributions] |
|
|
|
#for continuous_dist in self.continuous_distributions: |
|
|
|
# continuous_distributions += continuous_dist(inputs) |
|
|
|
#for discrete_dist in self.discrete_distributions: |
|
|
|
# discrete_distributions += discrete_dist(inputs, masks) |
|
|
|
return continuous_dist_instances, discrete_dist_instances |
|
|
|
distribution_instances: List[DistInstance] = [] |
|
|
|
for distribution in self._distributions: |
|
|
|
dist_instances = distribution(inputs, masks) |
|
|
|
for dist_instance in dist_instances: |
|
|
|
distribution_instances.append(dist_instance) |
|
|
|
return distribution_instances |
|
|
|
continuous_dists, discrete_dists = self._get_dists(inputs, masks) |
|
|
|
|
|
|
|
continuous_action_list = self._sample_action(continuous_dists) |
|
|
|
continuous_entropies, continuous_log_probs, continuous_all_probs = ModelUtils.get_probs_and_entropy( |
|
|
|
continuous_action_list, continuous_dists |
|
|
|
) |
|
|
|
continuous_actions = torch.stack(continuous_action_list, dim=-1) |
|
|
|
continuous_actions = continuous_actions[:, :, 0] |
|
|
|
|
|
|
|
discrete_action_list = self._sample_action(discrete_dists) |
|
|
|
discrete_entropies, discrete_log_probs, discrete_all_probs = ModelUtils.get_probs_and_entropy( |
|
|
|
discrete_action_list, discrete_dists |
|
|
|
) |
|
|
|
discrete_actions = torch.stack(discrete_action_list, dim=-1) |
|
|
|
discrete_actions = discrete_actions[:, 0, :] |
|
|
|
|
|
|
|
action = torch.cat([continuous_actions, discrete_actions.type(torch.float)], dim=1) |
|
|
|
log_probs = torch.cat([continuous_log_probs, discrete_log_probs], dim=1) |
|
|
|
entropies = torch.cat([continuous_entropies, discrete_entropies], dim=1) |
|
|
|
dists = self._get_dists(inputs, masks) |
|
|
|
action_outs : List[torch.Tensor] = [] |
|
|
|
action_lists = self._sample_action(dists) |
|
|
|
for action_list, dist in zip(action_lists, dists): |
|
|
|
action_out = action_list.unsqueeze(-1)#torch.stack(action_list, dim=-1) |
|
|
|
action_outs.append(dist.structure_action(action_out)) |
|
|
|
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_lists, dists) #self._get_stats(actions, dists)self._get_stats(action_lists, dists) |
|
|
|
action = torch.cat(action_outs, dim=1) |
|
|
|
return (action, log_probs, entropies) |