浏览代码

save model

/ai-hw-2021
Ruo-Ping Dong 3 年前
当前提交
48b2bb77
共有 3 个文件被更改,包括 44 次插入4 次删除
  1. 8
      com.unity.ml-agents/Runtime/Training/MyTimeScaleSetting.cs
  2. 8
      com.unity.ml-agents/Runtime/Training/Trainer.cs
  3. 32
      com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs

8
com.unity.ml-agents/Runtime/Training/MyTimeScaleSetting.cs


set { m_Train = value; }
}
[SerializeField]
bool m_LoadFile = true;
public bool LoadFile
{
get { return m_LoadFile; }
set { m_LoadFile = value; }
}
// Start is called before the first frame update
void Start()
{

8
com.unity.ml-agents/Runtime/Training/Trainer.cs


return;
}
for (int i = 0; i < m_Config.numSamplingAndUpdates; i++)
float loss = 0f;
for (var i = 0; i < m_Config.numSamplingAndUpdates; i++)
m_ModelRunner.UpdateModel(samples);
loss += m_ModelRunner.UpdateModel(samples);
UnityEngine.Debug.Log("Update !");
Debug.Log($"Loss: {loss/m_Config.numSamplingAndUpdates}");
m_ModelRunner.SaveModelToFile();
// Update target network
// if (m_TrainingStep % m_Config.updateTargetFreq == 0)

32
com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs


using UnityEngine;
using Unity.MLAgents.Inference.Utils;
using System.Linq;
using System.Runtime.Serialization.Formatters.Binary;
using System.IO;
using System;
namespace Unity.MLAgents
{

bool m_TrainingObservationsInitialized;
ReplayBuffer m_Buffer;
string m_ModelFileName = "Assets/model.dat";
/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for

void InitializeTrainingState()
{
var initState = m_Model.GetTensorByName(TensorNames.InitialTrainingState);
int[] stateShape = initState.shape.ToArray();
if (MyTimeScaleSetting.instance.LoadFile)
{
Debug.Log("load model");
initState = LoadModelFromFile(stateShape);
}
shape = initState.shape.ToArray().Select(i => (long)i).ToArray()
shape = stateShape.Select(i => (long)i).ToArray()
};
}

// Debug.Log(string.Join(", ", message));
// }
// }
public void SaveModelToFile()
{
float[] array = m_TrainingState.data.ToReadOnlyArray();
var byteArray = new byte[array.Length * 4];
Buffer.BlockCopy(array, 0, byteArray, 0, byteArray.Length);
File.WriteAllBytes(m_ModelFileName, byteArray);
Debug.Log($"Save ModelParam: {m_TrainingState.data[0]}, {m_TrainingState.data[1]}, {m_TrainingState.data[2]}, " +
$"{m_TrainingState.data[3]}, {m_TrainingState.data[4]}, {m_TrainingState.data[5]}, " +
$"{m_TrainingState.data[6]}, {m_TrainingState.data[7]}, {m_TrainingState.data[8]}, {m_TrainingState.data[9]}");
}
public Tensor LoadModelFromFile(int[] shape)
{
var byteArray = File.ReadAllBytes(m_ModelFileName);
float[] array = new float[byteArray.Length / 4];
Buffer.BlockCopy(byteArray, 0, array, 0, byteArray.Length);
return new Tensor(shape, array);
}
}
}
正在加载...
取消
保存