|
|
|
|
|
|
self, |
|
|
|
obs_only: List[List[torch.Tensor]], |
|
|
|
obs: List[List[torch.Tensor]], |
|
|
|
actions: Optional[List[AgentAction]], |
|
|
|
actions: List[AgentAction], |
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
concat_f_inp = [] |
|
|
|
if actions is not None: |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
for inputs, action in zip(obs, actions): |
|
|
|
encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
obs_input[obs_input.isnan()] = 0.0 # Remove NaNs |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
encodes.append(processed_obs) |
|
|
|
cat_encodes = [ |
|
|
|
torch.cat(encodes, dim=-1), |
|
|
|
action.to_flat(self.action_spec.discrete_branches), |
|
|
|
] |
|
|
|
concat_f_inp.append(torch.cat(cat_encodes, dim=1)) |
|
|
|
|
|
|
|
if concat_f_inp: |
|
|
|
f_inp = torch.stack(concat_f_inp, dim=1) |
|
|
|