浏览代码

Merge pull request #5018 from Unity-Technologies/develop-use-correct-names-for-recurrent-inputs-and-outputs

Modified the model_serialization to have correct inputs and outputs
/develop/gail-srl-hack
GitHub 4 年前
当前提交
c3c34267
共有 5 个文件被更改,包括 91 次插入24 次删除
  1. 8
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 4
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  3. 20
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  4. 81
      ml-agents/mlagents/trainers/torch/model_serialization.py
  5. 2
      ml-agents/mlagents/trainers/torch/networks.py

8
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


if (sensor.GetObservationShape().Length == 3)
{
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
TensorNames.GetVisualObservationName(visObsIndex)))
{
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +

if (sensor.GetObservationShape().Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.ObservationPlaceholderPrefix + sensorIndex))
TensorNames.GetObservationName(sensorIndex)))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +

if (sens.GetObservationShape().Length == 3)
{
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);
}
}

4
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.ObservationPlaceholderPrefix + sensorIndex;
obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGenName = TensorNames.VisualObservationPlaceholderPrefix + visIndex;
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:

20
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string SequenceLengthPlaceholder = "sequence_length";
public const string VectorObservationPlaceholder = "vector_observation";
public const string RecurrentInPlaceholder = "recurrent_in";
public const string recurrentInPlaceholderH = "recurrent_in_h";
public const string recurrentInPlaceholderC = "recurrent_in_c";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string PreviousActionPlaceholder = "prev_action";

public const string ValueEstimateOutput = "value_estimate";
public const string RecurrentOutput = "recurrent_out";
public const string recurrentOutputH = "recurrent_out_h";
public const string recurrentOutputC = "recurrent_out_c";
public const string MemorySize = "memory_size";
public const string VersionNumber = "version_number";
public const string ContinuousActionOutputShape = "continuous_action_output_shape";

public const string IsContinuousControlDeprecated = "is_continuous_control";
public const string ActionOutputDeprecated = "action";
public const string ActionOutputShapeDeprecated = "action_output_shape";
/// <summary>
/// Returns the name of the visual observation with a given index
/// </summary>
public static string GetVisualObservationName(int index)
{
return VisualObservationPlaceholderPrefix + index;
}
/// <summary>
/// Returns the name of the observation with a given index
/// </summary>
public static string GetObservationName(int index)
{
return ObservationPlaceholderPrefix + index;
}
}
}

81
ml-agents/mlagents/trainers/torch/model_serialization.py


return exporting_to_onnx._local_data._is_exporting
class TensorNames:
batch_size_placeholder = "batch_size"
sequence_length_placeholder = "sequence_length"
vector_observation_placeholder = "vector_observation"
recurrent_in_placeholder = "recurrent_in"
visual_observation_placeholder_prefix = "visual_observation_"
observation_placeholder_prefix = "obs_"
previous_action_placeholder = "prev_action"
action_mask_placeholder = "action_masks"
random_normal_epsilon_placeholder = "epsilon"
value_estimate_output = "value_estimate"
recurrent_output = "recurrent_out"
memory_size = "memory_size"
version_number = "version_number"
continuous_action_output_shape = "continuous_action_output_shape"
discrete_action_output_shape = "discrete_action_output_shape"
continuous_action_output = "continuous_actions"
discrete_action_output = "discrete_actions"
# Deprecated TensorNames entries for backward compatibility
is_continuous_control_deprecated = "is_continuous_control"
action_output_deprecated = "action"
action_output_shape_deprecated = "action_output_shape"
@staticmethod
def get_visual_observation_name(index: int) -> str:
"""
Returns the name of the visual observation with a given index
"""
return TensorNames.visual_observation_placeholder_prefix + str(index)
@staticmethod
def get_observation_name(index: int) -> str:
"""
Returns the name of the observation with a given index
"""
return TensorNames.observation_placeholder_prefix + str(index)
class ModelSerializer:
def __init__(self, policy):
# ONNX only support input in NCHW (channel first) format.

dummy_memories,
)
self.input_names = ["vector_observation"]
self.input_names = [TensorNames.vector_observation_placeholder]
self.input_names.append(f"visual_observation_{i}")
self.input_names.append(TensorNames.get_visual_observation_name(i))
self.input_names.append(f"obs_{i}")
self.input_names += ["action_masks", "memories"]
self.input_names.append(TensorNames.get_observation_name(i))
self.input_names += [
TensorNames.action_mask_placeholder,
TensorNames.recurrent_in_placeholder,
]
self.output_names = ["version_number", "memory_size"]
self.output_names = [TensorNames.version_number, TensorNames.memory_size]
"continuous_actions",
"continuous_action_output_shape",
TensorNames.continuous_action_output,
TensorNames.continuous_action_output_shape,
self.dynamic_axes.update({"continuous_actions": {0: "batch"}})
self.dynamic_axes.update(
{TensorNames.continuous_action_output: {0: "batch"}}
)
self.output_names += ["discrete_actions", "discrete_action_output_shape"]
self.dynamic_axes.update({"discrete_actions": {0: "batch"}})
self.output_names += [
TensorNames.discrete_action_output,
TensorNames.discrete_action_output_shape,
]
self.dynamic_axes.update({TensorNames.discrete_action_output: {0: "batch"}})
"action",
"is_continuous_control",
"action_output_shape",
TensorNames.action_output_deprecated,
TensorNames.is_continuous_control_deprecated,
TensorNames.action_output_shape_deprecated,
self.dynamic_axes.update({"action": {0: "batch"}})
self.dynamic_axes.update(
{TensorNames.action_output_deprecated: {0: "batch"}}
)
if self.policy.export_memory_size > 0:
self.output_names += [TensorNames.recurrent_output]
def export_policy_model(self, output_filepath: str) -> None:
"""

2
ml-agents/mlagents/trainers/torch/networks.py


self.is_continuous_int_deprecated,
self.act_size_vector_deprecated,
]
if self.network_body.memory_size > 0:
export_out += [memories_out]
return tuple(export_out)

正在加载...
取消
保存