onnxruntime/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/ArrayTensorExtensionsTests.cs
Dmitri Smirnov bd4d011142
[C#] Rename unreleased API, add utilities (#16806)
### Description
1. rename OrtValue.FillStringTensorElement to StringTensorSetElementAt .
To the API user I think we're conceptually setting the string at an
offset in the tensor with is roughly equivalent to `List<string> list
... list[index] = "value"`.
2. While working on new inference examples, I noticed that I am still
inclined to use `DenseTensor` for N-D indexing. Added `GetStrides()` and
`GetIndex()` from strides for long dims, so the user can obtain strides
and translate N-D indices into a flat index to operate directly on the
native `OrtValue` buffers. Expose these functions to the user.
3. Make sure we generate docs for C# public static  functions.
2023-08-02 10:06:42 -07:00

148 lines
4.8 KiB
C#

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using Xunit;
using Microsoft.ML.OnnxRuntime.Tensors;
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.ML.OnnxRuntime.Tests.ArrayTensorExtensions
{
public class ArrayTensorExtensionsTests
{
static void CheckValues(IEnumerable<int> expected, DenseTensor<int> tensor)
{
foreach (var pair in expected.Zip(tensor.Buffer.ToArray(), Tuple.Create))
{
Assert.Equal(pair.Item1, pair.Item2);
}
}
[Fact]
public void ConstructFrom1D()
{
var array = new int[] { 1, 2, 3, 4 };
var tensor = array.ToTensor();
var expectedDims = new int[] { 4 };
Assert.Equal(tensor.Length, array.Length);
Assert.Equal(expectedDims, tensor.Dimensions.ToArray());
CheckValues(array.Cast<int>(), tensor);
}
[Fact]
public void ConstructFrom2D()
{
var array = new int[,] { { 1, 2 } , { 3, 4 } };
var tensor = array.ToTensor();
var expectedDims = new int[] { 2, 2 };
Assert.Equal(tensor.Length, array.Length);
Assert.Equal(expectedDims, tensor.Dimensions.ToArray());
CheckValues(array.Cast<int>(), tensor);
}
[Fact]
public void ConstructFrom3D()
{
var array = new int[,,] { { { 1, 2 }, { 3, 4 } },
{ { 5, 6 }, { 7, 8 } } };
var tensor = array.ToTensor();
var expectedDims = new int[] { 2, 2, 2 };
Assert.Equal(tensor.Length, array.Length);
Assert.Equal(expectedDims, tensor.Dimensions.ToArray());
CheckValues(array.Cast<int>(), tensor);
}
[Fact]
public void ConstructFrom3DWithDim1()
{
var array = new int[,,] { { { 1, 2 } },
{ { 3, 4 } } };
var tensor = array.ToTensor();
var expectedDims = new int[] { 2, 1, 2 };
Assert.Equal(tensor.Length, array.Length);
Assert.Equal(expectedDims, tensor.Dimensions.ToArray());
CheckValues(array.Cast<int>(), tensor);
}
[Fact]
public void ConstructFrom4D()
{
var array = new int[,,,] {
{ { { 1, 2 }, { 3, 4 } },
{ { 5, 6 }, { 7, 8 } } }
};
var tensor = array.ToTensor();
var expectedDims = new int[] { 1, 2, 2, 2 };
Assert.Equal(tensor.Length, array.Length);
Assert.Equal(expectedDims, tensor.Dimensions.ToArray());
CheckValues(array.Cast<int>(), tensor);
}
[Fact]
public void ConstructFrom5D()
{
var array = new int[,,,,] {
{ { { { 1, 2 }, { 3, 4 } },
{ { 5, 6 }, { 7, 8 } } } }
};
// 5D requires cast to Array
Array a = (Array)array;
var tensor = a.ToTensor<int>();
var expectedDims = new int[] { 1, 1, 2, 2, 2 };
Assert.Equal(tensor.Length, array.Length);
Assert.Equal(expectedDims, tensor.Dimensions.ToArray());
CheckValues(array.Cast<int>(), tensor);
}
[Fact]
public void TestLongStrides()
{
long[] emptyStrides = ShapeUtils.GetStrides(Array.Empty<long>());
Assert.Empty(emptyStrides);
long[] negativeDims = { 2, -3, 4, 5 };
Assert.Throws<ArgumentException>(() => ShapeUtils.GetStrides(negativeDims));
ReadOnlySpan<long> goodDims = stackalloc long[] { 2, 3, 4, 5 };
long[] expectedStrides = { 60, 20, 5, 1 };
Assert.Equal(expectedStrides, ShapeUtils.GetStrides(goodDims));
}
[Fact]
public void TestLongGetIndex()
{
ReadOnlySpan<long> dims = stackalloc long[] { 2, 3, 4, 5 };
long size = ShapeUtils.GetSizeForShape(dims);
Assert.Equal(120, size);
ReadOnlySpan<long> strides = ShapeUtils.GetStrides(dims);
static void IncDims(ReadOnlySpan<long> dims, Span<long> indices)
{
for (int i = dims.Length - 1; i >= 0; i--)
{
indices[i]++;
if (indices[i] < dims[i])
break;
indices[i] = 0;
}
}
Span<long> indices = stackalloc long[] { 0, 0, 0, 0 };
for (long i = 0; i < size; i++)
{
long index = ShapeUtils.GetIndex(strides, indices);
Assert.Equal(i, index);
IncDims(dims, indices);
}
}
}
}