浏览代码

made a nice error if a placeholder is missing or if a placeholder is not in the graph

/develop-generalizationTraining-TrainerController
vincentpierre 7 年前
当前提交
54d85928
共有 1 个文件被更改,包括 38 次插入9 次删除
  1. 47
      unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs

47
unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs


public string[] ObservationPlaceholderName;
/// Modify only in inspector : Name of the action node
public string ActionPlaceholderName = "action";
#if ENABLE_TENSORFLOW
#if ENABLE_TENSORFLOW
TFGraph graph;
TFSession session;
bool hasRecurrent;

float[,] inputState;
List<float[,,,]> observationMatrixList;
float[,] inputOldMemories;
#endif
#endif
/// Reference to the brain that uses this CoreBrainInternal
public Brain brain;

foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders)
{
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
try
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint)
{
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) });
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
{
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
}
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer)
catch
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) });
throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found.
In brain {0}, there are no {1} placeholder named {2}.",
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name));
}
}

runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]);
}
TFTensor[] runned;
try
{
runned = runner.Run();
}
catch (TFException e)
{
string errorMessage = e.Message;
try
{
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}",
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0],
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]);
}
finally
{
throw new UnityAgentsException(errorMessage);
}
}
// Create the recurrent tensor
if (hasRecurrent)

runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories);
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]);
float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,];
float[,] recurrent_tensor = runned[1].GetValue() as float[,];
int i = 0;
foreach (int k in agentKeys)

if (brain.brainParameters.actionSpaceType == StateType.continuous)
{
float[,] output = runner.Run()[0].GetValue() as float[,];
float[,] output = runned[0].GetValue() as float[,];
int i = 0;
foreach (int k in agentKeys)
{

}
else if (brain.brainParameters.actionSpaceType == StateType.discrete)
{
long[,] output = runner.Run()[0].GetValue() as long[,];
long[,] output = runned[0].GetValue() as long[,];
int i = 0;
foreach (int k in agentKeys)
{

正在加载...
取消
保存