using System.Collections.Generic; using System.Runtime.InteropServices.ComTypes; using Barracuda; namespace MLAgents.InferenceBrain { /// /// Mapping between Tensor names and generators. /// A TensorGenerator implements a Dictionary of strings (node names) to an Action. /// The Action take as argument the tensor, the current batch size and a Dictionary of /// Agent to AgentInfo corresponding to the current batch. /// Each Generator reshapes and fills the data of the tensor based of the data of the batch. /// When the TensorProxy is an Input to the model, the shape of the Tensor will be modified /// depending on the current batch size and the data of the Tensor will be filled using the /// Dictionary of Agent to AgentInfo. /// When the TensorProxy is an Output of the model, only the shape of the Tensor will be modified /// using the current batch size. The data will be prefilled with zeros. /// public class TensorGenerator { public interface Generator { /// /// Modifies the data inside a Tensor according to the information contained in the /// AgentInfos contained in the current batch. /// /// The tensor the data and shape will be modified /// The number of agents present in the current batch /// Dictionary of Agent to AgentInfo containing the /// information that will be used to populate the tensor's data void Generate(TensorProxy tensorProxy, int batchSize, Dictionary agentInfo); } Dictionary _dict = new Dictionary(); ITensorAllocator _allocator; /// /// Returns a new TensorGenerators object. /// /// The BrainParameters used to determine what Generators will be /// used /// The seed the Generators will be initialized with. /// Tensor allocator public TensorGenerator(BrainParameters bp, int seed, ITensorAllocator allocator, object barracudaModel = null) { _allocator = allocator; // Generator for Inputs _dict[TensorNames.BatchSizePlaceholder] = new BatchSizeGenerator(_allocator); _dict[TensorNames.SequenceLengthPlaceholder] = new SequenceLengthGenerator(_allocator); _dict[TensorNames.VectorObservationPlacholder] = new VectorObservationGenerator(_allocator); _dict[TensorNames.RecurrentInPlaceholder] = new RecurrentInputGenerator(_allocator); if (barracudaModel != null) { Model model = (Model) barracudaModel; for (var i = 0; i < model?.memories.Length; i++) { _dict[model.memories[i].input] = new BarracudaRecurrentInputGenerator(i, _allocator); } } _dict[TensorNames.PreviousActionPlaceholder] = new PreviousActionInputGenerator(_allocator); _dict[TensorNames.ActionMaskPlaceholder] = new ActionMaskInputGenerator(_allocator); _dict[TensorNames.RandomNormalEpsilonPlaceholder] = new RandomNormalInputGenerator(seed, _allocator); if (bp.cameraResolutions != null) { for (var visIndex = 0; visIndex < bp.cameraResolutions.Length; visIndex++) { var index = visIndex; var bw = bp.cameraResolutions[visIndex].blackAndWhite; _dict[TensorNames.VisualObservationPlaceholderPrefix + visIndex] = new VisualObservationInputGenerator(index, bw, _allocator); } } // Generators for Outputs _dict[TensorNames.ActionOutput] = new BiDimensionalOutputGenerator(_allocator); _dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(_allocator); _dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(_allocator); } /// /// Populates the data of the tensor inputs given the data contained in the current batch /// of agents. /// /// Enumerable of tensors that will be modified. /// The number of agents present in the current batch /// /// Dictionary of Agent to AgentInfo that contains the /// data that will be used to modify the tensors /// One of the tensor does not have an /// associated generator. public void GenerateTensors(IEnumerable tensors, int currentBatchSize, Dictionary agentInfos) { foreach (var tensor in tensors) { if (!_dict.ContainsKey(tensor.Name)) { throw new UnityAgentsException( "Unknow tensorProxy expected as input : " + tensor.Name); } _dict[tensor.Name].Generate(tensor, currentBatchSize, agentInfos); } } } }