浏览代码

Modified the model_serialization to have correct inputs and outputs

/develop/gail-srl-hack
vincentpierre 4 年前
当前提交
22db0335
共有 2 个文件被更改,包括 59 次插入14 次删除
  1. 71
      ml-agents/mlagents/trainers/torch/model_serialization.py
  2. 2
      ml-agents/mlagents/trainers/torch/networks.py

71
ml-agents/mlagents/trainers/torch/model_serialization.py


return exporting_to_onnx._local_data._is_exporting
class TensorNames:
BatchSizePlaceholder = "batch_size"
SequenceLengthPlaceholder = "sequence_length"
VectorObservationPlaceholder = "vector_observation"
RecurrentInPlaceholder = "recurrent_in"
recurrentInPlaceholderH = "recurrent_in_h"
recurrentInPlaceholderC = "recurrent_in_c"
VisualObservationPlaceholderPrefix = "visual_observation_"
ObservationPlaceholderPrefix = "obs_"
PreviousActionPlaceholder = "prev_action"
ActionMaskPlaceholder = "action_masks"
RandomNormalEpsilonPlaceholder = "epsilon"
ValueEstimateOutput = "value_estimate"
RecurrentOutput = "recurrent_out"
recurrentOutputH = "recurrent_out_h"
recurrentOutputC = "recurrent_out_c"
MemorySize = "memory_size"
VersionNumber = "version_number"
ContinuousActionOutputShape = "continuous_action_output_shape"
DiscreteActionOutputShape = "discrete_action_output_shape"
ContinuousActionOutput = "continuous_actions"
DiscreteActionOutput = "discrete_actions"
# Deprecated TensorNames entries for backward compatibility
IsContinuousControlDeprecated = "is_continuous_control"
ActionOutputDeprecated = "action"
ActionOutputShapeDeprecated = "action_output_shape"
class ModelSerializer:
def __init__(self, policy):
# ONNX only support input in NCHW (channel first) format.

dummy_memories,
)
self.input_names = ["vector_observation"]
self.input_names = [TensorNames.VectorObservationPlaceholder]
self.input_names.append(f"visual_observation_{i}")
self.input_names.append(
TensorNames.VisualObservationPlaceholderPrefix + str(i)
)
self.input_names.append(f"obs_{i}")
self.input_names += ["action_masks", "memories"]
self.input_names.append(
TensorNames.ObservationPlaceholderPrefix + str(i)
)
self.input_names += [
TensorNames.ActionMaskPlaceholder,
TensorNames.RecurrentInPlaceholder,
]
self.output_names = ["version_number", "memory_size"]
self.output_names = [TensorNames.VersionNumber, TensorNames.MemorySize]
"continuous_actions",
"continuous_action_output_shape",
TensorNames.ContinuousActionOutput,
TensorNames.ContinuousActionOutputShape,
self.dynamic_axes.update({"continuous_actions": {0: "batch"}})
self.dynamic_axes.update({TensorNames.ContinuousActionOutput: {0: "batch"}})
self.output_names += ["discrete_actions", "discrete_action_output_shape"]
self.dynamic_axes.update({"discrete_actions": {0: "batch"}})
self.output_names += [
TensorNames.DiscreteActionOutput,
TensorNames.DiscreteActionOutputShape,
]
self.dynamic_axes.update({TensorNames.DiscreteActionOutput: {0: "batch"}})
"action",
"is_continuous_control",
"action_output_shape",
TensorNames.ActionOutputDeprecated,
TensorNames.IsContinuousControlDeprecated,
TensorNames.ActionOutputShapeDeprecated,
self.dynamic_axes.update({"action": {0: "batch"}})
self.dynamic_axes.update({TensorNames.ActionOutputDeprecated: {0: "batch"}})
if self.policy.export_memory_size > 0:
self.output_names += [TensorNames.RecurrentOutput]
def export_policy_model(self, output_filepath: str) -> None:
"""

2
ml-agents/mlagents/trainers/torch/networks.py


self.is_continuous_int_deprecated,
self.act_size_vector_deprecated,
]
if self.network_body.memory_size > 0:
export_out += [memories_out]
return tuple(export_out)

正在加载...
取消
保存