|
|
|
|
|
|
): |
|
|
|
visual_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) |
|
|
|
# Make sure to permute visual obs, as PyTorch uses NCHW |
|
|
|
visual_obs.append(visual_ob.permute([0, 3, 1, 2])) |
|
|
|
visual_obs.append(ModelUtils.nhwc_to_nchw(visual_ob)) |
|
|
|
else: |
|
|
|
visual_obs = [] |
|
|
|
|
|
|
|
|
|
|
ModelUtils.list_to_tensor(vec_vis_obs.vector_observations).unsqueeze(0) |
|
|
|
] |
|
|
|
next_vis_obs = [ |
|
|
|
ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0).permute([0, 3, 1, 2]) |
|
|
|
ModelUtils.nhwc_to_nchw(ModelUtils.list_to_tensor(_vis_ob).unsqueeze(0)) |
|
|
|
for _vis_ob in vec_vis_obs.visual_observations |
|
|
|
] |
|
|
|
|
|
|
|