|
|
|
|
|
|
for shape in self.policy.behavior_spec.observation_shapes: |
|
|
|
if len(shape) == 1: |
|
|
|
vec_obs_size += shape[0] |
|
|
|
num_vis_obs = sum(1 for shape in self.policy.behavior_spec.observation_shapes if len(shape) == 3) |
|
|
|
num_vis_obs = sum( |
|
|
|
1 |
|
|
|
for shape in self.policy.behavior_spec.observation_shapes |
|
|
|
if len(shape) == 3 |
|
|
|
) |
|
|
|
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])] |
|
|
|
# create input shape of NCHW |
|
|
|
# (It's NHWC in self.policy.behavior_spec.observation_shapes) |
|
|
|