|
|
|
|
|
|
public interface ITrainingGenerator |
|
|
|
{ |
|
|
|
void Generate( |
|
|
|
TensorProxy tensorProxy, int batchSize, IList<Transition> transitions); |
|
|
|
TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState); |
|
|
|
} |
|
|
|
|
|
|
|
readonly Dictionary<string, ITrainingGenerator> m_Dict = new Dictionary<string, ITrainingGenerator>(); |
|
|
|
|
|
|
int seed, |
|
|
|
ITensorAllocator allocator, |
|
|
|
object barracudaModel = null) |
|
|
|
float learning_rate, |
|
|
|
float gamma, |
|
|
|
object barracudaModel = null |
|
|
|
) |
|
|
|
{ |
|
|
|
// If model is null, no inference to run and exception is thrown before reaching here.
|
|
|
|
if (barracudaModel == null) |
|
|
|
|
|
|
var model = (Model)barracudaModel; |
|
|
|
|
|
|
|
// Generator for Inputs
|
|
|
|
var obsGen = new CopyObservationTensorsGenerator(allocator); |
|
|
|
obsGen.SetSensorIndex(0); |
|
|
|
m_Dict[TensorNames.Observations] = obsGen; |
|
|
|
var nextObsGen = new CopyNextObservationTensorsGenerator(allocator); |
|
|
|
nextObsGen.SetSensorIndex(0); |
|
|
|
m_Dict[TensorNames.NextObservations] = nextObsGen; |
|
|
|
m_Dict[TensorNames.TargetInput] = new RewardInputGenerator(allocator); |
|
|
|
m_Dict[TensorNames.LearningRate] = new ConstantGenerator(allocator, 0.0001f); |
|
|
|
m_Dict[TensorNames.DoneInput] = new DoneInputGenerator(allocator); |
|
|
|
m_Dict[TensorNames.LearningRate] = new ConstantGenerator(allocator,learning_rate); |
|
|
|
m_Dict[TensorNames.Gamma] = new ConstantGenerator(allocator, gamma); |
|
|
|
|
|
|
|
// Generators for Outputs
|
|
|
|
m_Dict[TensorNames.TrainingStateIn] = new TrainingStateGenerator(allocator); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
|
|
|
/// <exception cref="UnityAgentsException"> One of the tensor does not have an
|
|
|
|
/// associated generator.</exception>
|
|
|
|
public void GenerateTensors( |
|
|
|
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions, bool training=false) |
|
|
|
IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<Transition> transitions, TensorProxy trainingState, bool training=false) |
|
|
|
{ |
|
|
|
for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) |
|
|
|
{ |
|
|
|
|
|
|
throw new UnityAgentsException( |
|
|
|
$"Unknown tensorProxy expected as input : {tensor.name}"); |
|
|
|
} |
|
|
|
if (tensor.name.StartsWith("obs_") || tensor.name == TensorNames.BatchSizePlaceholder) |
|
|
|
{ |
|
|
|
if (training == true) |
|
|
|
{ |
|
|
|
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions); |
|
|
|
} |
|
|
|
} |
|
|
|
else |
|
|
|
if ((tensor.name == TensorNames.Observations || tensor.name == TensorNames.BatchSizePlaceholder) && training == false) |
|
|
|
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions); |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public void InitializeObservations(Transition transition, ITensorAllocator allocator) |
|
|
|
{ |
|
|
|
for (var sensorIndex = 0; sensorIndex < transition.state.Count; sensorIndex++) |
|
|
|
{ |
|
|
|
var obsGen = new CopyObservationTensorsGenerator(allocator); |
|
|
|
var obsGenName = TensorNames.GetObservationName(sensorIndex); |
|
|
|
obsGen.SetSensorIndex(sensorIndex); |
|
|
|
m_Dict[obsGenName] = obsGen; |
|
|
|
} |
|
|
|
|
|
|
|
for (var sensorIndex = 0; sensorIndex < transition.nextState.Count; sensorIndex++) |
|
|
|
{ |
|
|
|
var obsGen = new CopyNextObservationTensorsGenerator(allocator); |
|
|
|
var obsGenName = TensorNames.GetNextObservationName(sensorIndex); |
|
|
|
obsGen.SetSensorIndex(sensorIndex); |
|
|
|
m_Dict[obsGenName] = obsGen; |
|
|
|
m_Dict[tensor.name].Generate(tensor, currentBatchSize, transitions, trainingState); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
m_Allocator = allocator; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions) |
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
|
|
|
m_Allocator = allocator; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions) |
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
internal class DoneInputGenerator: TrainingTensorGenerator.ITrainingGenerator |
|
|
|
{ |
|
|
|
readonly ITensorAllocator m_Allocator; |
|
|
|
|
|
|
|
public DoneInputGenerator(ITensorAllocator allocator) |
|
|
|
{ |
|
|
|
m_Allocator = allocator; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
{ |
|
|
|
tensorProxy.data[index, 0] = transitions[index].done==true ? 1f : 0f; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
internal class CopyObservationTensorsGenerator: TrainingTensorGenerator.ITrainingGenerator |
|
|
|
{ |
|
|
|
readonly ITensorAllocator m_Allocator; |
|
|
|
|
|
|
m_SensorIndex = index; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions) |
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
|
|
|
m_SensorIndex = index; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions) |
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
|
|
|
m_Const = c; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions) |
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, 1, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
|
|
|
m_Allocator = allocator; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions) |
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
tensorProxy.data?.Dispose(); |
|
|
|
tensorProxy.data = m_Allocator.Alloc(new TensorShape(1, 1)); |
|
|
|
|
|
|
|
|
|
|
internal class TrainingStateGenerator: TrainingTensorGenerator.ITrainingGenerator |
|
|
|
{ |
|
|
|
readonly ITensorAllocator m_Allocator; |
|
|
|
|
|
|
|
public TrainingStateGenerator(ITensorAllocator allocator) |
|
|
|
{ |
|
|
|
m_Allocator = allocator; |
|
|
|
} |
|
|
|
|
|
|
|
public void Generate(TensorProxy tensorProxy, int batchSize, IList<Transition> transitions, TensorProxy trainingState) |
|
|
|
{ |
|
|
|
TensorUtils.ResizeTensor(tensorProxy, batchSize, m_Allocator); |
|
|
|
for (var index = 0; index < batchSize; index++) |
|
|
|
{ |
|
|
|
TensorUtils.CopyTensor(trainingState, tensorProxy); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |