浏览代码

switching from CamelCase to snake_case

/develop/gail-srl-hack
vincentpierre 4 年前
当前提交
5b5171f8
共有 1 个文件被更改,包括 45 次插入41 次删除
  1. 86
      ml-agents/mlagents/trainers/torch/model_serialization.py

86
ml-agents/mlagents/trainers/torch/model_serialization.py


class TensorNames:
BatchSizePlaceholder = "batch_size"
SequenceLengthPlaceholder = "sequence_length"
VectorObservationPlaceholder = "vector_observation"
RecurrentInPlaceholder = "recurrent_in"
recurrentInPlaceholderH = "recurrent_in_h"
recurrentInPlaceholderC = "recurrent_in_c"
VisualObservationPlaceholderPrefix = "visual_observation_"
ObservationPlaceholderPrefix = "obs_"
PreviousActionPlaceholder = "prev_action"
ActionMaskPlaceholder = "action_masks"
RandomNormalEpsilonPlaceholder = "epsilon"
batch_size_placeholder = "batch_size"
sequence_length_placeholder = "sequence_length"
vector_observation_placeholder = "vector_observation"
recurrent_in_placeholder = "recurrent_in"
recurrent_in_placeholder_h = "recurrent_in_h"
recurrent_in_placeholder_c = "recurrent_in_c"
visual_observation_placeholder_prefix = "visual_observation_"
observation_placeholder_prefix = "obs_"
previous_action_placeholder = "prev_action"
action_mask_placeholder = "action_masks"
random_normal_epsilon_placeholder = "epsilon"
ValueEstimateOutput = "value_estimate"
RecurrentOutput = "recurrent_out"
recurrentOutputH = "recurrent_out_h"
recurrentOutputC = "recurrent_out_c"
MemorySize = "memory_size"
VersionNumber = "version_number"
ContinuousActionOutputShape = "continuous_action_output_shape"
DiscreteActionOutputShape = "discrete_action_output_shape"
ContinuousActionOutput = "continuous_actions"
DiscreteActionOutput = "discrete_actions"
value_estimate_output = "value_estimate"
recurrent_output = "recurrent_out"
recurrent_output_h = "recurrent_out_h"
recurrent_output_c = "recurrent_out_c"
memory_size = "memory_size"
version_number = "version_number"
continuous_action_output_shape = "continuous_action_output_shape"
discrete_action_output_shape = "discrete_action_output_shape"
continuous_action_output = "continuous_actions"
discrete_action_output = "discrete_actions"
IsContinuousControlDeprecated = "is_continuous_control"
ActionOutputDeprecated = "action"
ActionOutputShapeDeprecated = "action_output_shape"
is_continuous_control_deprecated = "is_continuous_control"
action_output_deprecated = "action"
action_output_shape_deprecated = "action_output_shape"
class ModelSerializer:

dummy_memories,
)
self.input_names = [TensorNames.VectorObservationPlaceholder]
self.input_names = [TensorNames.vector_observation_placeholder]
TensorNames.VisualObservationPlaceholderPrefix + str(i)
TensorNames.visual_observation_placeholder_prefix + str(i)
TensorNames.ObservationPlaceholderPrefix + str(i)
TensorNames.observation_placeholder_prefix + str(i)
TensorNames.ActionMaskPlaceholder,
TensorNames.RecurrentInPlaceholder,
TensorNames.action_mask_placeholder,
TensorNames.recurrent_in_placeholder,
self.output_names = [TensorNames.VersionNumber, TensorNames.MemorySize]
self.output_names = [TensorNames.version_number, TensorNames.memory_size]
TensorNames.ContinuousActionOutput,
TensorNames.ContinuousActionOutputShape,
TensorNames.continuous_action_output,
TensorNames.continuous_action_output_shape,
self.dynamic_axes.update({TensorNames.ContinuousActionOutput: {0: "batch"}})
self.dynamic_axes.update(
{TensorNames.continuous_action_output: {0: "batch"}}
)
TensorNames.DiscreteActionOutput,
TensorNames.DiscreteActionOutputShape,
TensorNames.discrete_action_output,
TensorNames.discrete_action_output_shape,
self.dynamic_axes.update({TensorNames.DiscreteActionOutput: {0: "batch"}})
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})
TensorNames.ActionOutputDeprecated,
TensorNames.IsContinuousControlDeprecated,
TensorNames.ActionOutputShapeDeprecated,
TensorNames.action_output_deprecated,
TensorNames.is_continuous_control_deprecated,
TensorNames.action_output_shape_deprecated,
self.dynamic_axes.update({TensorNames.ActionOutputDeprecated: {0: "batch"}})
self.dynamic_axes.update(
{TensorNames.action_output_deprecated: {0: "batch"}}
)
self.output_names += [TensorNames.RecurrentOutput]
self.output_names += [TensorNames.recurrent_output]
def export_policy_model(self, output_filepath: str) -> None:
"""

正在加载...
取消
保存