浏览代码

Fix ONNX import for continuous

/develop/add-fire/ckpt-2
GitHub 4 年前
当前提交
ce6ab357
共有 3 个文件被更改,包括 18 次插入19 次删除
  1. 4
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 31
      ml-agents/mlagents/trainers/torch/networks.py
  3. 2
      ml-agents/mlagents/trainers/trainer/rl_trainer.py

4
ml-agents/mlagents/trainers/torch/model_serialization.py


self.input_names = [
"vector_observation",
"visual_observation",
"action_mask",
"action_masks",
"memories",
]
self.output_names = [

self.dynamic_axes = {
"vector_observation": [0],
"visual_observation": [0],
"action_mask": [0],
"action_masks": [0],
"memories": [0],
"action": [0],
"action_probs": [0],

31
ml-agents/mlagents/trainers/torch/networks.py


memories: Optional[torch.Tensor] = None,
sequence_length: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
vec_encodes = []
encodes = []
for idx, encoder in enumerate(self.vector_encoders):
vec_input = vec_inputs[idx]
if actions is not None:

vec_encodes.append(hidden)
encodes.append(hidden)
vis_encodes = []
vis_encodes.append(hidden)
encodes.append(hidden)
if len(vec_encodes) > 0 and len(vis_encodes) > 0:
vec_encodes_tensor = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
vis_encodes_tensor = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
encoding = torch.stack(
[vec_encodes_tensor, vis_encodes_tensor], dim=-1
).sum(dim=-1)
elif len(vec_encodes) > 0:
encoding = torch.stack(vec_encodes, dim=-1).sum(dim=-1)
elif len(vis_encodes) > 0:
encoding = torch.stack(vis_encodes, dim=-1).sum(dim=-1)
else:
if len(encodes) == 0:
# Constants don't work in Barracuda
encoding = encodes[0]
if len(encodes) > 1:
for _enc in encodes[1:]:
encoding += _enc
if self.use_lstm:
encoding = encoding.view([sequence_length, -1, self.h_size])

)
action_list = self.sample_action(dists)
sampled_actions = torch.stack(action_list, dim=-1)
if self.act_type == ActionType.CONTINUOUS:
log_probs = dists[0].log_prob(sampled_actions)
else:
log_probs = dists[0].all_log_prob()
dists[0].pdf(sampled_actions),
log_probs,
self.version_number,
self.memory_size,
self.is_continuous_int,

2
ml-agents/mlagents/trainers/trainer/rl_trainer.py


def create_saver(
framework: str, trainer_settings: TrainerSettings, model_path: str, load: bool
) -> BaseSaver:
if framework == "torch":
if framework == FrameworkType.PYTORCH:
saver = TorchSaver( # type: ignore
trainer_settings, model_path, load
)

正在加载...
取消
保存