|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorNames: |
|
|
|
BatchSizePlaceholder = "batch_size" |
|
|
|
SequenceLengthPlaceholder = "sequence_length" |
|
|
|
VectorObservationPlaceholder = "vector_observation" |
|
|
|
RecurrentInPlaceholder = "recurrent_in" |
|
|
|
recurrentInPlaceholderH = "recurrent_in_h" |
|
|
|
recurrentInPlaceholderC = "recurrent_in_c" |
|
|
|
VisualObservationPlaceholderPrefix = "visual_observation_" |
|
|
|
ObservationPlaceholderPrefix = "obs_" |
|
|
|
PreviousActionPlaceholder = "prev_action" |
|
|
|
ActionMaskPlaceholder = "action_masks" |
|
|
|
RandomNormalEpsilonPlaceholder = "epsilon" |
|
|
|
batch_size_placeholder = "batch_size" |
|
|
|
sequence_length_placeholder = "sequence_length" |
|
|
|
vector_observation_placeholder = "vector_observation" |
|
|
|
recurrent_in_placeholder = "recurrent_in" |
|
|
|
recurrent_in_placeholder_h = "recurrent_in_h" |
|
|
|
recurrent_in_placeholder_c = "recurrent_in_c" |
|
|
|
visual_observation_placeholder_prefix = "visual_observation_" |
|
|
|
observation_placeholder_prefix = "obs_" |
|
|
|
previous_action_placeholder = "prev_action" |
|
|
|
action_mask_placeholder = "action_masks" |
|
|
|
random_normal_epsilon_placeholder = "epsilon" |
|
|
|
ValueEstimateOutput = "value_estimate" |
|
|
|
RecurrentOutput = "recurrent_out" |
|
|
|
recurrentOutputH = "recurrent_out_h" |
|
|
|
recurrentOutputC = "recurrent_out_c" |
|
|
|
MemorySize = "memory_size" |
|
|
|
VersionNumber = "version_number" |
|
|
|
ContinuousActionOutputShape = "continuous_action_output_shape" |
|
|
|
DiscreteActionOutputShape = "discrete_action_output_shape" |
|
|
|
ContinuousActionOutput = "continuous_actions" |
|
|
|
DiscreteActionOutput = "discrete_actions" |
|
|
|
value_estimate_output = "value_estimate" |
|
|
|
recurrent_output = "recurrent_out" |
|
|
|
recurrent_output_h = "recurrent_out_h" |
|
|
|
recurrent_output_c = "recurrent_out_c" |
|
|
|
memory_size = "memory_size" |
|
|
|
version_number = "version_number" |
|
|
|
continuous_action_output_shape = "continuous_action_output_shape" |
|
|
|
discrete_action_output_shape = "discrete_action_output_shape" |
|
|
|
continuous_action_output = "continuous_actions" |
|
|
|
discrete_action_output = "discrete_actions" |
|
|
|
IsContinuousControlDeprecated = "is_continuous_control" |
|
|
|
ActionOutputDeprecated = "action" |
|
|
|
ActionOutputShapeDeprecated = "action_output_shape" |
|
|
|
is_continuous_control_deprecated = "is_continuous_control" |
|
|
|
action_output_deprecated = "action" |
|
|
|
action_output_shape_deprecated = "action_output_shape" |
|
|
|
|
|
|
|
|
|
|
|
class ModelSerializer: |
|
|
|
|
|
|
dummy_memories, |
|
|
|
) |
|
|
|
|
|
|
|
self.input_names = [TensorNames.VectorObservationPlaceholder] |
|
|
|
self.input_names = [TensorNames.vector_observation_placeholder] |
|
|
|
TensorNames.VisualObservationPlaceholderPrefix + str(i) |
|
|
|
TensorNames.visual_observation_placeholder_prefix + str(i) |
|
|
|
TensorNames.ObservationPlaceholderPrefix + str(i) |
|
|
|
TensorNames.observation_placeholder_prefix + str(i) |
|
|
|
TensorNames.ActionMaskPlaceholder, |
|
|
|
TensorNames.RecurrentInPlaceholder, |
|
|
|
TensorNames.action_mask_placeholder, |
|
|
|
TensorNames.recurrent_in_placeholder, |
|
|
|
self.output_names = [TensorNames.VersionNumber, TensorNames.MemorySize] |
|
|
|
self.output_names = [TensorNames.version_number, TensorNames.memory_size] |
|
|
|
TensorNames.ContinuousActionOutput, |
|
|
|
TensorNames.ContinuousActionOutputShape, |
|
|
|
TensorNames.continuous_action_output, |
|
|
|
TensorNames.continuous_action_output_shape, |
|
|
|
self.dynamic_axes.update({TensorNames.ContinuousActionOutput: {0: "batch"}}) |
|
|
|
self.dynamic_axes.update( |
|
|
|
{TensorNames.continuous_action_output: {0: "batch"}} |
|
|
|
) |
|
|
|
TensorNames.DiscreteActionOutput, |
|
|
|
TensorNames.DiscreteActionOutputShape, |
|
|
|
TensorNames.discrete_action_output, |
|
|
|
TensorNames.discrete_action_output_shape, |
|
|
|
self.dynamic_axes.update({TensorNames.DiscreteActionOutput: {0: "batch"}}) |
|
|
|
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}}) |
|
|
|
TensorNames.ActionOutputDeprecated, |
|
|
|
TensorNames.IsContinuousControlDeprecated, |
|
|
|
TensorNames.ActionOutputShapeDeprecated, |
|
|
|
TensorNames.action_output_deprecated, |
|
|
|
TensorNames.is_continuous_control_deprecated, |
|
|
|
TensorNames.action_output_shape_deprecated, |
|
|
|
self.dynamic_axes.update({TensorNames.ActionOutputDeprecated: {0: "batch"}}) |
|
|
|
self.dynamic_axes.update( |
|
|
|
{TensorNames.action_output_deprecated: {0: "batch"}} |
|
|
|
) |
|
|
|
self.output_names += [TensorNames.RecurrentOutput] |
|
|
|
self.output_names += [TensorNames.recurrent_output] |
|
|
|
|
|
|
|
def export_policy_model(self, output_filepath: str) -> None: |
|
|
|
""" |
|
|
|