|
|
|
|
|
|
self.policy = policy |
|
|
|
batch_dim = [1] |
|
|
|
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])] |
|
|
|
dummy_vis_obs = ( |
|
|
|
[torch.zeros(batch_dim + list(self.policy.vis_obs_shape))] |
|
|
|
if self.policy.vis_obs_size > 0 |
|
|
|
else [] |
|
|
|
) |
|
|
|
dummy_vis_obs = [ |
|
|
|
torch.zeros(batch_dim + list(shape)) |
|
|
|
for shape in self.policy.behavior_spec.observation_shapes |
|
|
|
if len(shape) == 3 |
|
|
|
] |
|
|
|
# Need to pass all posslible inputs since currently keyword arguments is not |
|
|
|
# Need to pass all possible inputs since currently keyword arguments is not |
|
|
|
# supported by torch.nn.export() |
|
|
|
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) |
|
|
|
|
|
|
|
|
|
|
if self.policy.use_vec_obs: |
|
|
|
self.input_names.append("vector_observation") |
|
|
|
self.dynamic_axes.update({"vector_observation": {0: "batch"}}) |
|
|
|
if self.policy.use_vis_obs: |
|
|
|
self.input_names.append("visual_observation") |
|
|
|
self.dynamic_axes.update({"visual_observation": {0: "batch"}}) |
|
|
|
for i in range(self.policy.vis_obs_size): |
|
|
|
self.input_names.append(f"visual_observation_{i}") |
|
|
|
self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}}) |
|
|
|
if not self.policy.use_continuous_act: |
|
|
|
self.input_names.append("action_masks") |
|
|
|
self.dynamic_axes.update({"action_masks": {0: "batch"}}) |
|
|
|