onnxruntime/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorArithmetic.tt
Scott McKay b5a652c578
Add Xamarin support (#9436)
Add Xamarin support to the ORT nuget packages.
  - Update C# code to support Xamarin builds for iOS and Android
  - refactor some things to split out common code
  - include iOS and Android ORT native shared library in native nuget package
2021-10-27 20:07:07 +10:00

249 lines
12 KiB
Text

<#@ template debug="false" hostspecific="false" language="C#" #>
<#@ assembly name="System.Core" #>
<#@ output extension=".cs" #>
<#@ include file="TensorTemplate.ttinclude" #>// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/tests/TensorArithmetic.cs
// Original license statement below -
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
namespace Microsoft.ML.OnnxRuntime.Tensors
{
internal interface ITensorArithmetic<T>
{
T One { get; }
T Zero { get; }
<# foreach (MethodConfiguration method in methodConfiguration) { #>
<#= method.GetResultMethodSignature("Tensor", "T")#>;
<# } #>
}
internal static class TensorArithmetic<T>
{
public static ITensorArithmetic<T> Instance => TensorArithmetic.GetArithmetic<T>();
}
internal static class TensorArithmetic
{
public static ITensorArithmetic<T> GetArithmetic<T>()
{
<# foreach (TypeConfiguration type in typeConfiguration) { #>
<#=GenerateIfStatementHeader(type)#>
{
return (ITensorArithmetic<T>)new <#=type.ClassPrefix#>Arithmetic();
}
<# } #>
return null;
}
}
<# foreach (TypeConfiguration type in typeConfiguration) { #>
internal class <#=type.ClassPrefix#>Arithmetic : ITensorArithmetic<<#=type.TypeName#>>
{
public <#=type.TypeName#> One => <#=type.OneLiteral#>;
public <#=type.TypeName#> Zero => <#=type.ZeroLiteral#>;
<# foreach (MethodConfiguration method in methodConfiguration) { #>
public <#= method.GetResultMethodSignature("Tensor", type.TypeName)#>
{
<# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #>
throw new NotSupportedException();
<# } else if (method.Operator != null) { #>
Span<int> indices = new Span<int>(new int[result.Rank]);
for(int i = 0; i < <#= method.ResultName #>.Length; i++)
{
ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, i, indices);
<#=method.GetElementOperation(type.TypeName, "[indices]")#>;
}
<# } else if (method.MethodName == "Contract") {#>
var leftIndices = new int[left.Rank];
var rightIndices = new int[right.Rank];
var resultIndices = new int[result.Rank];
var summingDimensions = new int[leftAxes.Length];
for(int i = 0; i < leftAxes.Length; i++)
{
summingDimensions[i] = left.dimensions[leftAxes[i]];
}
var summingStrides = ArrayUtilities.GetStrides(summingDimensions);
int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions);
var resultStrides = result.strides;
// translates from result index to left non-summing dimensions' index portion
// since left non-summing dimensions are given precedence in result, the end is zero-padded
int[] leftNonSummingStrides = new int[result.Rank];
// translates from summing index to left summing dimensions' index portion
int[] leftSummingStrides = new int[leftAxes.Length];
ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0);
// translates from result index to right non-summing dimensions' index portion
int[] rightNonSummingStrides = new int[result.Rank];
// right non-summing dimensions appear after left non-summing dimensions.
int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length);
// translates from summing index to right summing dimensions' index portion
int[] rightSummingStrides = new int[rightAxes.Length];
ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0);
for (int resultIndex = 0; resultIndex < result.Length; resultIndex++)
{
<#=type.TypeName#> sum = (<#=type.TypeName#>)0;
ArrayUtilities.GetIndices(result.strides, result.IsReversedStride, resultIndex, resultIndices);
int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides);
int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides);
for (int summingIndex = 0; summingIndex < summingLength; summingIndex++)
{
int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides);
int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides);
int leftIndex = leftIndexNonSumming + leftIndexSumming;
int rightIndex = rightIndexNonSumming + rightIndexSumming;
// todo, make this more efficient
ArrayUtilities.GetIndices(left.strides, left.IsReversedStride, leftIndex, leftIndices);
ArrayUtilities.GetIndices(right.strides, right.IsReversedStride, rightIndex, rightIndices);
sum += (<#=type.TypeName#>)(left[leftIndices] * right[rightIndices]);
}
result[resultIndices] = sum;
}
<# } #>
}
<# } #>
<# foreach (MethodConfiguration method in methodConfiguration) { #>
public <#= method.GetResultMethodSignature("DenseTensor", type.TypeName)#>
{
<# if ((method.IsNumeric && !type.SupportsNumeric) || (method.IsBitwise && !type.SupportsBitwise) || (type.UnsupportedMethods.Contains(method.MethodName))) { #>
throw new NotSupportedException();
<# } else if (method.Operator != null) { #>
<# if (method.MethodType == MethodType.UnaryInPlace) { #>
var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span;
var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span;
for(int i = 0; i < <#=method.ResultName #>Span.Length; i++)
{
<#=method.GetElementOperation(type.TypeName, "Span[i]")#>;
}
<# } else {#>
var <#=method.ResultName #>Span = <#=method.ResultName #>.Buffer.Span;
var <#=method.Op1Name #>Span = <#=method.Op1Name #>.Buffer.Span;
<# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#>
var <#=method.Op2Name #>Span = <#=method.Op2Name #>.Buffer.Span;
<# } #>
if <#= method.GetLinearOperationCheck() #>
{
for(int i = 0; i < <#= method.ResultName #>Span.Length; i++)
{
<#=method.GetElementOperation(type.TypeName, "Span[i]")#>;
}
}
else
{
int rowMajorIndex = 0;
int colMajorIndex = 0;
ref int resultIndex = ref <#= method.ResultName #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex;
ref int op1Index = ref <#= method.Op1Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex;
<# if ((method.MethodType == MethodType.Binary) || (method.MethodType == MethodType.Comparison)) {#>
ref int op2Index = ref <#= method.Op2Name #>.IsReversedStride ? ref colMajorIndex : ref rowMajorIndex;
var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides :
!<#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides :
<#= method.Op2Name #>.strides;
var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides :
<#= method.Op1Name #>.IsReversedStride ? <#= method.Op1Name #>.strides :
<#= method.Op2Name #>.strides;
<# } else {#>
var rowMajorStrides = !<#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides :
<#= method.Op1Name #>.strides;
var columnMajorStrides = <#= method.ResultName #>.IsReversedStride ? <#= method.ResultName #>.strides :
<#= method.Op1Name #>.strides;
<# } #>
for(;rowMajorIndex < <#= method.ResultName #>Span.Length; rowMajorIndex++)
{
colMajorIndex = ArrayUtilities.TransformIndexByStrides(rowMajorIndex, rowMajorStrides, false, columnMajorStrides);
<#=method.GetElementOperation(type.TypeName, "Span[resultIndex]", "Span[op1Index]", "Span[op2Index]")#>;
}
}
<# } #>
<# } else if (method.MethodName == "Contract") {#>
var summingDimensions = new int[leftAxes.Length];
for(int i = 0; i < leftAxes.Length; i++)
{
summingDimensions[i] = left.dimensions[leftAxes[i]];
}
var summingStrides = ArrayUtilities.GetStrides(summingDimensions);
int summingLength = (int)ArrayUtilities.GetProduct(summingDimensions);
var resultStrides = result.strides;
// translates from result index to left non-summing dimensions' index portion
// since left non-summing dimensions are given precedence in result, the end is zero-padded
int[] leftNonSummingStrides = new int[result.Rank];
// translates from summing index to left summing dimensions' index portion
int[] leftSummingStrides = new int[leftAxes.Length];
ArrayUtilities.SplitStrides(left.strides, leftAxes, leftNonSummingStrides, 0, leftSummingStrides, 0);
// translates from result index to right non-summing dimensions' index portion
int[] rightNonSummingStrides = new int[result.Rank];
// right non-summing dimensions appear after left non-summing dimensions.
int rightNonSummingStridesOffset = (left.Rank - leftAxes.Length);
// translates from summing index to right summing dimensions' index portion
int[] rightSummingStrides = new int[rightAxes.Length];
ArrayUtilities.SplitStrides(right.strides, rightAxes, rightNonSummingStrides, rightNonSummingStridesOffset, rightSummingStrides, 0);
var resultSpan = result.Buffer.Span;
var leftSpan = left.Buffer.Span;
var rightSpan = right.Buffer.Span;
for (int resultIndex = 0; resultIndex < resultSpan.Length; resultIndex++)
{
<#=type.TypeName#> sum = (<#=type.TypeName#>)0;
int leftIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, leftNonSummingStrides);
int rightIndexNonSumming = ArrayUtilities.TransformIndexByStrides(resultIndex, resultStrides, result.IsReversedStride, rightNonSummingStrides);
for (int summingIndex = 0; summingIndex < summingLength; summingIndex++)
{
int leftIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, leftSummingStrides);
int rightIndexSumming = ArrayUtilities.TransformIndexByStrides(summingIndex, summingStrides, false, rightSummingStrides);
int leftIndex = leftIndexNonSumming + leftIndexSumming;
int rightIndex = rightIndexNonSumming + rightIndexSumming;
sum += (<#=type.TypeName#>)(leftSpan[leftIndex] * rightSpan[rightIndex]);
}
resultSpan[resultIndex] = sum;
}
<# } #>
}
<# } #>
}
<# } #>
}