|
|
|
|
|
|
class ModelSerializer: |
|
|
|
def __init__(self, policy): |
|
|
|
self.policy = policy |
|
|
|
# dimension for batch (and sequence_length if use recurrent) |
|
|
|
dummy_dim = [1, 1] if self.policy.use_recurrent else [1] |
|
|
|
|
|
|
|
dummy_vec_obs = [torch.zeros(dummy_dim + [self.policy.vec_obs_size])] |
|
|
|
batch_dim = [1] |
|
|
|
dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])] |
|
|
|
[torch.zeros(dummy_dim + list(self.policy.vis_obs_shape))] |
|
|
|
[torch.zeros(batch_dim + list(self.policy.vis_obs_shape))] |
|
|
|
dummy_masks = torch.ones([1] + self.policy.actor_critic.act_size) |
|
|
|
dummy_memories = torch.zeros(dummy_dim + [self.policy.m_size]) |
|
|
|
dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) |
|
|
|
dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) |
|
|
|
self.input_names = [ |
|
|
|
"vector_observation", |
|
|
|
"visual_observation", |
|
|
|
"action_mask", |
|
|
|
"memories", |
|
|
|
] |
|
|
|
self.input_names = [] |
|
|
|
self.dynamic_axes = {"action": {0: 'batch'}, "action_probs": {0: 'batch'}} |
|
|
|
if self.policy.use_vec_obs: |
|
|
|
self.input_names.append("vector_observation") |
|
|
|
self.dynamic_axes.update({"vector_observation": {0: 'batch'}}) |
|
|
|
if self.policy.use_vis_obs: |
|
|
|
self.input_names.append("visual_observation") |
|
|
|
self.dynamic_axes.update({"visual_observation": {0: 'batch'}}) |
|
|
|
if not self.policy.use_continuous_act: |
|
|
|
self.input_names.append("action_mask") |
|
|
|
self.dynamic_axes.update({"action_mask": {0: 'batch'}}) |
|
|
|
if self.policy.use_recurrent: |
|
|
|
self.input_names.append("memories") |
|
|
|
self.dynamic_axes.update({"memories": {0: 'batch'}}) |
|
|
|
|
|
|
|
self.output_names = [ |
|
|
|
"action", |
|
|
|
"action_probs", |
|
|
|
|
|
|
"action_output_shape", |
|
|
|
] |
|
|
|
self.dynamic_axes = { |
|
|
|
"vector_observation": [0], |
|
|
|
"visual_observation": [0], |
|
|
|
"action_mask": [0], |
|
|
|
"memories": [0], |
|
|
|
"action": [0], |
|
|
|
"action_probs": [0], |
|
|
|
} |
|
|
|
|
|
|
|
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) |
|
|
|
|
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
|
|
|
self.policy.actor_critic, |
|
|
|
self.dummy_input, |
|
|
|
onnx_output_path, |
|
|
|
verbose=True, |
|
|
|
verbose=False, |
|
|
|
opset_version=SerializationSettings.onnx_opset, |
|
|
|
input_names=self.input_names, |
|
|
|
output_names=self.output_names, |
|
|
|