|
|
|
|
|
|
return exporting_to_onnx._local_data._is_exporting |
|
|
|
|
|
|
|
|
|
|
|
class TensorNames: |
|
|
|
BatchSizePlaceholder = "batch_size" |
|
|
|
SequenceLengthPlaceholder = "sequence_length" |
|
|
|
VectorObservationPlaceholder = "vector_observation" |
|
|
|
RecurrentInPlaceholder = "recurrent_in" |
|
|
|
recurrentInPlaceholderH = "recurrent_in_h" |
|
|
|
recurrentInPlaceholderC = "recurrent_in_c" |
|
|
|
VisualObservationPlaceholderPrefix = "visual_observation_" |
|
|
|
ObservationPlaceholderPrefix = "obs_" |
|
|
|
PreviousActionPlaceholder = "prev_action" |
|
|
|
ActionMaskPlaceholder = "action_masks" |
|
|
|
RandomNormalEpsilonPlaceholder = "epsilon" |
|
|
|
|
|
|
|
ValueEstimateOutput = "value_estimate" |
|
|
|
RecurrentOutput = "recurrent_out" |
|
|
|
recurrentOutputH = "recurrent_out_h" |
|
|
|
recurrentOutputC = "recurrent_out_c" |
|
|
|
MemorySize = "memory_size" |
|
|
|
VersionNumber = "version_number" |
|
|
|
ContinuousActionOutputShape = "continuous_action_output_shape" |
|
|
|
DiscreteActionOutputShape = "discrete_action_output_shape" |
|
|
|
ContinuousActionOutput = "continuous_actions" |
|
|
|
DiscreteActionOutput = "discrete_actions" |
|
|
|
|
|
|
|
# Deprecated TensorNames entries for backward compatibility |
|
|
|
IsContinuousControlDeprecated = "is_continuous_control" |
|
|
|
ActionOutputDeprecated = "action" |
|
|
|
ActionOutputShapeDeprecated = "action_output_shape" |
|
|
|
|
|
|
|
|
|
|
|
class ModelSerializer: |
|
|
|
def __init__(self, policy): |
|
|
|
# ONNX only support input in NCHW (channel first) format. |
|
|
|
|
|
|
dummy_memories, |
|
|
|
) |
|
|
|
|
|
|
|
self.input_names = ["vector_observation"] |
|
|
|
self.input_names = [TensorNames.VectorObservationPlaceholder] |
|
|
|
self.input_names.append(f"visual_observation_{i}") |
|
|
|
self.input_names.append( |
|
|
|
TensorNames.VisualObservationPlaceholderPrefix + str(i) |
|
|
|
) |
|
|
|
self.input_names.append(f"obs_{i}") |
|
|
|
self.input_names += ["action_masks", "memories"] |
|
|
|
self.input_names.append( |
|
|
|
TensorNames.ObservationPlaceholderPrefix + str(i) |
|
|
|
) |
|
|
|
self.input_names += [ |
|
|
|
TensorNames.ActionMaskPlaceholder, |
|
|
|
TensorNames.RecurrentInPlaceholder, |
|
|
|
] |
|
|
|
self.output_names = ["version_number", "memory_size"] |
|
|
|
self.output_names = [TensorNames.VersionNumber, TensorNames.MemorySize] |
|
|
|
"continuous_actions", |
|
|
|
"continuous_action_output_shape", |
|
|
|
TensorNames.ContinuousActionOutput, |
|
|
|
TensorNames.ContinuousActionOutputShape, |
|
|
|
self.dynamic_axes.update({"continuous_actions": {0: "batch"}}) |
|
|
|
self.dynamic_axes.update({TensorNames.ContinuousActionOutput: {0: "batch"}}) |
|
|
|
self.output_names += ["discrete_actions", "discrete_action_output_shape"] |
|
|
|
self.dynamic_axes.update({"discrete_actions": {0: "batch"}}) |
|
|
|
self.output_names += [ |
|
|
|
TensorNames.DiscreteActionOutput, |
|
|
|
TensorNames.DiscreteActionOutputShape, |
|
|
|
] |
|
|
|
self.dynamic_axes.update({TensorNames.DiscreteActionOutput: {0: "batch"}}) |
|
|
|
"action", |
|
|
|
"is_continuous_control", |
|
|
|
"action_output_shape", |
|
|
|
TensorNames.ActionOutputDeprecated, |
|
|
|
TensorNames.IsContinuousControlDeprecated, |
|
|
|
TensorNames.ActionOutputShapeDeprecated, |
|
|
|
self.dynamic_axes.update({"action": {0: "batch"}}) |
|
|
|
self.dynamic_axes.update({TensorNames.ActionOutputDeprecated: {0: "batch"}}) |
|
|
|
|
|
|
|
if self.policy.export_memory_size > 0: |
|
|
|
self.output_names += [TensorNames.RecurrentOutput] |
|
|
|
|
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
""" |
|
|
|