|
|
|
|
|
|
bool hasState; |
|
|
|
bool hasBatchSize; |
|
|
|
bool hasPrevAction; |
|
|
|
bool hasValueEstimate; |
|
|
|
float[,] inputState; |
|
|
|
int[] inputPrevAction; |
|
|
|
List<float[,,,]> observationMatrixList; |
|
|
|
|
|
|
if (graph[graphScope + PreviousActionPlaceholderName] != null) |
|
|
|
{ |
|
|
|
hasPrevAction = true; |
|
|
|
} |
|
|
|
if (graph[graphScope + "value_estimate"] != null) |
|
|
|
{ |
|
|
|
hasValueEstimate = true; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]); |
|
|
|
} |
|
|
|
|
|
|
|
if (hasValueEstimate) |
|
|
|
{ |
|
|
|
runner.Fetch(graph[graphScope + "value_estimate"][0]); |
|
|
|
} |
|
|
|
|
|
|
|
TFTensor[] networkOutput; |
|
|
|
try |
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
agent.UpdateMemoriesAction(m.ToList()); |
|
|
|
i++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (hasValueEstimate) |
|
|
|
{ |
|
|
|
float[,] value_estimates = new float[currentBatchSize,1]; |
|
|
|
if (hasRecurrent) |
|
|
|
{ |
|
|
|
value_estimates = networkOutput[2].GetValue() as float[,]; |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
value_estimates = networkOutput[1].GetValue() as float[,]; |
|
|
|
} |
|
|
|
|
|
|
|
var i = 0; |
|
|
|
foreach (Agent agent in agentList) |
|
|
|
{ |
|
|
|
agent.UpdateValueAction(value_estimates[i,0]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|