|
|
|
|
|
|
public enum tensorType |
|
|
|
{ |
|
|
|
Integer, |
|
|
|
FloatingPoint} |
|
|
|
FloatingPoint |
|
|
|
} |
|
|
|
|
|
|
|
; |
|
|
|
|
|
|
|
|
|
|
public string[] ObservationPlaceholderName; |
|
|
|
/// Modify only in inspector : Name of the action node
|
|
|
|
public string ActionPlaceholderName = "action"; |
|
|
|
#if ENABLE_TENSORFLOW
|
|
|
|
#if ENABLE_TENSORFLOW
|
|
|
|
bool hasValue; |
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
/// Reference to the brain that uses this CoreBrainInternal
|
|
|
|
public Brain brain; |
|
|
|
|
|
|
|
|
|
|
session = new TFSession(graph); |
|
|
|
|
|
|
|
if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/')){ |
|
|
|
if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/')) |
|
|
|
{ |
|
|
|
graphScope = graphScope + '/'; |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
{ |
|
|
|
hasState = true; |
|
|
|
} |
|
|
|
if (graph[graphScope + "value_estimate"] != null) |
|
|
|
{ |
|
|
|
hasValue = true; |
|
|
|
} |
|
|
|
} |
|
|
|
#endif
|
|
|
|
} |
|
|
|
|
|
|
{ |
|
|
|
coord.giveBrainInfo(brain); |
|
|
|
} |
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
catch |
|
|
|
{ |
|
|
|
throw new UnityAgentsException(string.Format(@"The node {0} could not be found. Please make sure the graphScope {1} is correct", |
|
|
|
graphScope + ActionPlaceholderName, graphScope)); |
|
|
|
graphScope + ActionPlaceholderName, graphScope)); |
|
|
|
} |
|
|
|
|
|
|
|
if (hasBatchSize) |
|
|
|
|
|
|
|
|
|
|
if (hasRecurrent) |
|
|
|
{ |
|
|
|
runner.AddInput(graph[graphScope + "sequence_length"][0], 1 ); |
|
|
|
runner.AddInput(graph[graphScope + "sequence_length"][0], 1); |
|
|
|
|
|
|
|
|
|
|
|
if (hasValue) |
|
|
|
{ |
|
|
|
runner.Fetch(graph[graphScope + "value_estimate"][0]); |
|
|
|
} |
|
|
|
|
|
|
|
TFTensor[] networkOutput; |
|
|
|
try |
|
|
|
{ |
|
|
|
|
|
|
try |
|
|
|
{ |
|
|
|
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}", |
|
|
|
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0], |
|
|
|
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]); |
|
|
|
e.Message.Split(new string[] { "Node: " }, 0)[1].Split('=')[0], |
|
|
|
e.Message.Split(new string[] { "dtype=" }, 0)[1].Split(',')[0]); |
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
brain.SendActions(actions); |
|
|
|
|
|
|
|
if (hasValue) |
|
|
|
{ |
|
|
|
var values = new Dictionary<int, float>(); |
|
|
|
float[,] value_tensor; |
|
|
|
if (hasRecurrent) |
|
|
|
{ |
|
|
|
value_tensor = networkOutput[2].GetValue() as float[,]; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
value_tensor = networkOutput[1].GetValue() as float[,]; |
|
|
|
} |
|
|
|
var i = 0; |
|
|
|
foreach (int k in agentKeys) |
|
|
|
{ |
|
|
|
var v = (float)(value_tensor[i, 0]); |
|
|
|
values.Add(k, v); |
|
|
|
i++; |
|
|
|
} |
|
|
|
brain.SendValues(values); |
|
|
|
} |
|
|
|
|
|
|
|
#endif
|
|
|
|
} |
|
|
|