|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
vec_encodes = [] |
|
|
|
encodes = [] |
|
|
|
for idx, encoder in enumerate(self.vector_encoders): |
|
|
|
vec_input = vec_inputs[idx] |
|
|
|
if actions is not None: |
|
|
|
|
|
|
vec_encodes.append(hidden) |
|
|
|
encodes.append(hidden) |
|
|
|
vis_encodes = [] |
|
|
|
vis_encodes.append(hidden) |
|
|
|
encodes.append(hidden) |
|
|
|
if len(vec_encodes) > 0 and len(vis_encodes) > 0: |
|
|
|
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1) |
|
|
|
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1) |
|
|
|
encoding = torch.stack( |
|
|
|
[vec_encodes_tensor, vis_encodes_tensor], dim=-1 |
|
|
|
).sum(dim=-1) |
|
|
|
elif len(vec_encodes) > 0: |
|
|
|
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1) |
|
|
|
elif len(vis_encodes) > 0: |
|
|
|
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1) |
|
|
|
else: |
|
|
|
if len(encodes) == 0: |
|
|
|
|
|
|
|
# Constants don't work in Barracuda |
|
|
|
encoding = encodes[0] |
|
|
|
if len(encodes) > 1: |
|
|
|
for _enc in encodes[1:]: |
|
|
|
encoding += _enc |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
encoding = encoding.view([sequence_length, -1, self.h_size]) |
|
|
|
|
|
|
) |
|
|
|
action_list = self.sample_action(dists) |
|
|
|
sampled_actions = torch.stack(action_list, dim=-1) |
|
|
|
if self.act_type == ActionType.CONTINUOUS: |
|
|
|
log_probs = dists[0].log_prob(sampled_actions) |
|
|
|
else: |
|
|
|
log_probs = dists[0].all_log_prob() |
|
|
|
dists[0].pdf(sampled_actions), |
|
|
|
log_probs, |
|
|
|
self.version_number, |
|
|
|
self.memory_size, |
|
|
|
self.is_continuous_int, |
|
|
|