浏览代码

name the saved model file

/ai-hw-2021
Ruo-Ping Dong 3 年前
当前提交
19ee6e5e
共有 2 个文件被更改,包括 27 次插入17 次删除
  1. 24
      com.unity.ml-agents/Runtime/Training/MyTimeScaleSetting.cs
  2. 20
      com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs

24
com.unity.ml-agents/Runtime/Training/MyTimeScaleSetting.cs


s_Instance = null;
}
[SerializeField]
string m_TrainingId = "default";
public string TrainingId
{
get { return m_TrainingId; }
set { m_TrainingId = value; }
}
[SerializeField]
float m_TimeScale = 1f;

}
[SerializeField]
bool m_LoadFile = true;
public bool LoadFile
bool m_LoadTrainedModel = true;
public bool LoadTrainedModel
get { return m_LoadFile; }
set { m_LoadFile = value; }
get { return m_LoadTrainedModel; }
set { m_LoadTrainedModel = value; }
}
[SerializeField]
Object m_Model;
public Object Model
{
get { return m_Model; }
set { m_Model = value; }
}
// Start is called before the first frame update

20
com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs


using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents;
using Unity.MLAgents.Policies;
using Unity.MLAgents.Inference.Utils;
using System.Runtime.Serialization.Formatters.Binary;
using UnityEditor;
namespace Unity.MLAgents
{

bool m_TrainingObservationsInitialized;
ReplayBuffer m_Buffer;
string m_ModelFileName = "Assets/model.dat";
/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for

{
var initState = m_Model.GetTensorByName(TensorNames.InitialTrainingState);
int[] stateShape = initState.shape.ToArray();
if (MyTimeScaleSetting.instance.LoadFile)
if (MyTimeScaleSetting.instance.LoadTrainedModel && MyTimeScaleSetting.instance.Model != null)
initState = LoadModelFromFile(stateShape);
var modelPath = AssetDatabase.GetAssetPath(MyTimeScaleSetting.instance.Model);
initState = LoadModelFromFile(modelPath, stateShape);
}
m_TrainingState = new TensorProxy

float[] array = m_TrainingState.data.ToReadOnlyArray();
var byteArray = new byte[array.Length * 4];
Buffer.BlockCopy(array, 0, byteArray, 0, byteArray.Length);
File.WriteAllBytes(m_ModelFileName, byteArray);
Debug.Log($"Save ModelParam: {m_TrainingState.data[0]}, {m_TrainingState.data[1]}, {m_TrainingState.data[2]}, " +
$"{m_TrainingState.data[3]}, {m_TrainingState.data[4]}, {m_TrainingState.data[5]}, " +
$"{m_TrainingState.data[6]}, {m_TrainingState.data[7]}, {m_TrainingState.data[8]}, {m_TrainingState.data[9]}");
File.WriteAllBytes("Assets/model_" + MyTimeScaleSetting.instance.TrainingId, byteArray);
public Tensor LoadModelFromFile(int[] shape)
public Tensor LoadModelFromFile(string path, int[] shape)
var byteArray = File.ReadAllBytes(m_ModelFileName);
var byteArray = File.ReadAllBytes(path);
float[] array = new float[byteArray.Length / 4];
Buffer.BlockCopy(byteArray, 0, array, 0, byteArray.Length);
return new Tensor(shape, array);

正在加载...
取消
保存