浏览代码

-

/exp-robot
vincentpierre 4 年前
当前提交
6a61eb05
共有 2 个文件被更改,包括 27 次插入20 次删除
  1. 22
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 25
      ml-agents/mlagents/trainers/torch/networks.py

22
ml-agents/mlagents/trainers/torch/model_serialization.py


self.dynamic_axes = {name: {0: "batch"} for name in self.input_names}
self.output_names = ["version_number", "memory_size"]
if self.policy.behavior_spec.action_spec.continuous_size > 0:
if True:
self.output_names += [
"continuous_actions",
"continuous_action_output_shape",

self.output_names += ["discrete_actions", "discrete_action_output_shape"]
self.dynamic_axes.update({"discrete_actions": {0: "batch"}})
if (
self.policy.behavior_spec.action_spec.continuous_size == 0
or self.policy.behavior_spec.action_spec.discrete_size == 0
):
self.output_names += [
"action",
"is_continuous_control",
"action_output_shape",
]
self.dynamic_axes.update({"action": {0: "batch"}})
# if (
# self.policy.behavior_spec.action_spec.continuous_size == 0
# or self.policy.behavior_spec.action_spec.discrete_size == 0
# ):
# self.output_names += [
# "action",
# "is_continuous_control",
# "action_output_shape",
# ]
# self.dynamic_axes.update({"action": {0: "batch"}})
def export_policy_model(self, output_filepath: str) -> None:
"""

25
ml-agents/mlagents/trainers/torch/networks.py


loss = torch.mean(loss)
return loss
def get_prediction(self, inputs: List[torch.Tensor]) -> torch.Tensor:
prediction, _ = self.forward(inputs)
prediction = self.surrogate_predictor(prediction)
return prediction
class ValueNetwork(nn.Module):
def __init__(

action_out_deprecated,
) = self.action_model.get_action_out(encoding, masks)
export_out = [self.version_number, self.memory_size_vector]
if self.action_spec.continuous_size > 0:
export_out += [cont_action_out, self.continuous_act_size_vector]
if True:
# export_out += [cont_action_out, self.continuous_act_size_vector]
export_out += [self.network_body.get_prediction(inputs), torch.nn.Parameter(
torch.Tensor([int(9)]), requires_grad=False
)]
# Only export deprecated nodes with non-hybrid action spec
if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0:
export_out += [
action_out_deprecated,
self.is_continuous_int_deprecated,
self.act_size_vector_deprecated,
]
# # Only export deprecated nodes with non-hybrid action spec
# if self.action_spec.continuous_size == 0 or self.action_spec.discrete_size == 0:
# export_out += [
# action_out_deprecated,
# self.is_continuous_int_deprecated,
# self.act_size_vector_deprecated,
# ]
return tuple(export_out)

正在加载...
取消
保存