浏览代码

fix input to model

/ai-hw-2021
Ruo-Ping Dong 4 年前
当前提交
27704856
共有 7 个文件被更改,包括 81 次插入34 次删除
  1. 65
      com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
  2. 4
      com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs
  3. 6
      com.unity.ml-agents/Runtime/Inference/TensorNames.cs
  4. 12
      com.unity.ml-agents/Runtime/Training/Trainer.cs
  5. 20
      com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs
  6. 1
      com.unity.ml-agents/Runtime/Training/TrainingPolicy.cs
  7. 7
      com.unity.ml-agents/Runtime/Training/TrainingTensorGenerator.cs

65
com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs


return tensors;
}
// This is for creating a Tensor to store observations/nextObservations in buffer
public static IReadOnlyList<TensorProxy> GetTrainingObservationInputTensors(this Model model)
{
var tensors = new List<TensorProxy>();
if (model == null)
return tensors;
foreach (var input in model.inputs)
{
if (input.name == TensorNames.Observations)
{
tensors.Add(new TensorProxy
{
name = input.name,
valueType = TensorProxy.TensorType.FloatingPoint,
data = null,
shape = input.shape.Select(i => (long)i).ToArray()
});
}
}
tensors.Sort((el1, el2) => el1.name.CompareTo(el2.name));
return tensors;
}
/// <summary>
/// Get number of visual observation inputs to the model.
/// </summary>

{
if (model == null)
return false;
if (!model.SupportsContinuousAndDiscrete())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0;
}
else
{
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
}
return false;
// if (!model.SupportsContinuousAndDiscrete())
// {
// return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0;
// }
// else
// {
// return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
// (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
// }
}
/// <summary>

{
if (model == null)
return false;
if (!model.SupportsContinuousAndDiscrete())
{
return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0;
}
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
(int)model.DiscreteOutputSize() > 0;
}
return false;
// if (!model.SupportsContinuousAndDiscrete())
// {
// return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0;
// }
// else
// {
// return model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
// (int)model.DiscreteOutputSize() > 0;
// }
}
/// <summary>

4
com.unity.ml-agents/Runtime/Inference/TensorGenerator.cs


m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] =
new RandomNormalInputGenerator(seed, allocator);
var obsGen = new ObservationGenerator(allocator);
obsGen.AddSensorIndex(0);
m_Dict[TensorNames.Observations] = obsGen;
// Generators for Outputs
if (model.HasContinuousOutputs())

6
com.unity.ml-agents/Runtime/Inference/TensorNames.cs


public const string VectorObservationPlaceholder = "vector_observation";
public const string RecurrentInPlaceholder = "recurrent_in";
public const string VisualObservationPlaceholderPrefix = "visual_observation_";
public const string ObservationPlaceholderPrefix = "obs_";
public const string ObservationPlaceholderPrefix = "iobs_";
public const string PreviousActionPlaceholder = "prev_action";
public const string ActionMaskPlaceholder = "action_masks";
public const string RandomNormalEpsilonPlaceholder = "epsilon";

public const string ActionOutputShapeDeprecated = "action_output_shape";
// Tensors for in-editor training
public const string Observations = "input";
public const string Observations = "obs_0";
public const string NextObservations = "next_state";
public const string NextObservations = "next_obs_0";
public const string LearningRate = "lr";
public const string TrainingStateIn = "training_state.1";

12
com.unity.ml-agents/Runtime/Training/Trainer.cs


m_behaviorName = behaviorName;
m_Buffer = new ReplayBuffer(m_Config.bufferSize);
m_ModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
// m_TargetModelRunner = new TrainingModelRunner(actionSpec, model, m_Buffer, m_Config, seed);
// copy weights from model to target model
// m_TargetModelRunner.model.weights = m_ModelRunner.model.weights
Academy.Instance.TrainerUpdate += Update;

}
var samples = m_Buffer.SampleBatch(m_Config.batchSize);
m_ModelRunner.UpdateModel(samples);
// m_ModelRunner.UpdateModel(samples);
if (m_TrainingStep % m_Config.updateTargetFreq == 0)
{
// copy weights
}
// if (m_TrainingStep % m_Config.updateTargetFreq == 0)
// {
// // copy weights
// }
m_TrainingStep += 1;
}

20
com.unity.ml-agents/Runtime/Training/TrainingModelRunner.cs


bool m_Verbose = false;
string[] m_TrainingOutputNames;
IReadOnlyList<TensorProxy> m_TrainingInputs;
IReadOnlyList<TensorProxy> m_InferenceInputs;
List<TensorProxy> m_TrainingOutputs;
Dictionary<string, Tensor> m_InputsByName;
Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>();

m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose);
m_TrainingInputs = barracudaModel.GetTrainingInputTensors();
List<TensorProxy> infTensors = new List<TensorProxy>();
foreach(var tensor in m_TrainingInputs)
{
if (tensor.name == TensorNames.Observations || tensor.name == TensorNames.BatchSizePlaceholder)
{
infTensors.Add(tensor);
}
}
m_InferenceInputs = (IReadOnlyList<TensorProxy>) infTensors;
m_TrainingOutputNames = barracudaModel.GetTrainingOutputNames();
m_TensorGenerator = new TensorGenerator(
seed, m_TensorAllocator, m_Memories, barracudaModel);

m_InputsByName = new Dictionary<string, Tensor>();
m_TrainingOutputs = new List<TensorProxy>();
m_Buffer = buffer;
InitializeTrainingState();
}
void InitializeTrainingState()

var inp = infInputs[i];
m_InputsByName[inp.name] = inp.data;
}
}
public ITensorAllocator Allocator
{
get => m_TensorAllocator;
}
public void Dispose()

public IReadOnlyList<TensorProxy> GetInputTensors()
{
return m_Model.GetInputTensors();
return m_Model.GetTrainingObservationInputTensors();
}
public void DecideBatch()

}
// Prepare the input tensors to be feed into the engine
m_TensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Infos);
m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos);
m_TrainingTensorGenerator.GenerateTensors(m_TrainingInputs, currentBatchSize, m_Buffer.SampleDummyBatch(currentBatchSize), m_TrainingState);
PrepareBarracudaInputs(m_TrainingInputs);

1
com.unity.ml-agents/Runtime/Training/TrainingPolicy.cs


m_LastInfo = info;
for (var i = 0; i < m_CurrentObservations.Count; i++)
{
TensorUtils.ResizeTensor(m_LastObservations[i], 1, m_ModelRunner.Allocator);
TensorUtils.CopyTensor(m_CurrentObservations[i], m_LastObservations[i]);
}
m_HasLastObservation = true;

7
com.unity.ml-agents/Runtime/Training/TrainingTensorGenerator.cs


public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState)
{
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator);
for (var index = 0; index < batchSize; index++)
{
TensorUtils.CopyTensor(trainingState, tensorProxy);
}
TensorUtils.ResizeTensor(tensorProxy, trainingState.data.batch, m_Allocator);
TensorUtils.CopyTensor(trainingState, tensorProxy);
}
}
}
正在加载...
取消
保存