|
|
|
|
|
|
self.export_model(self.global_step) |
|
|
|
|
|
|
|
def load_model(self, step=0): # TODO: this doesn't work |
|
|
|
load_path = self.model_path + "/model-" + str(step) + ".pt" |
|
|
|
load_path = os.path.join(self.model_path, "model-" + str(step) + ".pt") |
|
|
|
self.actor_critic.load_state_dict(torch.load(load_path)) |
|
|
|
|
|
|
|
def export_model(self, step=0): |
|
|
|
|
|
|
print(fake_vec_obs[0].shape, fake_vis_obs[0].shape, fake_masks.shape) |
|
|
|
export_path = "./model-" + str(step) + ".onnx" |
|
|
|
output_names = ["action", "action_probs"] |
|
|
|
export_path = os.path.join(self.model_path, "model-" + str(step) + ".onnx") |
|
|
|
output_names = ["action", "action_probs", "is_continuous_control", \ |
|
|
|
"version_number", "memory_size", "action_output_shape"] |
|
|
|
input_names = ["vector_observation", "action_mask"] |
|
|
|
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]} |
|
|
|
onnx.export( |
|
|
|
|
|
|
verbose=True, |
|
|
|
opset_version=12, |
|
|
|
opset_version=9, |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|