浏览代码

convert samples to tensors. add tensorGenerator for update.

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
30f60427
共有 7 个文件被更改,包括 335 次插入25 次删除
  1. 63
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  2. 14
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  3. 57
      com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs
  4. 6
      com.unity.ml-agents/Runtime/ReplayBuffer.cs
  5. 18
      com.unity.ml-agents/Runtime/Trainer.cs
  6. 191
      com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs
  7. 11
      com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs.meta

63
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


foreach (var input in model.inputs)
{
tensors.Add(new TensorProxy
if (!TensorNames.IsTrainingInputNames(input.name))
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
}
foreach (var mem in model.memories)

return tensors;
}
public static IReadOnlyList<TensorProxy> GetTrainingInputTensors(this Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
if (TensorNames.IsTrainingInputNames(input.name) || input.name.StartsWith(TensorNames.ObservationPlaceholderPrefix))
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
/// <summary>
/// Get number of visual observation inputs to the model.
/// </summary>

foreach (var mem in model.memories)
{
names.Add(mem.output);
}
}
names.Sort();
return names.ToArray();
}
public static string[] GetTrainingOutputNames(this Model model)
{
var names = new List<string>();
if (model == null)
{
return names.ToArray();
}
foreach (var output in model.outputs)
{
if (output.Contains("weight") || output.Contains("bias"))
{
names.Add(output);
}
}

14
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string ActionOutputDeprecated = "action";
public const string ActionOutputShapeDeprecated = "action_output_shape";
// Tensors for in-editor training
public const string ActionInput = "action_in";
public const string RewardInput = "reward";
public const string NextObservationPlaceholderPrefix = "next_obs_";
/// <summary>
/// Returns the name of the visual observation with a given index
/// </summary>

public static string GetObservationName(int index)
{
return ObservationPlaceholderPrefix + index;
}
public static string GetNextObservationName(int index)
{
return ObservationPlaceholderPrefix + index;
}
public static bool IsTrainingInputNames(string name)
{
return name == ActionInput || name == RewardInput || name.StartsWith(NextObservationPlaceholderPrefix);
}
}
}

57
com.unity.ml-agents/Runtime/Policies/TrainingModelRunner.cs


using System.Collections.Generic;
using Unity.Barracuda;
using UnityEngine.Profiling;
using UnityEngine;
namespace Unity.MLAgents
{

ITensorAllocator m_TensorAllocator;
TensorGenerator m_TensorGenerator;
TrainingTensorGenerator m_TrainingTensorGenerator;
TensorApplier m_TensorApplier;
Model m_Model;

IWorker m_Engine;
bool m_Verbose = false;
string[] m_OutputNames;
string[] m_TrainingOutputNames;
IReadOnlyList<TensorProxy> m_TrainingInputs;
List<TensorProxy> m_InferenceOutputs;
Dictionary<string, Tensor> m_InputsByName;
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();

bool m_ObservationsInitialized;
bool m_TrainingObservationsInitialized;
/// <summary>
/// Initializes the Brain with the Model that it will use when selecting actions for

m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose);
m_InferenceInputs = barracudaModel.GetInputTensors();
m_TrainingInputs = barracudaModel.GetTrainingInputTensors();
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames();
m_TrainingTensorGenerator = new TrainingTensorGenerator(
seed, m_TensorAllocator, barracudaModel);
m_TensorApplier = new TensorApplier(
actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel);
m_InputsByName = new Dictionary<string, Tensor>();

m_OrderedAgentsRequestingDecisions.Clear();
}
public void UpdateModel(List<Transition> transitions)
{
var currentBatchSize = transitions.Count;
if (currentBatchSize == 0)
{
return;
}
if (!m_TrainingObservationsInitialized)
{
// Just grab the first agent in the collection (any will suffice, really).
// We check for an empty Collection above, so this will always return successfully.
m_TrainingTensorGenerator.InitializeObservations(transitions[0], m_TensorAllocator);
m_TrainingObservationsInitialized = true;
}
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, transitions);
PrepareBarracudaInputs(m_TrainingInputs);
// Execute the Model
m_Engine.Execute(m_InputsByName);
FetchBarracudaOutputs(m_TrainingOutputNames);
// Update the model
// m_Model.weights = m_InferenceOutputs.weights
}
public ActionBuffers GetAction(int agentId)
{
if (m_LastActionsReceived.ContainsKey(agentId))

return ActionBuffers.Empty;
}
// void PrintTensor(TensorProxy tensor)
// {
// Debug.Log($"Print tensor {tensor.name}");
// for (var b = 0; b < tensor.data.batch; b++)
// {
// var message = new List<float>();
// for (var i = 0; i < tensor.data.height; i++)
// {
// for (var j = 0; j < tensor.data.width; j++)
// {
// for(var k = 0; k < tensor.data.channels; k++)
// {
// message.Add(tensor.data[b, i, j, k]);
// }
// }
// }
// Debug.Log(string.Join(", ", message));
// }
// }
}
}

6
com.unity.ml-agents/Runtime/ReplayBuffer.cs


currentIndex = currentIndex % m_MaxSize;
}
public Transition[] SampleBatch(int batchSize)
public List<Transition> SampleBatch(int batchSize)
var samples = new Transition[batchSize];
var samples = new List<Transition>(batchSize);
samples[i] = m_Buffer[indexList[i]];
samples.Add(m_Buffer[indexList[i]]);
}
return samples;
}

18
com.unity.ml-agents/Runtime/Trainer.cs


using Unity.MLAgents.Actuators;
using Unity.Barracuda;
using UnityEngine;
public int bufferSize = 1024;
public int batchSize = 64;
public int bufferSize = 100;
public int batchSize = 4;
public float gamma = 0.99f;
public int updateTargetFreq = 200;
}

}
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
// states = [s.state for s in samples]
// actions = [s.action for s in samples]
// q_values = policy_net(states).gather(1, actions)
// next_states = [s.next_state for s in samples]
// rewards = [s.reward for s in samples]
// next_state_values = target_net(non_final_next_states).max(1)[0]
// expected_q_values = (next_state_values * GAMMA) + rewards
// loss = MSE(q_values, expected_q_values);
// m_ModelRunner.model = Barracuda.ModelUpdate(m_ModelRunner.model, loss);
m_ModelRunner.UpdateModel(samples);
// Update target network
if (m_TrainingStep % m_Config.updateTargetFreq == 0)

191
com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs


using System.Collections.Generic;
using Unity.Barracuda;
using Unity.MLAgents.Sensors;
using Unity.MLAgents;
namespace Unity.MLAgents.Inference
{
internal class TrainingTensorGenerator
{
public interface ITrainingGenerator
{
void Generate(
TensorProxy tensorProxy, int batchSize, IList<Transition> transitions);
}
readonly Dictionary<string, ITrainingGenerator> m_Dict = new Dictionary<string, ITrainingGenerator>();
public TrainingTensorGenerator(
int seed,
ITensorAllocator allocator,
object barracudaModel = null)
{
// If model is null, no inference to run and exception is thrown before reaching here.
if (barracudaModel == null)
{
return;
}
var model = (Model)barracudaModel;
// Generator for Inputs
m_Dict[TensorNames.ActionInput] =
new ActionInputGenerator(allocator);
m_Dict[TensorNames.RewardInput] =
new RewardInputGenerator(allocator);
// Generators for Outputs
}
/// <summary>
/// Populates the data of the tensor inputs given the data contained in the current batch
/// of agents.
/// </summary>
/// <param name="tensors"> Enumerable of tensors that will be modified.</param>
/// <param name="currentBatchSize"> The number of agents present in the current batch
/// </param>
/// <param name="infos"> List of AgentsInfos and Sensors that contains the
/// data that will be used to modify the tensors</param>
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
/// associated generator.</exception>
public void GenerateTensors(
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions)
{
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++)
{
var tensor = tensors[tensorIndex];
if (!m_Dict.ContainsKey(tensor.name))
{
throw new UnityAgentsException(
$"Unknown tensorProxy expected as input : {tensor.name}");
}
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions);
}
}
public void InitializeObservations(Transition transition, ITensorAllocator allocator)
{
for (var sensorIndex = 0; sensorIndex < transition.state.Count; sensorIndex++)
{
var obsGen = new CopyNextObservationTensorsGenerator(allocator);
var obsGenName = TensorNames.GetObservationName(sensorIndex);
obsGen.SetSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
for (var sensorIndex = 0; sensorIndex < transition.nextState.Count; sensorIndex++)
{
var obsGen = new CopyObservationTensorsGenerator(allocator);
var obsGenName = TensorNames.GetNextObservationName(sensorIndex);
obsGen.SetSensorIndex(sensorIndex);
m_Dict[obsGenName] = obsGen;
}
}
public static void CopyTensorToBatch(TensorProxy source, TensorProxy target, int batchIndex)
{
for (var i = 0; i < source.Height; i++)
{
for (var j = 0; j < source.Width; j++)
{
for(var k = 0; k < source.Channels; k++)
{
target.data[batchIndex, i, j, k] = source.data[0, i, j, k];
}
}
}
}
}
internal class ActionInputGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
public ActionInputGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
var actions = transitions[index].action.DiscreteActions;
for (var j = 0; j < actions.Length; j++)
{
tensorProxy.data[index, j] = actions[j];
}
}
}
}
internal class RewardInputGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
public RewardInputGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
tensorProxy.data[index, 0] = transitions[index].reward;
}
}
}
internal class CopyObservationTensorsGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
int m_SensorIndex;
public CopyObservationTensorsGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void SetSensorIndex(int index)
{
m_SensorIndex = index;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
TrainingTensorGenerator.CopyTensorToBatch(transitions[index].state[m_SensorIndex], tensorProxy, index);
}
}
}
internal class CopyNextObservationTensorsGenerator: TrainingTensorGenerator.ITrainingGenerator
{
readonly ITensorAllocator m_Allocator;
int m_SensorIndex;
public CopyNextObservationTensorsGenerator(ITensorAllocator allocator)
{
m_Allocator = allocator;
}
public void SetSensorIndex(int index)
{
m_SensorIndex = index;
}
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
TrainingTensorGenerator.CopyTensorToBatch(transitions[index].nextState[m_SensorIndex], tensorProxy, index);
}
}
}
}

11
com.unity.ml-agents/Runtime/Inference/TrainingTensorGenerator.cs.meta


fileFormatVersion: 2
guid: cca690e21a2fe49b49f636cd4e76e0b4
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:
正在加载...
取消
保存