|
|
|
|
|
|
""" |
|
|
|
Extracts the current state embedding from a mini_batch. |
|
|
|
""" |
|
|
|
n_vis = len(self._state_encoder.visual_processors) |
|
|
|
vec_inputs=[ |
|
|
|
ModelUtils.list_to_tensor(mini_batch["vector_obs"], dtype=torch.float) |
|
|
|
], |
|
|
|
vis_inputs=[ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
mini_batch["visual_obs%d" % i], dtype=torch.float |
|
|
|
) |
|
|
|
for i in range(n_vis) |
|
|
|
], |
|
|
|
net_inputs=ModelUtils.list_to_tensor_list( |
|
|
|
AgentBuffer.obs_list_to_obs_batch(mini_batch["obs"]), dtype=torch.float |
|
|
|
) |
|
|
|
) |
|
|
|
return hidden |
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
n_vis = len(self._state_encoder.visual_processors) |
|
|
|
vec_inputs=[ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
mini_batch["next_vector_in"], dtype=torch.float |
|
|
|
) |
|
|
|
], |
|
|
|
vis_inputs=[ |
|
|
|
ModelUtils.list_to_tensor( |
|
|
|
mini_batch["next_visual_obs%d" % i], dtype=torch.float |
|
|
|
) |
|
|
|
for i in range(n_vis) |
|
|
|
], |
|
|
|
net_inputs=ModelUtils.list_to_tensor_list( |
|
|
|
AgentBuffer.obs_list_to_obs_batch(mini_batch["next_obs"]), |
|
|
|
dtype=torch.float, |
|
|
|
) |
|
|
|
) |
|
|
|
return hidden |
|
|
|
|
|
|
|