|
|
|
|
|
|
public string[] ObservationPlaceholderName; |
|
|
|
/// Modify only in inspector : Name of the action node
|
|
|
|
public string ActionPlaceholderName = "action"; |
|
|
|
#if ENABLE_TENSORFLOW
|
|
|
|
#if ENABLE_TENSORFLOW
|
|
|
|
TFGraph graph; |
|
|
|
TFSession session; |
|
|
|
bool hasRecurrent; |
|
|
|
|
|
|
float[,] inputState; |
|
|
|
List<float[,,,]> observationMatrixList; |
|
|
|
float[,] inputOldMemories; |
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
/// Reference to the brain that uses this CoreBrainInternal
|
|
|
|
public Brain brain; |
|
|
|
|
|
|
|
|
|
|
foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders) |
|
|
|
{ |
|
|
|
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint) |
|
|
|
try |
|
|
|
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) }); |
|
|
|
if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.FloatingPoint) |
|
|
|
{ |
|
|
|
runner.AddInput(graph[graphScope + placeholder.name][0], new float[] { Random.Range(placeholder.minValue, placeholder.maxValue) }); |
|
|
|
} |
|
|
|
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer) |
|
|
|
{ |
|
|
|
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) }); |
|
|
|
} |
|
|
|
else if (placeholder.valueType == TensorFlowAgentPlaceholder.tensorType.Integer) |
|
|
|
catch |
|
|
|
runner.AddInput(graph[graphScope + placeholder.name][0], new int[] { Random.Range((int)placeholder.minValue, (int)placeholder.maxValue + 1) }); |
|
|
|
throw new UnityAgentsException(string.Format(@"One of the Tensorflow placeholder cound nout be found.
|
|
|
|
In brain {0}, there are no {1} placeholder named {2}.",
|
|
|
|
brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
runner.AddInput(graph[graphScope + ObservationPlaceholderName[obs_number]][0], observationMatrixList[obs_number]); |
|
|
|
} |
|
|
|
|
|
|
|
TFTensor[] networkOutput; |
|
|
|
try |
|
|
|
{ |
|
|
|
networkOutput = runner.Run(); |
|
|
|
} |
|
|
|
catch (TFException e) |
|
|
|
{ |
|
|
|
string errorMessage = e.Message; |
|
|
|
try |
|
|
|
{ |
|
|
|
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}", |
|
|
|
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0], |
|
|
|
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]); |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
throw new UnityAgentsException(errorMessage); |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
// Create the recurrent tensor
|
|
|
|
if (hasRecurrent) |
|
|
|
|
|
|
runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories); |
|
|
|
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]); |
|
|
|
float[,] recurrent_tensor = runner.Run()[1].GetValue() as float[,]; |
|
|
|
float[,] recurrent_tensor = networkOutput[1].GetValue() as float[,]; |
|
|
|
|
|
|
|
int i = 0; |
|
|
|
foreach (int k in agentKeys) |
|
|
|
|
|
|
|
|
|
|
if (brain.brainParameters.actionSpaceType == StateType.continuous) |
|
|
|
{ |
|
|
|
float[,] output = runner.Run()[0].GetValue() as float[,]; |
|
|
|
float[,] output = networkOutput[0].GetValue() as float[,]; |
|
|
|
int i = 0; |
|
|
|
foreach (int k in agentKeys) |
|
|
|
{ |
|
|
|
|
|
|
} |
|
|
|
else if (brain.brainParameters.actionSpaceType == StateType.discrete) |
|
|
|
{ |
|
|
|
long[,] output = runner.Run()[0].GetValue() as long[,]; |
|
|
|
long[,] output = networkOutput[0].GetValue() as long[,]; |
|
|
|
int i = 0; |
|
|
|
foreach (int k in agentKeys) |
|
|
|
{ |
|
|
|