浏览代码

get output loss

/ai-hw-2021
Ruo-Ping Dong 3 年前
当前提交
f62c6971
共有 1 个文件被更改,包括 10 次插入5 次删除
  1. 15
      com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs

15
com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs


bool m_Verbose = false;
IReadOnlyList<TensorProxy> m_TrainingInputs;
IReadOnlyList<TensorProxy> m_InferenceInputs;
string[] m_TrainingOutputNames;
string[] m_InferenceOutputNames;
List<TensorProxy> m_TrainingOutputs;
Dictionary<string, Tensor> m_InputsByName;
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();

actionSpec, seed, m_TensorAllocator, barracudaModel);
m_InputsByName = new Dictionary<string, Tensor>();
m_TrainingOutputs = new List<TensorProxy>();
m_TrainingOutputNames = new string[] {TensorNames.TrainingStateOut, TensorNames.OuputLoss};
m_InferenceOutputNames = new string[] {TensorNames.TrainingOutput};
m_Buffer = buffer;
InitializeTrainingState();
}

{
name = TensorNames.InitialTrainingState,
valueType = TensorProxy.TensorType.FloatingPoint,
data = initState,
data = initState.DeepCopy(),
shape = initState.shape.ToArray().Select(i => (long)i).ToArray()
};
}

// Execute the Model
m_Engine.Execute(m_InputsByName);
FetchBarracudaOutputs(new string[] { TensorNames.TrainingOutput });
FetchBarracudaOutputs(m_InferenceOutputNames);
// Update the outputs
m_TensorApplier.ApplyTensors(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);

m_OrderedAgentsRequestingDecisions.Clear();
}
public void UpdateModel(List<Transition> transitions)
public float UpdateModel(List<Transition> transitions)
return;
return 0;
}
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, m_TrainingState, true);

m_Engine.Execute(m_InputsByName);
// Update the model
FetchBarracudaOutputs(new string[] { TensorNames.TrainingStateOut });
FetchBarracudaOutputs(m_TrainingOutputNames);
TensorUtils.CopyTensor(m_TrainingOutputs[0], m_TrainingState);
// UnityEngine.Debug.Log(m_TrainingState.data[0]);

// }
// throw new System.Exception("STOP");
return m_TrainingOutputs[1].data[0];
}
public ActionBuffers GetAction(int agentId)

正在加载...
取消
保存