|
|
|
|
|
|
class ModelSerializer: |
|
|
|
def __init__(self, policy): |
|
|
|
self.policy = policy |
|
|
|
dummy_vec_obs = [torch.zeros([1] + [self.policy.vec_obs_size])] |
|
|
|
# dimension for batch (and sequence_length if use recurrent) |
|
|
|
dummy_dim = [1, 1] if self.policy.use_recurrent else [1] |
|
|
|
|
|
|
|
dummy_vec_obs = [torch.zeros(dummy_dim + [self.policy.vec_obs_size])] |
|
|
|
[torch.zeros([1] + list(self.policy.vis_obs_shape))] |
|
|
|
[torch.zeros(dummy_dim + list(self.policy.vis_obs_shape))] |
|
|
|
dummy_memories = torch.zeros([1] + [self.policy.m_size]) |
|
|
|
dummy_memories = torch.zeros(dummy_dim + [self.policy.m_size]) |
|
|
|
|
|
|
|
self.input_names = [ |
|
|
|
"vector_observation", |
|
|
|
|
|
|
} |
|
|
|
self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
""" |
|
|
|
Exports a Torch model for a Policy to .onnx format for Unity embedding. |
|
|
|