浏览代码

new format

/ai-hw-2021
Ruo-Ping Dong 3 年前
当前提交
24154ec4
共有 9 个文件被更改,包括 216 次插入286 次删除
  1. 41
      Project/Assets/ML-Agents/Examples/3DBall/Prefabs/3DBall.prefab
  2. 52
      com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs
  3. 118
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  4. 84
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  5. 15
      com.unity.ml-agents/Runtime/Inference/TensorProxy.cs
  6. 103
      com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs
  7. 79
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
  8. 5
      com.unity.ml-agents/Runtime/ReplayBuffer.cs
  9. 5
      com.unity.ml-agents/Runtime/Trainer.cs

41
Project/Assets/ML-Agents/Examples/3DBall/Prefabs/3DBall.prefab


m_PrefabInstance: {fileID: 0}
m_PrefabAsset: {fileID: 0}
m_GameObject: {fileID: 1036225416237908}
m_Material: {fileID: 13400000, guid: 56162663048874fd4b10e065f9cf78b7, type: 2}
m_Material: {fileID: 0}
m_IsTrigger: 0
m_Enabled: 1
serializedVersion: 2

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
m_AdditionalVertexStreams: {fileID: 0}
--- !u!54 &54597526346971362
Rigidbody:
m_ObjectHideFlags: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
m_AdditionalVertexStreams: {fileID: 0}
--- !u!1 &1321468028730240
GameObject:
m_ObjectHideFlags: 0

VectorObservationSize: 8
NumStackedVectorObservations: 1
m_ActionSpec:
m_NumContinuousActions: 2
BranchSizes:
VectorActionSize: 02000000
m_NumContinuousActions: 0
BranchSizes: 0a0000000a000000
VectorActionSize: 0a0000000a000000
VectorActionSpaceType: 1
VectorActionSpaceType: 0
m_Model: {fileID: 11400000, guid: 20a7b83be6b0c493d9271c65c897eb9b, type: 3}
m_Model: {fileID: 5022602860645237092, guid: 35d5202e6dbc04a50934f20df199b47f, type: 3}
m_BehaviorType: 0
m_BehaviorType: 3
m_BehaviorName: 3DBall
TeamId: 0
m_UseChildSensors: 1

m_ClearFlags: 2
m_BackGroundColor: {r: 0.46666667, g: 0.5647059, b: 0.60784316, a: 1}
m_projectionMatrixMode: 1
m_GateFitMode: 2
m_FOVAxisMode: 0
m_GateFitMode: 2
m_FocalLength: 50
m_NormalizedViewPortRect:
serializedVersion: 2

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
m_AdditionalVertexStreams: {fileID: 0}
--- !u!1 &1854695166504686
GameObject:
m_ObjectHideFlags: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
m_AdditionalVertexStreams: {fileID: 0}
--- !u!1 &1859240399150782
GameObject:
m_ObjectHideFlags: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
m_AdditionalVertexStreams: {fileID: 0}
--- !u!1 &1999020414315134
GameObject:
m_ObjectHideFlags: 0

m_MotionVectors: 1
m_LightProbeUsage: 1
m_ReflectionProbeUsage: 1
m_RayTracingMode: 2
m_RayTraceProcedural: 0
m_RenderingLayerMask: 1
m_RendererPriority: 0
m_Materials:

m_ProbeAnchor: {fileID: 0}
m_LightProbeVolumeOverride: {fileID: 0}
m_ScaleInLightmap: 1
m_ReceiveGI: 1
m_PreserveUVs: 1
m_IgnoreNormalsForChartDetection: 0
m_ImportantGI: 0

m_SortingLayerID: 0
m_SortingLayer: 0
m_SortingOrder: 0
m_AdditionalVertexStreams: {fileID: 0}

52
com.unity.ml-agents/Runtime/Inference/ApplierImpl.cs


public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var actionSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
var actionSpaceSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
for (var i = 0; i < actionIds.Count; i++)
{

var discreteBuffer = actionBuffer.DiscreteActions;
var maxIndex = 0;
var maxValue = 0;
for (var j = 0; j < actionSize; j++)
for (var j = 0; j < actionSpaceSize; j++)
{
var value = (int)tensorProxy.data[agentIndex, j];
if (value > maxValue)

}
var actionSize = discreteBuffer.Length;
}
agentIndex++;
}
}
}
internal class ContinuousFromDiscreteOutputApplier : TensorApplier.IApplier
{
readonly ActionSpec m_ActionSpec;
int m_NumDiscretization;
public ContinuousFromDiscreteOutputApplier(ActionSpec actionSpec, int seed, ITensorAllocator allocator, int numDiscretization)
{
m_ActionSpec = actionSpec;
m_NumDiscretization = numDiscretization;
}
public void Apply(TensorProxy tensorProxy, IList<int> actionIds, Dictionary<int, ActionBuffers> lastActions)
{
var agentIndex = 0;
var actionSpaceSize = tensorProxy.shape[tensorProxy.shape.Length - 1];
for (var i = 0; i < actionIds.Count; i++)
{
var agentId = actionIds[i];
if (lastActions.ContainsKey(agentId))
{
var actionBuffer = lastActions[agentId];
if (actionBuffer.IsEmpty())
{
actionBuffer = new ActionBuffers(m_ActionSpec);
lastActions[agentId] = actionBuffer;
}
var continuousBuffer = actionBuffer.ContinuousActions;
var maxIndex = 0;
var maxValue = 0;
for (var j = 0; j < actionSpaceSize; j++)
{
var value = (int)tensorProxy.data[agentIndex, j];
if (value > maxValue)
{
maxIndex = j;
}
}
continuousBuffer[0] = ((maxIndex/m_NumDiscretization)/(m_NumDiscretization-1)/2)-1;
continuousBuffer[1] = ((maxIndex%m_NumDiscretization)/(m_NumDiscretization-1)/2)-1;
}
agentIndex++;
}

118
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


/// <returns>The api version of the model</returns>
public static int GetVersion(this Model model)
{
return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
// return (int)model.GetTensorByName(TensorNames.VersionNumber)[0];
return 3;
}
/// <summary>

foreach (var input in model.inputs)
{
if (TensorNames.IsInferenceInputNames(input.name))
tensors.Add(new TensorProxy
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
foreach (var mem in model.memories)

foreach (var input in model.inputs)
{
if (TensorNames.IsTrainingInputNames(input.name))
tensors.Add(new TensorProxy
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
public static IReadOnlyList<TensorProxy> GetModelParamTensors(this Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
if (TensorNames.IsModelParamNames(input.name))
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = GetShape(input)
});
}
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));

// hack the shape for now
public static long[] GetShape(Model.Input tensor)
{
if (tensor.name == "b_2")
{
return new long[] {1, 1, 1, 1, 1, 1, 1, 3}; //output
}
else if (tensor.name.StartsWith("b_"))
{
return new long[] {1, 1, 1, 1, 1, 1, 1, 128}; //hidden
}
else
{
return tensor.shape.Select(i => (long)i).ToArray();
}
}
/// <summary>
/// Get number of visual observation inputs to the model.
/// </summary>

return names.ToArray();
}
foreach (var output in model.outputs)
{
if (output.Contains("weight") || output.Contains("bias"))
{
names.Add(output);
}
}
names.Add(TensorNames.TrainingStateOut);
names.Add(TensorNames.OuputLoss);
names.Add(TensorNames.TrainingOutput);
names.Sort();

public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
if (modelApiVersionTensor == null)
{
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
);
return false;
}
// var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
// if (modelApiVersionTensor == null)
// {
// failedModelChecks.Add(
// FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.")
// );
// return false;
// }
var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
if (memorySizeTensor == null)
{
failedModelChecks.Add(
FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
);
return false;
}
// var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize);
// if (memorySizeTensor == null)
// {
// failedModelChecks.Add(
// FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.")
// );
// return false;
// }
// Check the presence of action output tensor
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&

84
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


// Deprecated TensorNames entries for backward compatibility
public const string IsContinuousControlDeprecated = "is_continuous_control";
public const string ActionOutputDeprecated = "action";
public const string ActionOutputDeprecated = "action_";
public const string ActionInput = "action_in";
public const string Observations = "input";
public const string ActionInput = "action";
public const string NextObservationPlaceholderPrefix = "next_obs_";
public const string TargetInput = "target";
public const string DoneInput = "done";
public const string Gamma = "gamma";
public const string NextObservations = "next_state";
public const string InputWeightsPrefix = "w_";
public const string InputBiasPrefix = "b_";
public const string OutputWeightsPrefix = "nw_";
public const string OutputBiasPrefix = "nb_";
public const string TrainingStateIn = "training_state.1";
public const string TrainingOutput = "output";
public const string OuputLoss = "loss";
public const string TrainingStateOut = "training_state";
public const string InitialTrainingState = "initial_training_state";
/// <summary>
/// Returns the name of the visual observation with a given index

return VisualObservationPlaceholderPrefix + index;
}
static HashSet<string> InferenceInput = new HashSet<string>
{
BatchSizePlaceholder,
SequenceLengthPlaceholder,
VectorObservationPlaceholder,
RecurrentInPlaceholder,
VisualObservationPlaceholderPrefix,
ObservationPlaceholderPrefix,
PreviousActionPlaceholder,
ActionMaskPlaceholder,
RandomNormalEpsilonPlaceholder
};
static HashSet<string> InferenceInputPrefix = new HashSet<string>
{
VisualObservationPlaceholderPrefix,
ObservationPlaceholderPrefix,
};
static HashSet<string> TrainingInput = new HashSet<string>
{
ActionInput,
RewardInput,
TargetInput,
LearningRate,
BatchSizePlaceholder,
};
static HashSet<string> TrainingInputPrefix = new HashSet<string>
{
ObservationPlaceholderPrefix,
NextObservationPlaceholderPrefix,
};
static HashSet<string> ModelParamPrefix = new HashSet<string>
{
InputWeightsPrefix,
InputBiasPrefix,
};
/// <summary>
/// Returns the name of the observation with a given index
/// </summary>

}
public static string GetNextObservationName(int index)
{
return ObservationPlaceholderPrefix + index;
}
public static string GetInputWeightName(int index)
{
return InputWeightsPrefix + index;
}
public static string GetInputBiasName(int index)
{
return InputBiasPrefix + index;
}
public static bool IsInferenceInputNames(string name)
{
return InferenceInput.Contains(name) || InferenceInputPrefix.Any(s=>name.Contains(s));
}
public static bool IsTrainingInputNames(string name)
{
return TrainingInput.Contains(name) || TrainingInputPrefix.Any(s=>name.Contains(s));
}
public static bool IsModelParamNames(string name)
{
return ModelParamPrefix.Any(s=>name.Contains(s));
}
}
}

15
com.unity.ml-agents/Runtime/Inference/TensorProxy.cs


}
}
public static void RandomInitialize(
TensorProxy tensorProxy, RandomNormal randomNormal, ITensorAllocator allocator)
{
if (tensorProxy.data == null)
{
tensorProxy.data = allocator.Alloc(
new TensorShape(tensorProxy.shape.Select(x => (int)x).ToArray()));
}
for (var i = 0; i < tensorProxy.data.length; i++)
{
tensorProxy.data[i] = (float)randomNormal.NextDouble();
}
}
public static void CopyTensor(TensorProxy source, TensorProxy target)
{
for (var b = 0; b < source.data.batch; b++)

103
com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs


public interface ITrainingGenerator
{
void Generate(
TensorProxy tensorProxy, int batchSize, IList<Transition> transitions);
TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState);
}
readonly Dictionary<string, ITrainingGenerator> m_Dict = new Dictionary<string, ITrainingGenerator>();

int seed,
ITensorAllocator allocator,
object barracudaModel = null)
float learning_rate,
float gamma,
object barracudaModel = null
)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)

var model = (Model)barracudaModel;
// Generator for Inputs
var obsGen = new CopyObservationTensorsGenerator(allocator);
obsGen.SetSensorIndex(0);
m_Dict[TensorNames.Observations] = obsGen;
var nextObsGen = new CopyNextObservationTensorsGenerator(allocator);
nextObsGen.SetSensorIndex(0);
m_Dict[TensorNames.NextObservations] = nextObsGen;
m_Dict[TensorNames.TargetInput] = new RewardInputGenerator(allocator);
m_Dict[TensorNames.LearningRate] = new ConstantGenerator(allocator, 0.0001f);
m_Dict[TensorNames.DoneInput] = new DoneInputGenerator(allocator);
m_Dict[TensorNames.LearningRate] = new ConstantGenerator(allocator,learning_rate);
m_Dict[TensorNames.Gamma] = new ConstantGenerator(allocator, gamma);
// Generators for Outputs
m_Dict[TensorNames.TrainingStateIn] = new TrainingStateGenerator(allocator);
}
/// <summary>

/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions, bool training=false)
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions, TensorProxy trainingState, bool training=false)
{
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
{

throw new UnityAgentsException(
$"Unknown tensorProxy expected as input : {tensor.name}");
}
if (tensor.name.StartsWith("obs_") || tensor.name == TensorNames.BatchSizePlaceholder)
{
if (training == true)
{
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions);
}
}
else
if ((tensor.name == TensorNames.Observations || tensor.name == TensorNames.BatchSizePlaceholder) && training == false)
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions);
continue;
}
}
public void InitializeObservations(Transition transition, ITensorAllocator allocator)
{
for (var sensorIndex = 0; sensorIndex < transition.state.Count; sensorIndex++)
{
var obsGen = new CopyObservationTensorsGenerator(allocator);
var obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGen.SetSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
for (var sensorIndex = 0; sensorIndex < transition.nextState.Count; sensorIndex++)
{
var obsGen = new CopyNextObservationTensorsGenerator(allocator);
var obsGenName = TensorNames.GetNextObservationName(sensorIndex);
obsGen.SetSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions, trainingState);
}
}

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)

}
}
internal class DoneInputGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
public DoneInputGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
tensorProxy.data[index, 0] = transitions[index].done==true ? 1f : 0f;
}
}
}
internal class CopyObservationTensorsGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;

m_SensorIndex = index;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)

m_SensorIndex = index;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)

m_Const = c;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, 1, m_Allocator);
for (var index = 0; index < batchSize; index++)

m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
tensorProxy.data?.Dispose();
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1));

internal class TrainingStateGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
public TrainingStateGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
TensorUtils.CopyTensor(trainingState, tensorProxy);
}
}
}
}

79
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Inference;
using Unity.MLAgents.Policies;

List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>();
Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>();
List<int> m_OrderedAgentsRequestingDecisions = new List<int>();
TensorProxy m_TrainingState;
ITensorAllocator m_TensorAllocator;
TensorGenerator m_TensorGenerator;

Model m_Model;
NNModel m_TargetModel;
string m_ModelName;
InferenceDevice m_InferenceDevice;
IReadOnlyList<TensorProxy> m_InferenceInputs;
IReadOnlyList<TensorProxy> m_ModelParametersInputs;
List<TensorProxy> m_InferenceOutputs;
List<TensorProxy> m_TrainingOutputs;
SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator();
bool m_ObservationsInitialized;
bool m_TrainingObservationsInitialized;

ActionSpec actionSpec,
NNModel model,
ReplayBuffer buffer,
TrainerConfig config,
int seed = 0)
{
Model barracudaModel;

// barracudaModel = ModelLoader.Load(new NNModel());
m_InferenceInputs = barracudaModel.GetInputTensors();
m_ModelParametersInputs = barracudaModel.GetModelParamTensors();
InitializeModelParam();
seed, m_TensorAllocator, barracudaModel);
seed, m_TensorAllocator, config.learningRate, config.gamma, barracudaModel);
m_InferenceOutputs = new List<TensorProxy>();
m_TrainingOutputs = new List<TensorProxy>();
void InitializeModelParam()
void InitializeTrainingState()
RandomNormal randomNormal = new RandomNormal(10);
foreach (var tensor in m_ModelParametersInputs)
{
TensorUtils.RandomInitialize(tensor, randomNormal, m_TensorAllocator);
}
// TODO: initialize m_TrainingState
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs, bool training=false)
void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs)
{
m_InputsByName.Clear();
for (var i = 0; i < infInputs.Count; i++)

}
for (var i = 0; i < m_TrainingInputs.Count; i++)
{
var inp = m_TrainingInputs[i];
if (m_InputsByName.ContainsKey(inp.name) && training==false)
{
continue;
}
m_InputsByName[inp.name] = inp.data;
}
for (var i = 0; i < m_ModelParametersInputs.Count; i++)
{
var inp = m_ModelParametersInputs[i];
m_InputsByName[inp.name] = inp.data;
}
}
public void Dispose()

void FetchBarracudaOutputs(string[] names)
{
m_InferenceOutputs.Clear();
m_TrainingOutputs.Clear();
m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
m_TrainingOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n));
}
}

m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator);
m_ObservationsInitialized = true;
}
if (!m_TrainingObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
m_TrainingTensorGenerator.InitializeObservations(m_Buffer.SampleDummyBatch(1)[0], m_TensorAllocator);
m_TrainingObservationsInitialized = true;
}
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos);
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize));
m_TensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Infos);
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize), m_TrainingState);
PrepareBarracudaInputs(m_InferenceInputs);
PrepareBarracudaInputs(m_TrainingInputs);
// Execute the Model
m_Engine.Execute(m_InputsByName);

// Update the outputs
m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
m_TensorApplier.ApplyTensors(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
m_Infos.Clear();

{
return;
}
if (!m_TrainingObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
m_TrainingTensorGenerator.InitializeObservations(transitions[0], m_TensorAllocator);
m_TrainingObservationsInitialized = true;
}
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, true);
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions, m_TrainingState, true);
PrepareBarracudaInputs(m_TrainingInputs, true);
PrepareBarracudaInputs(m_TrainingInputs);
// Execute the Model
m_Engine.Execute(m_InputsByName);

// Update the model
// CopyWeights(w_0, nw_0)
// m_TensorApplier.UpdateModel(m_TrainingOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived);
}
public ActionBuffers GetAction(int agentId)

5
com.unity.ml-agents/Runtime/ReplayBuffer.cs


public IReadOnlyList<TensorProxy> state;
public ActionBuffers action;
public float reward;
public bool done;
public IReadOnlyList<TensorProxy> nextState;
}

{
if (m_Buffer.Count < m_MaxSize)
{
m_Buffer.Add(new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState});
m_Buffer.Add(new Transition() {state=state, action=info.storedActions, reward=info.reward, done=info.done, nextState=nextState});
m_Buffer[m_CurrentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, nextState=nextState};
m_Buffer[m_CurrentIndex] = new Transition() {state=state, action=info.storedActions, reward=info.reward, done=info.done, nextState=nextState};
}
m_CurrentIndex += 1;
m_CurrentIndex = m_CurrentIndex % m_MaxSize;

5
com.unity.ml-agents/Runtime/Trainer.cs


public int bufferSize = 100;
public int batchSize = 4;
public float gamma = 0.99f;
public float learningRate = 0.0005f;
public int updateTargetFreq = 200;
}

m_Config = config ?? new TrainerConfig();
m_behaviorName = behaviorName;
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, seed);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
// copy weights from model to target model
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
Academy.Instance.TrainerUpdate += Update;

正在加载...
取消
保存