浏览代码

Adding a helper method for creating observation placeholder names and removed the _h and _c placeholders

/develop/gail-srl-hack
vincentpierre 4 年前
当前提交
3068ae1f
共有 4 个文件被更改,包括 38 次插入20 次删除
  1. 8
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  2. 4
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  3. 20
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  4. 26
      ml-agents/mlagents/trainers/torch/model_serialization.py

8
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


if (sensor.GetObservationShape().Length == 3)
{
if (!tensorsNames.Contains(
TensorNames.VisualObservationPlaceholderPrefix + visObsIndex))
TensorNames.GetVisualObservationName(visObsIndex)))
{
failedModelChecks.Add(
"The model does not contain a Visual Observation Placeholder Input " +

if (sensor.GetObservationShape().Length == 2)
{
if (!tensorsNames.Contains(
TensorNames.ObservationPlaceholderPrefix + sensorIndex))
TensorNames.GetObservationName(sensorIndex)))
{
failedModelChecks.Add(
"The model does not contain an Observation Placeholder Input " +

if (sens.GetObservationShape().Length == 3)
{
tensorTester[TensorNames.VisualObservationPlaceholderPrefix + visObsIndex] =
tensorTester[TensorNames.GetVisualObservationName(visObsIndex)] =
tensorTester[TensorNames.ObservationPlaceholderPrefix + sensorIndex] =
tensorTester[TensorNames.GetObservationName(sensorIndex)] =
(bp, tensor, scs, i) => CheckRankTwoObsShape(tensor, sens);
}
}

4
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


// If the tensor is of rank 2, we use the index of the sensor
// to create the name
obsGen = new ObservationGenerator(allocator);
obsGenName = TensorNames.ObservationPlaceholderPrefix + sensorIndex;
obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGenName = TensorNames.VisualObservationPlaceholderPrefix + visIndex;
obsGenName = TensorNames.GetVisualObservationName(visIndex);
visIndex++;
break;
default:

20
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string SequenceLengthPlaceholder = "sequence_length";
public const string VectorObservationPlaceholder = "vector_observation";
public const string RecurrentInPlaceholder = "recurrent_in";
public const string recurrentInPlaceholderH = "recurrent_in_h";
public const string recurrentInPlaceholderC = "recurrent_in_c";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string PreviousActionPlaceholder = "prev_action";

public const string ValueEstimateOutput = "value_estimate";
public const string RecurrentOutput = "recurrent_out";
public const string recurrentOutputH = "recurrent_out_h";
public const string recurrentOutputC = "recurrent_out_c";
public const string MemorySize = "memory_size";
public const string VersionNumber = "version_number";
public const string ContinuousActionOutputShape = "continuous_action_output_shape";

public const string IsContinuousControlDeprecated = "is_continuous_control";
public const string ActionOutputDeprecated = "action";
public const string ActionOutputShapeDeprecated = "action_output_shape";
/// <summary>
/// Returns the name of the visual observation with a given index
/// </summary>
public static string GetVisualObservationName(int index)
{
return VisualObservationPlaceholderPrefix + index;
}
/// <summary>
/// Returns the name of the observation with a given index
/// </summary>
public static string GetObservationName(int index)
{
return ObservationPlaceholderPrefix + index;
}
}
}

26
ml-agents/mlagents/trainers/torch/model_serialization.py


sequence_length_placeholder = "sequence_length"
vector_observation_placeholder = "vector_observation"
recurrent_in_placeholder = "recurrent_in"
recurrent_in_placeholder_h = "recurrent_in_h"
recurrent_in_placeholder_c = "recurrent_in_c"
visual_observation_placeholder_prefix = "visual_observation_"
observation_placeholder_prefix = "obs_"
previous_action_placeholder = "prev_action"

value_estimate_output = "value_estimate"
recurrent_output = "recurrent_out"
recurrent_output_h = "recurrent_out_h"
recurrent_output_c = "recurrent_out_c"
memory_size = "memory_size"
version_number = "version_number"
continuous_action_output_shape = "continuous_action_output_shape"

is_continuous_control_deprecated = "is_continuous_control"
action_output_deprecated = "action"
action_output_shape_deprecated = "action_output_shape"
@staticmethod
def get_visual_observation_name(index: int) -> str:
"""
Returns the name of the visual observation with a given index
"""
return TensorNames.visual_observation_placeholder_prefix + str(index)
@staticmethod
def get_observation_name(index: int) -> str:
"""
Returns the name of the observation with a given index
"""
return TensorNames.observation_placeholder_prefix + str(index)
class ModelSerializer:

self.input_names = [TensorNames.vector_observation_placeholder]
for i in range(num_vis_obs):
self.input_names.append(
TensorNames.visual_observation_placeholder_prefix + str(i)
)
self.input_names.append(TensorNames.get_visual_observation_name(i))
self.input_names.append(
TensorNames.observation_placeholder_prefix + str(i)
)
self.input_names.append(TensorNames.get_observation_name(i))
self.input_names += [
TensorNames.action_mask_placeholder,
TensorNames.recurrent_in_placeholder,

正在加载...
取消
保存