|
|
|
|
|
|
return exporting_to_onnx._local_data._is_exporting |
|
|
|
|
|
|
|
|
|
|
|
class TensorNames: |
|
|
|
batch_size_placeholder = "batch_size" |
|
|
|
sequence_length_placeholder = "sequence_length" |
|
|
|
vector_observation_placeholder = "vector_observation" |
|
|
|
recurrent_in_placeholder = "recurrent_in" |
|
|
|
visual_observation_placeholder_prefix = "visual_observation_" |
|
|
|
observation_placeholder_prefix = "obs_" |
|
|
|
previous_action_placeholder = "prev_action" |
|
|
|
action_mask_placeholder = "action_masks" |
|
|
|
random_normal_epsilon_placeholder = "epsilon" |
|
|
|
|
|
|
|
value_estimate_output = "value_estimate" |
|
|
|
recurrent_output = "recurrent_out" |
|
|
|
memory_size = "memory_size" |
|
|
|
version_number = "version_number" |
|
|
|
continuous_action_output_shape = "continuous_action_output_shape" |
|
|
|
discrete_action_output_shape = "discrete_action_output_shape" |
|
|
|
continuous_action_output = "continuous_actions" |
|
|
|
discrete_action_output = "discrete_actions" |
|
|
|
|
|
|
|
# Deprecated TensorNames entries for backward compatibility |
|
|
|
is_continuous_control_deprecated = "is_continuous_control" |
|
|
|
action_output_deprecated = "action" |
|
|
|
action_output_shape_deprecated = "action_output_shape" |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_visual_observation_name(index: int) -> str: |
|
|
|
""" |
|
|
|
Returns the name of the visual observation with a given index |
|
|
|
""" |
|
|
|
return TensorNames.visual_observation_placeholder_prefix + str(index) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_observation_name(index: int) -> str: |
|
|
|
""" |
|
|
|
Returns the name of the observation with a given index |
|
|
|
""" |
|
|
|
return TensorNames.observation_placeholder_prefix + str(index) |
|
|
|
|
|
|
|
|
|
|
|
class ModelSerializer: |
|
|
|
def __init__(self, policy): |
|
|
|
# ONNX only support input in NCHW (channel first) format. |
|
|
|
|
|
|
dummy_memories, |
|
|
|
) |
|
|
|
|
|
|
|
self.input_names = ["vector_observation"] |
|
|
|
self.input_names = [TensorNames.vector_observation_placeholder] |
|
|
|
self.input_names.append(f"visual_observation_{i}") |
|
|
|
self.input_names.append(TensorNames.get_visual_observation_name(i)) |
|
|
|
self.input_names.append(f"obs_{i}") |
|
|
|
self.input_names += ["action_masks", "memories"] |
|
|
|
self.input_names.append(TensorNames.get_observation_name(i)) |
|
|
|
self.input_names += [ |
|
|
|
TensorNames.action_mask_placeholder, |
|
|
|
TensorNames.recurrent_in_placeholder, |
|
|
|
] |
|
|
|
self.output_names = ["version_number", "memory_size"] |
|
|
|
self.output_names = [TensorNames.version_number, TensorNames.memory_size] |
|
|
|
"continuous_actions", |
|
|
|
"continuous_action_output_shape", |
|
|
|
TensorNames.continuous_action_output, |
|
|
|
TensorNames.continuous_action_output_shape, |
|
|
|
self.dynamic_axes.update({"continuous_actions": {0: "batch"}}) |
|
|
|
self.dynamic_axes.update( |
|
|
|
{TensorNames.continuous_action_output: {0: "batch"}} |
|
|
|
) |
|
|
|
self.output_names += ["discrete_actions", "discrete_action_output_shape"] |
|
|
|
self.dynamic_axes.update({"discrete_actions": {0: "batch"}}) |
|
|
|
self.output_names += [ |
|
|
|
TensorNames.discrete_action_output, |
|
|
|
TensorNames.discrete_action_output_shape, |
|
|
|
] |
|
|
|
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}}) |
|
|
|
"action", |
|
|
|
"is_continuous_control", |
|
|
|
"action_output_shape", |
|
|
|
TensorNames.action_output_deprecated, |
|
|
|
TensorNames.is_continuous_control_deprecated, |
|
|
|
TensorNames.action_output_shape_deprecated, |
|
|
|
self.dynamic_axes.update({"action": {0: "batch"}}) |
|
|
|
self.dynamic_axes.update( |
|
|
|
{TensorNames.action_output_deprecated: {0: "batch"}} |
|
|
|
) |
|
|
|
|
|
|
|
if self.policy.export_memory_size > 0: |
|
|
|
self.output_names += [TensorNames.recurrent_output] |
|
|
|
|
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
""" |
|
|
|