浏览代码

Use Vince's ONNX export code

/comms-grad
Ervin Teng 4 年前
当前提交
2b8ab09d
共有 2 个文件被更改,包括 21 次插入4 次删除
  1. 2
      ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
  2. 23
      ml-agents/mlagents/trainers/torch/networks.py

2
ml-agents/mlagents/trainers/model_saver/torch_model_saver.py


}
torch.save(state_dict, f"{checkpoint_path}.pt")
torch.save(state_dict, os.path.join(self.model_path, "checkpoint.pt"))
# self.export(checkpoint_path, behavior_name)
self.export(checkpoint_path, behavior_name)
return checkpoint_path
def export(self, output_filepath: str, behavior_name: str) -> None:

23
ml-agents/mlagents/trainers/torch/networks.py


import enum
from typing import Callable, List, Dict, Tuple, Optional
import abc

from mlagents.trainers.torch.decoders import ValueHeads
from mlagents.trainers.torch.layers import LSTM, LinearEncoder
from mlagents.trainers.torch.model_serialization import exporting_to_onnx
from mlagents.trainers.torch.encoders import VectorInput
ActivationFunction = Callable[[torch.Tensor], torch.Tensor]
EncoderFunction = Callable[

network_settings.vis_encode_type,
normalize=self.normalize,
)
self.observation_shapes = observation_shapes
total_enc_size = encoder_input_size + encoded_act_size
self.linear_encoder = LinearEncoder(
total_enc_size, network_settings.num_layers, self.h_size

def forward(
self,
net_inputs: List[torch.Tensor],
vec_inputs: List[torch.Tensor],
vis_inputs: List[torch.Tensor],
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, int, int, int, int]:

dists, _ = self.get_dists(net_inputs, masks, memories, 1)
concatenated_vec_obs = vec_inputs[0]
inputs = []
start = 0
end = 0
vis_index = 0
for i, enc in enumerate(self.network_body.processors):
if isinstance(enc, VectorInput):
# This is a vec_obs
vec_size = self.network_body.observation_shapes[i][0]
end = start + vec_size
inputs.append(concatenated_vec_obs[:, start:end])
start = end
else:
inputs.append(vis_inputs[vis_index])
vis_index += 1
dists, _ = self.get_dists(inputs, masks, memories, 1)
if self.action_spec.is_continuous():
action_list = self.sample_action(dists)
action_out = torch.stack(action_list, dim=-1)

正在加载...
取消
保存