using NUnit.Framework;
using System;
using System.Linq;
using Unity.MLAgents.Extensions.Sensors;
using UnityEngine;
namespace Unity.MLAgents.Extensions.Tests.Sensors
{
public static class GridObsTestUtils
{
///
/// Returns a human readable string of an array. Optional arguments are the index to start from and the number of elements to add to the string
///
/// The array to convert to string
/// The initial index. Default 0
/// The number of elements to print
/// Human readable string
public static string Array2Str(T[] arr, int initialIndex = 0, int maxNumberOfElements = int.MaxValue)
{
return String.Join(", ", arr.Skip(initialIndex).Take(maxNumberOfElements));
}
///
/// Given a flattened matrix and a shape, returns a string in human readable format
///
/// Flattened matrix array
/// Shape of matrix
/// human readable string
public static string Matrix2Str(float[] arr, int[] shape)
{
string log = "[";
int t = 0;
for (int i = 0; i < shape[0]; i++)
{
log += "\n[";
for (int j = 0; j < shape[1]; j++)
{
log += "[";
for (int k = 0; k < shape[2]; k++)
{
log += arr[t] + ", ";
t++;
}
log += "],";
}
log += "]";
}
log += "]";
return log;
}
///
/// Utility function to duplicate an array into an array of arrays
///
/// array to duplicate
/// number of times to duplicate
/// array of duplicated arrays
public static float[][] DuplicateArray(float[] array, int numCopies)
{
float[][] duplicated = new float[numCopies][];
for (int i = 0; i < numCopies; i++)
{
duplicated[i] = array;
}
return duplicated;
}
///
/// Asserts that the sub-arrays of the total array are equal to specific subarrays at specific subarray indicies and equal to a default everywhere else.
///
/// Array containing all data of the grid observation. Is a concatenation of N subarrays all of the same length
/// The indicies to verify that differ from the default array
/// The sub arrays values that differ from the default array
/// The default value of a sub array
///
/// If the total array is data from a 4x4x2 grid observation, total will be an array of size 32 and each sub array will have a size of 2.
/// Let 3 cells at indicies (0, 1), (2, 2), and (3, 0) with values ([.1, .5]), ([.9, .7]), ([0, .2]), respectively.
/// If the default values of cells are ([0, 0]) then the grid observation will be as follows:
/// [ [0, 0], [.1, .5], [ 0, 0 ], [0, 0],
/// [0, 0], [ 0, 0 ], [ 0, 0 ], [0, 0],
/// [0, 0], [ 0, 0 ], [.9, .7], [0, 0],
/// [0, .2], [ 0, 0 ], [ 0, 0 ], [0, 0] ]
///
/// Which will make the total array will be the flattened array
/// total = [0, 0, .1, .5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, .9, .7, 0, 0, 0, .2, 0, 0, 0, 0, 0]
///
/// The indicies of the activated cells in the flattened array will be 1, 10, and 12
///
/// So to verify that the total array is as expected, AssertSubarraysAtIndex should be called as
/// AssertSubarraysAtIndex(
/// total,
/// indicies = new int[] {1, 10, 12},
/// expectedArrays = new float[][] { new float[] {.1, .5}, new float[] {.9, .7}, new float[] {0, .2}},
/// expecedDefaultArray = new float[] {0, 0}
/// )
///
public static void AssertSubarraysAtIndex(float[] total, int[] indicies, float[][] expectedArrays, float[] expectedDefaultArray)
{
int totalIndex = 0;
int subIndex = 0;
int subarrayIndex = 0;
int lenOfData = expectedDefaultArray.Length;
int numArrays = total.Length / lenOfData;
for (int i = 0; i < numArrays; i++)
{
totalIndex = i * lenOfData;
if (indicies.Contains(i))
{
subarrayIndex = Array.IndexOf(indicies, i);
for (subIndex = 0; subIndex < lenOfData; subIndex++)
{
Assert.AreEqual(expectedArrays[subarrayIndex][subIndex], total[totalIndex],
"Expected " + expectedArrays[subarrayIndex][subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]);
totalIndex++;
}
}
else
{
for (subIndex = 0; subIndex < lenOfData; subIndex++)
{
Assert.AreEqual(expectedDefaultArray[subIndex], total[totalIndex],
"Expected default value " + expectedDefaultArray[subIndex] + " at subarray index " + totalIndex + ", index = " + subIndex + " but was " + total[totalIndex]);
totalIndex++;
}
}
}
}
public static void SetComponentParameters(GridSensorComponent gridComponent, string[] detectableObjects, int[] channelDepth, GridDepthType gridDepthType,
float cellScaleX, float cellScaleZ, int gridWidth, int gridHeight, int colliderMaskInt, bool rotateWithAgent, Color[] debugColors)
{
gridComponent.DetectableObjects = detectableObjects;
gridComponent.ChannelDepths = channelDepth;
gridComponent.DepthType = gridDepthType;
gridComponent.CellScale = new Vector3(cellScaleX, 0.01f, cellScaleZ);
gridComponent.GridSize = new Vector3Int(gridWidth, 1, gridHeight);
gridComponent.ColliderMask = colliderMaskInt;
gridComponent.RotateWithAgent = rotateWithAgent;
gridComponent.DebugColors = debugColors;
}
}
}