|
|
|
|
|
|
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) |
|
|
|
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) |
|
|
|
|
|
|
|
# Need to pass all posslible inputs since currently keyword arguments is not |
|
|
|
# supported by torch.nn.export() |
|
|
|
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) |
|
|
|
|
|
|
|
# Input names can only contain actual input used since in torch.nn.export |
|
|
|
# it maps input_names only to input nodes that exist in the graph |
|
|
|
self.input_names = [] |
|
|
|
self.dynamic_axes = {"action": {0: "batch"}, "action_probs": {0: "batch"}} |
|
|
|
if self.policy.use_vec_obs: |
|
|
|
|
|
|
"is_continuous_control", |
|
|
|
"action_output_shape", |
|
|
|
] |
|
|
|
|
|
|
|
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) |
|
|
|
|
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
""" |
|
|
|