浏览代码

The internal Brain now can effectively modify the value field of the agents (#275)

* Requires training to have been made with ppo
* The name of the tensor must be value_estimate
/develop-generalizationTraining-TrainerController
Arthur Juliani 7 年前
当前提交
f2d30f07
共有 1 个文件被更改,包括 44 次插入10 次删除
  1. 54
      unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs

54
unity-environment/Assets/ML-Agents/Scripts/CoreBrainInternal.cs


public enum tensorType
{
Integer,
FloatingPoint}
FloatingPoint
}
;

public string[] ObservationPlaceholderName;
/// Modify only in inspector : Name of the action node
public string ActionPlaceholderName = "action";
#if ENABLE_TENSORFLOW
#if ENABLE_TENSORFLOW
bool hasValue;
#endif
#endif
/// Reference to the brain that uses this CoreBrainInternal
public Brain brain;

session = new TFSession(graph);
if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/')){
if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/'))
{
graphScope = graphScope + '/';
}

{
hasState = true;
}
if (graph[graphScope + "value_estimate"] != null)
{
hasValue = true;
}
}
#endif
}

{
coord.giveBrainInfo(brain);
}
#endif
#endif
}

catch
{
throw new UnityAgentsException(string.Format(@"The node {0} could not be found. Please make sure the graphScope {1} is correct",
graphScope + ActionPlaceholderName, graphScope));
graphScope + ActionPlaceholderName, graphScope));
}
if (hasBatchSize)

if (hasRecurrent)
{
runner.AddInput(graph[graphScope + "sequence_length"][0], 1 );
runner.AddInput(graph[graphScope + "sequence_length"][0], 1);
if (hasValue)
{
runner.Fetch(graph[graphScope + "value_estimate"][0]);
}
TFTensor[] networkOutput;
try
{

try
{
errorMessage = string.Format(@"The tensorflow graph needs an input for {0} of type {1}",
e.Message.Split(new string[]{ "Node: " }, 0)[1].Split('=')[0],
e.Message.Split(new string[]{ "dtype=" }, 0)[1].Split(',')[0]);
e.Message.Split(new string[] { "Node: " }, 0)[1].Split('=')[0],
e.Message.Split(new string[] { "dtype=" }, 0)[1].Split(',')[0]);
}
finally
{

}
brain.SendActions(actions);
if (hasValue)
{
var values = new Dictionary<int, float>();
float[,] value_tensor;
if (hasRecurrent)
{
value_tensor = networkOutput[2].GetValue() as float[,];
}
else
{
value_tensor = networkOutput[1].GetValue() as float[,];
}
var i = 0;
foreach (int k in agentKeys)
{
var v = (float)(value_tensor[i, 0]);
values.Add(k, v);
i++;
}
brain.SendValues(values);
}
#endif
}

正在加载...
取消
保存