浏览代码

add model parameters

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
d0616609
共有 11 个文件被更改,包括 322 次插入23 次删除
  1. 44
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  2. 48
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  3. 1
      com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
  4. 3
      com.unity.ml-agents/Runtime/Inference/TensorApplier.cs
  5. 65
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  6. 33
      com.unity.ml-agents/Runtime/Inference/TensorProxy.cs
  7. 71
      com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs
  8. 51
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
  9. 5
      com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs
  10. 19
      com.unity.ml-agents/Runtime/ReplayBuffer.cs
  11. 5
      com.unity.ml-agents/Runtime/Trainer.cs

44
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


}
}
internal class MaxActionOutputApplier : TensorApplier.IApplier
{
readonly ActionSpec m_ActionSpec;
public MaxActionOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator)
{
m_ActionSpec = actionSpec;
}
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
for (var i = 0; i < actionIds.Count; i++)
{
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
{
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
}
var discreteBuffer = actionBuffer.DiscreteActions;
var maxIndex = 0;
var maxValue = 0;
for (var j = 0; j < actionSize; j++)
{
var value = (int)tensorProxy.data[agentIndex, j];
if (value > maxValue)
{
maxIndex = j;
}
}
discreteBuffer[0] = maxIndex;
}
agentIndex++;
}
}
}
/// <summary>
/// The Applier for the Discrete Action output tensor. Uses multinomial to sample discrete

48
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


using System.Linq;
using Unity.Barracuda;
using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck;
using UnityEngine;
namespace Unity.MLAgents.Inference
{

foreach (var input in model.inputs)
{
if (!TensorNames.IsTrainingInputNames(input.name))
if (TensorNames.IsInferenceInputNames(input.name))
{
tensors.Add(new TensorProxy
{

foreach (var input in model.inputs)
{
if (TensorNames.IsTrainingInputNames(input.name) || input.name.StartsWith(TensorNames.ObservationPlaceholderPrefix))
if (TensorNames.IsTrainingInputNames(input.name))
{
tensors.Add(new TensorProxy
{

tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
public static IReadOnlyList<TensorProxy> GetModelParamTensors(this Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
if (TensorNames.IsModelParamNames(input.name))
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = GetShape(input)
});
}
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
// hack the shape for now
public static long[] GetShape(Model.Input tensor)
{
if (tensor.name == "b_2")
{
return new long[] {1, 1, 1, 1, 1, 1, 1, 3}; //output
}
else if (tensor.name.StartsWith("b_"))
{
return new long[] {1, 1, 1, 1, 1, 1, 1, 128}; //hidden
}
else
{
return tensor.shape.Select(i => (long)i).ToArray();
}
}
/// <summary>

1
com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs


using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Policies;
using UnityEngine;
namespace Unity.MLAgents.Inference
{

3
com.unity.ml-agents/Runtime/Inference/TensorApplier.cs


}
if (modelVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0)
{
m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
// m_Dict[tensorName] = new DiscreteActionOutputApplier(actionSpec, seed, allocator);
m_Dict[tensorName] = new MaxActionOutputApplier(actionSpec, seed, allocator);
}
}
m_Dict[TensorNames.RecurrentOutput] = new MemoryOutputApplier(memories);

65
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


using System.Collections.Generic;
using System.Linq;
using System;
namespace Unity.MLAgents.Inference
{
/// <summary>

public const string ActionInput = "action_in";
public const string RewardInput = "reward";
public const string NextObservationPlaceholderPrefix = "next_obs_";
public const string TargetInput = "target";
public const string LearningRate = "lr";
public const string InputWeightsPrefix = "w_";
public const string InputBiasPrefix = "b_";
public const string OutputWeightsPrefix = "nw_";
public const string OutputBiasPrefix = "nb_";
/// <summary>
/// Returns the name of the visual observation with a given index

return VisualObservationPlaceholderPrefix + index;
}
static HashSet<string> InferenceInput = new HashSet<string>
{
BatchSizePlaceholder,
SequenceLengthPlaceholder,
VectorObservationPlaceholder,
RecurrentInPlaceholder,
VisualObservationPlaceholderPrefix,
ObservationPlaceholderPrefix,
PreviousActionPlaceholder,
ActionMaskPlaceholder,
RandomNormalEpsilonPlaceholder
};
static HashSet<string> InferenceInputPrefix = new HashSet<string>
{
VisualObservationPlaceholderPrefix,
ObservationPlaceholderPrefix,
};
static HashSet<string> TrainingInput = new HashSet<string>
{
ActionInput,
RewardInput,
TargetInput,
LearningRate,
BatchSizePlaceholder,
};
static HashSet<string> TrainingInputPrefix = new HashSet<string>
{
ObservationPlaceholderPrefix,
NextObservationPlaceholderPrefix,
};
static HashSet<string> ModelParamPrefix = new HashSet<string>
{
InputWeightsPrefix,
InputBiasPrefix,
};
/// <summary>
/// Returns the name of the observation with a given index
/// </summary>

return ObservationPlaceholderPrefix + index;
}
public static string GetInputWeightName(int index)
{
return InputWeightsPrefix + index;
}
public static string GetInputBiasName(int index)
{
return InputBiasPrefix + index;
}
public static bool IsInferenceInputNames(string name)
{
return InferenceInput.Contains(name) || InferenceInputPrefix.Any(s=>name.Contains(s));
}
return name == ActionInput || name == RewardInput || name.StartsWith(NextObservationPlaceholderPrefix);
return TrainingInput.Contains(name) || TrainingInputPrefix.Any(s=>name.Contains(s));
}
public static bool IsModelParamNames(string name)
{
return ModelParamPrefix.Any(s=>name.Contains(s));
}
}
}

33
com.unity.ml-agents/Runtime/Inference/TensorProxy.cs


using System;
using System.Linq;
using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Inference.Utils;

for (var i = 0; i < tensorProxy.data.length; i++)
{
tensorProxy.data[i] = (float)randomNormal.NextDouble();
}
}
public static void RandomInitialize(
TensorProxy tensorProxy, RandomNormal randomNormal, ITensorAllocator allocator)
{
if (tensorProxy.data == null)
{
tensorProxy.data = allocator.Alloc(
new TensorShape(tensorProxy.shape.Select(x => (int)x).ToArray()));
}
for (var i = 0; i < tensorProxy.data.length; i++)
{
tensorProxy.data[i] = (float)randomNormal.NextDouble();
}
}
public static void CopyTensor(TensorProxy source, TensorProxy target)
{
for (var b = 0; b < source.data.batch; b++)
{
for (var i = 0; i < source.data.height; i++)
{
for (var j = 0; j < source.data.width; j++)
{
for(var k = 0; k < source.data.channels; k++)
{
target.data[b, i, j, k] = source.data[b, i, j, k];
}
}
}
}
}
}

71
com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs


using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents;
using UnityEngine;
namespace Unity.MLAgents.Inference
{

var model = (Model)barracudaModel;
// Generator for Inputs
m_Dict[TensorNames.ActionInput] =
new ActionInputGenerator(allocator);
m_Dict[TensorNames.RewardInput] =
new RewardInputGenerator(allocator);
m_Dict[TensorNames.ActionInput] = new ActionInputGenerator(allocator);
m_Dict[TensorNames.TargetInput] = new RewardInputGenerator(allocator);
m_Dict[TensorNames.RewardInput] = new RewardInputGenerator(allocator);
m_Dict[TensorNames.LearningRate] = new ConstantGenerator(allocator, 0.0001f);
m_Dict[TensorNames.BatchSizePlaceholder] = new TrainingBatchSizeGenerator(allocator);
// Generators for Outputs
}

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions)
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions, bool training=false)
{
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
{

throw new UnityAgentsException(
$"Unknown tensorProxy expected as input : {tensor.name}");
}
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions);
if (tensor.name.StartsWith("obs_") || tensor.name == TensorNames.BatchSizePlaceholder)
{
if (training == true)
{
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions);
}
}
else
{
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions);
}
}
}

{
var obsGen = new CopyNextObservationTensorsGenerator(allocator);
var obsGen = new CopyObservationTensorsGenerator(allocator);
var obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGen.SetSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;

{
var obsGen = new CopyObservationTensorsGenerator(allocator);
var obsGen = new CopyNextObservationTensorsGenerator(allocator);
var obsGenName = TensorNames.GetNextObservationName(sensorIndex);
obsGen.SetSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;

}
}
}
internal class RewardInputGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;

}
}
}
internal class CopyObservationTensorsGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;

}
}
}
internal class CopyNextObservationTensorsGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;

}
}
}
internal class ConstantGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
float m_Const;
public ConstantGenerator(ITensorAllocator allocator, float c)
{
m_Allocator = allocator;
m_Const = c;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
{
TensorUtils.ResizeTensor(tensorProxy, 1, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
tensorProxy.data?.Dispose();
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1));
tensorProxy.data[0] = m_Const;
}
}
}
internal class TrainingBatchSizeGenerator : TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
public TrainingBatchSizeGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
{
tensorProxy.data?.Dispose();
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1));
tensorProxy.data[0] = batchSize;
}
}
}

51
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


using Unity.MLAgents.Policies;
using Unity.MLAgents.Sensors;
using UnityEngine;
using Unity.MLAgents.Inference.Utils;
namespace Unity.MLAgents
{

string[] m_TrainingOutputNames;
IReadOnlyList<TensorProxy> m_InferenceInputs;
IReadOnlyList<TensorProxy> m_TrainingInputs;
IReadOnlyList<TensorProxy> m_ModelParametersInputs;
List<TensorProxy> m_InferenceOutputs;
Dictionary<string, Tensor> m_InputsByName;
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();

bool m_ObservationsInitialized;
bool m_TrainingObservationsInitialized;
ReplayBuffer m_Buffer;
/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for

public TrainingModelRunner(
ActionSpec actionSpec,
NNModel model,
ReplayBuffer buffer,
int seed = 0)
{
Model barracudaModel;

// barracudaModel = ModelLoader.Load(new NNModel());
barracudaModel = ModelLoader.Load(model);
m_Model = barracudaModel;
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharp;
WorkerFactory.Type executionDevice = WorkerFactory.Type.CSharpBurst;
m_ModelParametersInputs = barracudaModel.GetModelParamTensors();
InitializeModelParam();
m_OutputNames = barracudaModel.GetOutputNames();
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames();
m_TensorGenerator = new TensorGenerator(

actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel);
m_InputsByName = new Dictionary<string, Tensor>();
m_InferenceOutputs = new List<TensorProxy>();
m_Buffer = buffer;
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)
void InitializeModelParam()
{
RandomNormal randomNormal = new RandomNormal(10);
foreach (var tensor in m_ModelParametersInputs)
{
TensorUtils.RandomInitialize(tensor, randomNormal, m_TensorAllocator);
}
}
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs, bool training=false)
{
m_InputsByName.Clear();
for (var i = 0; i < infInputs.Count; i++)

}
for (var i = 0; i < m_TrainingInputs.Count; i++)
{
var inp = m_TrainingInputs[i];
if (m_InputsByName.ContainsKey(inp.name) && training==false)
{
continue;
}
m_InputsByName[inp.name] = inp.data;
}
for (var i = 0; i < m_ModelParametersInputs.Count; i++)
{
var inp = m_ModelParametersInputs[i];
m_InputsByName[inp.name] = inp.data;
}
}
public void Dispose()

m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator);
m_ObservationsInitialized = true;
}
if (!m_TrainingObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
m_TrainingTensorGenerator.InitializeObservations(m_Buffer.SampleDummyBatch(1)[0], m_TensorAllocator);
m_TrainingObservationsInitialized = true;
}
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize));
PrepareBarracudaInputs(m_InferenceInputs);

m_TrainingTensorGenerator.InitializeObservations(transitions[0], m_TensorAllocator);
m_TrainingObservationsInitialized = true;
}
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions);
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, true);
PrepareBarracudaInputs(m_TrainingInputs);
PrepareBarracudaInputs(m_TrainingInputs, true);
// Execute the Model
m_Engine.Execute(m_InputsByName);

// Update the model
// m_Model.weights = m_InferenceOutputs.weights
// CopyWeights(w_0, nw_0)
}
public ActionBuffers GetAction(int agentId)

5
com.unity.ml-agents/Runtime/Policies/TrainingPolicy.cs


{
m_buffer.Push(m_LastInfo, m_LastObservations, m_CurrentObservations);
}
else if (m_buffer.Count == 0)
{
// hack
m_buffer.Push(info, m_CurrentObservations, m_CurrentObservations);
}
m_LastInfo = info;
m_LastObservations = m_CurrentObservations;

19
com.unity.ml-agents/Runtime/ReplayBuffer.cs


internal class ReplayBuffer
{
List<Transition> m_Buffer;
int currentIndex;
int m_CurrentIndex;
int m_MaxSize;
public ReplayBuffer(int maxSize)

}
else
{
m_Buffer[currentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState};
m_Buffer[m_CurrentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState};
currentIndex += 1;
currentIndex = currentIndex % m_MaxSize;
m_CurrentIndex += 1;
m_CurrentIndex = m_CurrentIndex % m_MaxSize;
}
public List<Transition> SampleBatch(int batchSize)

for (var i = 0; i < batchSize; i++)
{
samples.Add(m_Buffer[indexList[i]]);
}
return samples;
}
public List<Transition> SampleDummyBatch(int batchSize)
{
var indexList = SampleIndex(batchSize);
var samples = new List<Transition>(batchSize);
for (var i = 0; i < batchSize; i++)
{
samples.Add(m_Buffer[m_CurrentIndex-1]);
}
return samples;
}

5
com.unity.ml-agents/Runtime/Trainer.cs


using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;
namespace Unity.MLAgents
{
internal class TrainerConfig

m_Config = config ?? new TrainerConfig();
m_behaviorName = behaviorName;
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, seed);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, seed);
// copy weights from model to target model
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
Academy.Instance.TrainerUpdate += Update;

正在加载...
取消
保存