|
|
|
|
|
|
memories: Optional[torch.Tensor] = None, |
|
|
|
sequence_length: int = 1, |
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
encodes = [] |
|
|
|
goal_signal = None |
|
|
|
obs_encodes = [] |
|
|
|
goal_encodes = [] |
|
|
|
for idx, processor in enumerate(self.processors): |
|
|
|
obs_input = inputs[idx] |
|
|
|
processed_obs = processor(obs_input) |
|
|
|
|
|
|
): |
|
|
|
encodes.append(processed_obs) |
|
|
|
obs_encodes.append(processed_obs) |
|
|
|
if goal_signal is not None: |
|
|
|
raise Exception("TODO : Cannot currently handle more than one goal") |
|
|
|
goal_signal = processed_obs |
|
|
|
goal_encodes.append(processed_obs) |
|
|
|
else: |
|
|
|
raise Exception("TODO : Something other than a goal or observation was passed to the agent.") |
|
|
|
if len(encodes) == 0: |
|
|
|
if len(obs_encodes) == 0: |
|
|
|
inputs = torch.cat(encodes + [actions], dim=-1) |
|
|
|
obs_inputs = torch.cat(obs_encodes + [actions], dim=-1) |
|
|
|
inputs = torch.cat(encodes, dim=-1) |
|
|
|
obs_inputs = torch.cat(obs_encodes, dim=-1) |
|
|
|
if goal_signal is None: |
|
|
|
encoding = self.linear_encoder(inputs) |
|
|
|
if len(goal_encodes) == 0: |
|
|
|
encoding = self.linear_encoder(obs_inputs) |
|
|
|
encoding = self.linear_encoder(inputs, goal_signal) |
|
|
|
goal_inputs = torch.cat(goal_encodes, dim=-1) |
|
|
|
encoding = self.linear_encoder(obs_inputs, goal_inputs) |
|
|
|
|
|
|
|
if self.use_lstm: |
|
|
|
# Resize to (batch, sequence length, encoding size) |
|
|
|