|
|
|
|
|
|
batch_dim = [1] |
|
|
|
seq_len_dim = [1] |
|
|
|
vec_obs_size = 0 |
|
|
|
for shape in self.policy.behavior_spec.observation_shapes: |
|
|
|
if len(shape) == 1: |
|
|
|
vec_obs_size += shape[0] |
|
|
|
for sens_spec in self.policy.behavior_spec.sensor_specs: |
|
|
|
if len(sens_spec.shape) == 1: |
|
|
|
vec_obs_size += sens_spec.shape[0] |
|
|
|
for shape in self.policy.behavior_spec.observation_shapes |
|
|
|
if len(shape) == 3 |
|
|
|
for sens_spec in self.policy.behavior_spec.sensor_specs |
|
|
|
if len(sens_spec.shape) == 3 |
|
|
|
) |
|
|
|
dummy_vec_obs = [torch.zeros(batch_dim + [vec_obs_size])] |
|
|
|
# create input shape of NCHW |
|
|
|