浏览代码

fix onnx save path and output_name

/develop/add-fire
Ruo-Ping Dong 4 年前
当前提交
9449d711
共有 1 个文件被更改,包括 6 次插入4 次删除
  1. 10
      ml-agents/mlagents/trainers/policy/torch_policy.py

10
ml-agents/mlagents/trainers/policy/torch_policy.py


self.export_model(self.global_step)
def load_model(self, step=0): # TODO: this doesn't work
load_path = self.model_path + "/model-" + str(step) + ".pt"
load_path = os.path.join(self.model_path, "model-" + str(step) + ".pt")
self.actor_critic.load_state_dict(torch.load(load_path))
def export_model(self, step=0):

print(fake_vec_obs[0].shape, fake_vis_obs[0].shape, fake_masks.shape)
export_path = "./model-" + str(step) + ".onnx"
output_names = ["action", "action_probs"]
export_path = os.path.join(self.model_path, "model-" + str(step) + ".onnx")
output_names = ["action", "action_probs", "is_continuous_control", \
"version_number", "memory_size", "action_output_shape"]
input_names = ["vector_observation", "action_mask"]
dynamic_axes = {"vector_observation": [0], "action": [0], "action_probs": [0]}
onnx.export(

verbose=True,
opset_version=12,
opset_version=9,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,

正在加载...
取消
保存