浏览代码

BASIC WORKS

/ai-hw-2021
vincentpierre 4 年前
当前提交
7cece532
共有 7 个文件被更改,包括 59 次插入13 次删除
  1. 3
      Project/Assets/ML-Agents/Examples/Basic/Scenes/Basic.unity
  2. 8
      com.unity.ml-agents/Runtime/Agent.cs
  3. 11
      com.unity.ml-agents/Runtime/Inference/TensorProxy.cs
  4. 4
      com.unity.ml-agents/Runtime/Training/MyTimeScaleSetting.cs
  5. 12
      com.unity.ml-agents/Runtime/Training/Trainer.cs
  6. 2
      com.unity.ml-agents/Runtime/Training/TrainingForwardTensorApplier.cs
  7. 32
      com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs

3
Project/Assets/ML-Agents/Examples/Basic/Scenes/Basic.unity


m_Name:
m_EditorClassIdentifier:
m_TimeScale: 1
m_Greedy: 1
m_Greedy: 0.5
m_Train: 1

8
com.unity.ml-agents/Runtime/Agent.cs


return;
}
if (m_Info.done)
{
m_Info.ClearActions();
}
// if (m_Info.done)
// {
// m_Info.ClearActions();
// }
else
{
m_Info.CopyActions(m_ActuatorManager.StoredActions);

11
com.unity.ml-agents/Runtime/Inference/TensorProxy.cs


public static void CopyTensor(TensorProxy source, TensorProxy target)
{
if (source.data.batch != target.data.batch ||
source.data.height != target.data.height ||
source.data.width != target.data.width ||
source.data.channels != target.data.channels)
{
UnityEngine.Debug.Log("Error");
}
for (var b = 0; b < source.data.batch; b++)
{
for (var i = 0; i < source.data.height; i++)

for(var k = 0; k < source.data.channels; k++)
for (var k = 0; k < source.data.channels; k++)
{
target.data[b, i, j, k] = source.data[b, i, j, k];
}

name = tensor.name,
valueType = tensor.valueType,
data = tensor.data.DeepCopy(),
shape = (long[]) tensor.shape.Clone()
shape = (long[])tensor.shape.Clone()
};
}
}

4
com.unity.ml-agents/Runtime/Training/MyTimeScaleSetting.cs


void Start()
{
DontDestroyOnLoad(this.gameObject);
if (FindObjectsOfType<MyTimeScaleSetting>().Length > 1)
{
Destroy(this.gameObject);
}
}
// Update is called once per frame

12
com.unity.ml-agents/Runtime/Training/Trainer.cs


public int batchSize = 100;
public float gamma = 0.9f;
public float learningRate = 0.0001f;
public int updatePeriod = 10;
public int updatePeriod = 500;
public int numSamplingAndUpdates = 50;
// public int updateTargetFreq = 200;
}

return;
}
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
m_ModelRunner.UpdateModel(samples);
// UnityEngine.Debug.Log("Update");
for (int i = 0; i < m_Config.numSamplingAndUpdates; i++)
{
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
m_ModelRunner.UpdateModel(samples);
}
UnityEngine.Debug.Log("Update !");
// Update target network
// if (m_TrainingStep % m_Config.updateTargetFreq == 0)

2
com.unity.ml-agents/Runtime/Training/TrainingForwardTensorApplier.cs


var discreteBuffer = actionBuffer.DiscreteActions;
var maxIndex = 0;
var maxValue = float.MinValue;
// UnityEngine.Debug.Log(" ");
// UnityEngine.Debug.Log(value);
if (value > maxValue)
{
maxIndex = j;

32
com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs


actionSpec, seed, m_TensorAllocator, barracudaModel);
m_InputsByName = new Dictionary<string, Tensor>();
m_TrainingOutputs = new List<TensorProxy>();
m_TrainingOutputNames = new string[] {TensorNames.TrainingStateOut, TensorNames.OuputLoss};
m_InferenceOutputNames = new string[] {TensorNames.TrainingOutput};
m_TrainingOutputNames = new string[] { TensorNames.TrainingStateOut, TensorNames.OuputLoss };
m_InferenceOutputNames = new string[] { TensorNames.TrainingOutput };
m_Buffer = buffer;
InitializeTrainingState();
}

// UnityEngine.Debug.Log(m_TrainingOutputs[0].data[i]);
// }
// throw new System.Exception("STOP");
// UnityEngine.Debug.Log(m_TrainingState.data[m_TrainingState.data.length - 1] );
// m_TrainingState = m_TrainingOutputs[0];
// for (int i = 0; i < transitions.Count; i++){
// string message = "";
// for (int j = 0; j < transitions[i].state[0].data.length; j ++){
// if( transitions[i].state[0].data[j] > 0.5f){
// message += j;
// }
// }
// message += " | ";
// for (int j = 0; j < transitions[i].nextState[0].data.length; j ++){
// if( transitions[i].nextState[0].data[j] > 0.5f){
// message += j;
// }
// }
// message += " | ";
// message += transitions[i].action.DiscreteActions[0];
// message += " | ";
// message += transitions[i].reward;
// message += " | ";
// message += transitions[i].done;
// UnityEngine.Debug.Log(message);
// }
return m_TrainingOutputs[1].data[0];
}

正在加载...
取消
保存