onnxruntime/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/Tensors/TensorTemplate.ttinclude
Scott McKay b5a652c578
Add Xamarin support (#9436)
Add Xamarin support to the ORT nuget packages.
  - Update C# code to support Xamarin builds for iOS and Android
  - refactor some things to split out common code
  - include iOS and Android ORT native shared library in native nuget package
2021-10-27 20:07:07 +10:00

328 lines
14 KiB
Text

<#@ import namespace="System.Linq" #>
<#@ import namespace="System.Text" #>
<#@ import namespace="System.Collections.Generic" #>
<#+
public class TypeConfiguration
{
public TypeConfiguration(string typeName, string classPrefix = null, string oneLiteral = "1", string zeroLiteral = "0", bool supportsNumeric = true, bool supportsBitwise = true, IEnumerable<string> unsupportedMethods = null)
{
TypeName = typeName;
ClassPrefix = classPrefix ?? char.ToUpper(typeName[0]) + typeName.Substring(1);
OneLiteral = oneLiteral;
ZeroLiteral = zeroLiteral;
SupportsNumeric = supportsNumeric;
SupportsBitwise = supportsBitwise;
UnsupportedMethods = new HashSet<string>(unsupportedMethods ?? Enumerable.Empty<string>());
}
public string TypeName { get; }
public string ClassPrefix { get; }
public string OneLiteral { get; }
public string ZeroLiteral { get; }
public bool SupportsNumeric { get; }
public bool SupportsBitwise { get; }
public ISet<string> UnsupportedMethods { get; }
}
public string GenerateIfStatementHeader(TypeConfiguration type)
{
string keyword = (type == typeConfiguration[0]) ? "if" : "else if";
return $"{keyword} (typeof(T) == typeof({type.TypeName}))";
}
public TypeConfiguration[] typeConfiguration = new []
{
new TypeConfiguration("bool", oneLiteral:"true", zeroLiteral:"false", supportsNumeric: false, unsupportedMethods: new[] {"LeftShift", "RightShift"}),
new TypeConfiguration("byte"),
new TypeConfiguration("char", oneLiteral:"(char)1", zeroLiteral:"(char)0"),
new TypeConfiguration("decimal", supportsBitwise: false),
new TypeConfiguration("double", oneLiteral:"1.0", supportsBitwise: false),
new TypeConfiguration("float", oneLiteral:"1.0f", supportsBitwise: false),
new TypeConfiguration("int"),
new TypeConfiguration("long"),
new TypeConfiguration("sbyte", classPrefix:"SByte"),
new TypeConfiguration("short"),
new TypeConfiguration("uint", classPrefix:"UInt", unsupportedMethods: new[] {"UnaryMinus"}),
new TypeConfiguration("ulong", classPrefix:"ULong", unsupportedMethods: new[] {"UnaryMinus"}),
new TypeConfiguration("ushort", classPrefix:"UShort", unsupportedMethods: new[] {"UnaryMinus"})
};
public enum MethodType
{
Unary,
UnaryInPlace,
BinaryScalar,
BinaryInt,
Binary,
Comparison,
Contraction
}
public class MethodConfiguration
{
public MethodConfiguration(string methodName, MethodType methodType, string op = null, bool isNumeric = false, bool isBitwise = false)
{
MethodName = methodName;
MethodType = methodType;
Operator = op;
IsNumeric = isNumeric;
IsBitwise = isBitwise;
}
public string ResultName => "result";
public string Op1Name
{
get
{
switch (MethodType)
{
case MethodType.Unary:
case MethodType.UnaryInPlace:
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
return "tensor";
case MethodType.Binary:
case MethodType.Comparison:
case MethodType.Contraction:
return "left";
default:
throw new ArgumentException();
};
}
}
public string Op2Name
{
get
{
switch (MethodType)
{
case MethodType.BinaryScalar:
return "scalar";
case MethodType.BinaryInt:
return "value";
case MethodType.Binary:
case MethodType.Comparison:
case MethodType.Contraction:
return "right";
case MethodType.Unary:
case MethodType.UnaryInPlace:
default:
throw new ArgumentException();
};
}
}
public string MethodName { get; }
public MethodType MethodType { get; }
public string Operator { get; }
public string GetGenericMethodSignature(string tensorType, string genericType)
{
var resultType = GetResultType(tensorType, genericType);
var arguments = GetMethodArguments(tensorType, genericType);
return $"{resultType} {MethodName}<{genericType}>({arguments})";
}
public string GetGenericResultMethodSignature(string tensorType, string genericType)
{
var resultType = GetResultType(tensorType, genericType);
var arguments = GetMethodArguments(tensorType, genericType);
return $"void {MethodName}<{genericType}>({arguments}, {resultType} {ResultName})";
}
public string GetResultMethodSignature(string tensorType, string genericType)
{
var resultType = GetResultType(tensorType, genericType);
var arguments = GetMethodArguments(tensorType, genericType);
return $"void {MethodName}({arguments}, {resultType} {ResultName})";
}
public string GetMethodArguments(string tensorType, string genericType)
{
switch (MethodType)
{
case MethodType.Unary:
case MethodType.UnaryInPlace:
return $"{tensorType}<{genericType}> {Op1Name}";
case MethodType.BinaryScalar:
return $"{tensorType}<{genericType}> {Op1Name}, {genericType} {Op2Name}";
case MethodType.BinaryInt:
return $"{tensorType}<{genericType}> {Op1Name}, int {Op2Name}";
case MethodType.Binary:
case MethodType.Comparison:
return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}";
case MethodType.Contraction:
return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}, int[] leftAxes, int[] rightAxes";
default:
throw new ArgumentException();
}
}
public string GetCallArguments()
{
switch (MethodType)
{
case MethodType.Unary:
case MethodType.UnaryInPlace:
return $"{Op1Name}";
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
case MethodType.Binary:
case MethodType.Comparison:
return $"{Op1Name}, {Op2Name}";
case MethodType.Contraction:
return "left, right, leftAxes, rightAxes";
default:
throw new ArgumentException();
}
}
public string GetValidationMethod(bool includeResult)
{
var suffix = includeResult ? ", result" : "";
switch (MethodType)
{
case MethodType.Unary:
case MethodType.UnaryInPlace:
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
return $"ValidateArgs({Op1Name}{suffix});";
case MethodType.Binary:
case MethodType.Comparison:
return $"ValidateBinaryArgs({Op1Name}, {Op2Name}{suffix});";
case MethodType.Contraction:
return $"var resultDimensions = ValidateContractArgs({Op1Name}, {Op2Name}, leftAxes, rightAxes{suffix});";
default:
throw new ArgumentException();
}
}
public string GetResultType(string tensorType, string typeName)
{
switch (MethodType)
{
case MethodType.Unary:
case MethodType.UnaryInPlace:
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
case MethodType.Binary:
case MethodType.Contraction:
return $"{tensorType}<{typeName}>";
case MethodType.Comparison:
return $"{tensorType}<bool>";
default:
throw new ArgumentException();
}
}
public string GetLinearOperationCheck()
{
switch (MethodType)
{
case MethodType.Unary:
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
return $"({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride)";
case MethodType.Binary:
case MethodType.Comparison:
return $"(({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride) && ({ResultName}.IsReversedStride == {Op2Name}.IsReversedStride))";
case MethodType.UnaryInPlace:
default:
throw new ArgumentException();
}
}
public string GetElementOperation(string typeName, string access)
{
return GetElementOperation(typeName, access, access, access);
}
public string GetElementOperation(string typeName, string resultAccess, string leftAccess, string rightAccess)
{
switch (MethodType)
{
case MethodType.Unary:
return $"{ResultName}{resultAccess} = ({typeName}){Operator}{Op1Name}{leftAccess}";
case MethodType.UnaryInPlace:
return $"{ResultName}{resultAccess}{Operator}";
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name})";
case MethodType.Binary:
return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess})";
case MethodType.Comparison:
return $"{ResultName}{resultAccess} = {Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess}";
default:
throw new ArgumentException();
}
}
public string InitializeResult(string typeName)
{
switch (MethodType)
{
case MethodType.UnaryInPlace:
return $"{Op1Name}.Clone()";
case MethodType.Unary:
case MethodType.BinaryScalar:
case MethodType.BinaryInt:
case MethodType.Binary:
return $"{Op1Name}.CloneEmpty()";
case MethodType.Comparison:
return $"{Op1Name}.CloneEmpty<bool>()";
case MethodType.Contraction:
return $"{Op1Name}.CloneEmpty(resultDimensions)";
default:
throw new ArgumentException();
}
}
public bool IsNumeric { get; }
public bool IsBitwise { get; }
}
public MethodConfiguration[] methodConfiguration = new []
{
new MethodConfiguration("Add", MethodType.Binary, "+", isNumeric:true),
new MethodConfiguration("Add", MethodType.BinaryScalar, "+", isNumeric:true),
new MethodConfiguration("UnaryPlus", MethodType.Unary, "+", isNumeric:true),
new MethodConfiguration("Subtract", MethodType.Binary, "-", isNumeric:true),
new MethodConfiguration("Subtract", MethodType.BinaryScalar, "-", isNumeric:true),
new MethodConfiguration("UnaryMinus", MethodType.Unary, "-", isNumeric:true),
new MethodConfiguration("Increment", MethodType.UnaryInPlace, "++", isNumeric:true),
new MethodConfiguration("Decrement", MethodType.UnaryInPlace, "--", isNumeric:true),
new MethodConfiguration("Multiply", MethodType.Binary, "*", isNumeric:true), // element-wise product, not matrix product
new MethodConfiguration("Multiply", MethodType.BinaryScalar, "*", isNumeric:true),
new MethodConfiguration("Divide", MethodType.Binary, "/", isNumeric:true),
new MethodConfiguration("Divide", MethodType.BinaryScalar, "/", isNumeric:true),
new MethodConfiguration("Modulo", MethodType.Binary, "%", isNumeric:true),
new MethodConfiguration("Modulo", MethodType.BinaryScalar, "%", isNumeric:true),
new MethodConfiguration("And", MethodType.Binary, "&", isBitwise: true),
new MethodConfiguration("And", MethodType.BinaryScalar, "&", isBitwise: true),
new MethodConfiguration("Or", MethodType.Binary, "|", isBitwise: true),
new MethodConfiguration("Or", MethodType.BinaryScalar, "|", isBitwise: true),
new MethodConfiguration("Xor", MethodType.Binary, "^", isBitwise: true),
new MethodConfiguration("Xor", MethodType.BinaryScalar, "^", isBitwise: true),
new MethodConfiguration("LeftShift", MethodType.BinaryInt, "<<", isBitwise: true),
new MethodConfiguration("RightShift", MethodType.BinaryInt, ">>", isBitwise: true),
// Note all of these are element-wise operations not testing the operation on the entire Tensor
new MethodConfiguration("Equals", MethodType.Comparison, "=="),
new MethodConfiguration("NotEquals", MethodType.Comparison, "!="),
new MethodConfiguration("GreaterThanOrEqual", MethodType.Comparison, ">=", isNumeric:true),
new MethodConfiguration("LessThanOrEqual", MethodType.Comparison, "<=", isNumeric:true),
new MethodConfiguration("GreaterThan", MethodType.Comparison, ">", isNumeric:true),
new MethodConfiguration("LessThan", MethodType.Comparison, "<", isNumeric:true),
new MethodConfiguration("Contract", MethodType.Contraction, isNumeric:true),
}.OrderBy(m => m.MethodName).ToArray();
#>