mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
Add Xamarin support to the ORT nuget packages. - Update C# code to support Xamarin builds for iOS and Android - refactor some things to split out common code - include iOS and Android ORT native shared library in native nuget package
328 lines
14 KiB
Text
328 lines
14 KiB
Text
<#@ import namespace="System.Linq" #>
|
|
<#@ import namespace="System.Text" #>
|
|
<#@ import namespace="System.Collections.Generic" #>
|
|
<#+
|
|
public class TypeConfiguration
|
|
{
|
|
public TypeConfiguration(string typeName, string classPrefix = null, string oneLiteral = "1", string zeroLiteral = "0", bool supportsNumeric = true, bool supportsBitwise = true, IEnumerable<string> unsupportedMethods = null)
|
|
{
|
|
TypeName = typeName;
|
|
ClassPrefix = classPrefix ?? char.ToUpper(typeName[0]) + typeName.Substring(1);
|
|
OneLiteral = oneLiteral;
|
|
ZeroLiteral = zeroLiteral;
|
|
SupportsNumeric = supportsNumeric;
|
|
SupportsBitwise = supportsBitwise;
|
|
UnsupportedMethods = new HashSet<string>(unsupportedMethods ?? Enumerable.Empty<string>());
|
|
}
|
|
|
|
public string TypeName { get; }
|
|
public string ClassPrefix { get; }
|
|
public string OneLiteral { get; }
|
|
public string ZeroLiteral { get; }
|
|
|
|
public bool SupportsNumeric { get; }
|
|
public bool SupportsBitwise { get; }
|
|
public ISet<string> UnsupportedMethods { get; }
|
|
}
|
|
|
|
public string GenerateIfStatementHeader(TypeConfiguration type)
|
|
{
|
|
string keyword = (type == typeConfiguration[0]) ? "if" : "else if";
|
|
return $"{keyword} (typeof(T) == typeof({type.TypeName}))";
|
|
}
|
|
|
|
public TypeConfiguration[] typeConfiguration = new []
|
|
{
|
|
new TypeConfiguration("bool", oneLiteral:"true", zeroLiteral:"false", supportsNumeric: false, unsupportedMethods: new[] {"LeftShift", "RightShift"}),
|
|
new TypeConfiguration("byte"),
|
|
new TypeConfiguration("char", oneLiteral:"(char)1", zeroLiteral:"(char)0"),
|
|
new TypeConfiguration("decimal", supportsBitwise: false),
|
|
new TypeConfiguration("double", oneLiteral:"1.0", supportsBitwise: false),
|
|
new TypeConfiguration("float", oneLiteral:"1.0f", supportsBitwise: false),
|
|
new TypeConfiguration("int"),
|
|
new TypeConfiguration("long"),
|
|
new TypeConfiguration("sbyte", classPrefix:"SByte"),
|
|
new TypeConfiguration("short"),
|
|
new TypeConfiguration("uint", classPrefix:"UInt", unsupportedMethods: new[] {"UnaryMinus"}),
|
|
new TypeConfiguration("ulong", classPrefix:"ULong", unsupportedMethods: new[] {"UnaryMinus"}),
|
|
new TypeConfiguration("ushort", classPrefix:"UShort", unsupportedMethods: new[] {"UnaryMinus"})
|
|
};
|
|
|
|
public enum MethodType
|
|
{
|
|
Unary,
|
|
UnaryInPlace,
|
|
BinaryScalar,
|
|
BinaryInt,
|
|
Binary,
|
|
Comparison,
|
|
Contraction
|
|
}
|
|
|
|
public class MethodConfiguration
|
|
{
|
|
public MethodConfiguration(string methodName, MethodType methodType, string op = null, bool isNumeric = false, bool isBitwise = false)
|
|
{
|
|
MethodName = methodName;
|
|
MethodType = methodType;
|
|
Operator = op;
|
|
IsNumeric = isNumeric;
|
|
IsBitwise = isBitwise;
|
|
}
|
|
|
|
public string ResultName => "result";
|
|
|
|
public string Op1Name
|
|
{
|
|
get
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
case MethodType.UnaryInPlace:
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
return "tensor";
|
|
case MethodType.Binary:
|
|
case MethodType.Comparison:
|
|
case MethodType.Contraction:
|
|
return "left";
|
|
default:
|
|
throw new ArgumentException();
|
|
};
|
|
}
|
|
}
|
|
|
|
public string Op2Name
|
|
{
|
|
get
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.BinaryScalar:
|
|
return "scalar";
|
|
case MethodType.BinaryInt:
|
|
return "value";
|
|
case MethodType.Binary:
|
|
case MethodType.Comparison:
|
|
case MethodType.Contraction:
|
|
return "right";
|
|
case MethodType.Unary:
|
|
case MethodType.UnaryInPlace:
|
|
default:
|
|
throw new ArgumentException();
|
|
};
|
|
}
|
|
}
|
|
|
|
public string MethodName { get; }
|
|
public MethodType MethodType { get; }
|
|
public string Operator { get; }
|
|
|
|
public string GetGenericMethodSignature(string tensorType, string genericType)
|
|
{
|
|
var resultType = GetResultType(tensorType, genericType);
|
|
var arguments = GetMethodArguments(tensorType, genericType);
|
|
|
|
return $"{resultType} {MethodName}<{genericType}>({arguments})";
|
|
}
|
|
|
|
public string GetGenericResultMethodSignature(string tensorType, string genericType)
|
|
{
|
|
var resultType = GetResultType(tensorType, genericType);
|
|
var arguments = GetMethodArguments(tensorType, genericType);
|
|
|
|
return $"void {MethodName}<{genericType}>({arguments}, {resultType} {ResultName})";
|
|
}
|
|
|
|
public string GetResultMethodSignature(string tensorType, string genericType)
|
|
{
|
|
var resultType = GetResultType(tensorType, genericType);
|
|
var arguments = GetMethodArguments(tensorType, genericType);
|
|
|
|
return $"void {MethodName}({arguments}, {resultType} {ResultName})";
|
|
}
|
|
|
|
public string GetMethodArguments(string tensorType, string genericType)
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
case MethodType.UnaryInPlace:
|
|
return $"{tensorType}<{genericType}> {Op1Name}";
|
|
case MethodType.BinaryScalar:
|
|
return $"{tensorType}<{genericType}> {Op1Name}, {genericType} {Op2Name}";
|
|
case MethodType.BinaryInt:
|
|
return $"{tensorType}<{genericType}> {Op1Name}, int {Op2Name}";
|
|
case MethodType.Binary:
|
|
case MethodType.Comparison:
|
|
return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}";
|
|
case MethodType.Contraction:
|
|
return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}, int[] leftAxes, int[] rightAxes";
|
|
default:
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
|
|
public string GetCallArguments()
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
case MethodType.UnaryInPlace:
|
|
return $"{Op1Name}";
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
case MethodType.Binary:
|
|
case MethodType.Comparison:
|
|
return $"{Op1Name}, {Op2Name}";
|
|
case MethodType.Contraction:
|
|
return "left, right, leftAxes, rightAxes";
|
|
default:
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
|
|
public string GetValidationMethod(bool includeResult)
|
|
{
|
|
var suffix = includeResult ? ", result" : "";
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
case MethodType.UnaryInPlace:
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
return $"ValidateArgs({Op1Name}{suffix});";
|
|
case MethodType.Binary:
|
|
case MethodType.Comparison:
|
|
return $"ValidateBinaryArgs({Op1Name}, {Op2Name}{suffix});";
|
|
case MethodType.Contraction:
|
|
return $"var resultDimensions = ValidateContractArgs({Op1Name}, {Op2Name}, leftAxes, rightAxes{suffix});";
|
|
default:
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
|
|
public string GetResultType(string tensorType, string typeName)
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
case MethodType.UnaryInPlace:
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
case MethodType.Binary:
|
|
case MethodType.Contraction:
|
|
return $"{tensorType}<{typeName}>";
|
|
case MethodType.Comparison:
|
|
return $"{tensorType}<bool>";
|
|
default:
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
|
|
public string GetLinearOperationCheck()
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
return $"({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride)";
|
|
case MethodType.Binary:
|
|
case MethodType.Comparison:
|
|
return $"(({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride) && ({ResultName}.IsReversedStride == {Op2Name}.IsReversedStride))";
|
|
case MethodType.UnaryInPlace:
|
|
default:
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
|
|
|
|
public string GetElementOperation(string typeName, string access)
|
|
{
|
|
return GetElementOperation(typeName, access, access, access);
|
|
}
|
|
|
|
public string GetElementOperation(string typeName, string resultAccess, string leftAccess, string rightAccess)
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.Unary:
|
|
return $"{ResultName}{resultAccess} = ({typeName}){Operator}{Op1Name}{leftAccess}";
|
|
case MethodType.UnaryInPlace:
|
|
return $"{ResultName}{resultAccess}{Operator}";
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name})";
|
|
case MethodType.Binary:
|
|
return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess})";
|
|
case MethodType.Comparison:
|
|
return $"{ResultName}{resultAccess} = {Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess}";
|
|
default:
|
|
throw new ArgumentException();
|
|
|
|
}
|
|
}
|
|
|
|
public string InitializeResult(string typeName)
|
|
{
|
|
switch (MethodType)
|
|
{
|
|
case MethodType.UnaryInPlace:
|
|
return $"{Op1Name}.Clone()";
|
|
case MethodType.Unary:
|
|
case MethodType.BinaryScalar:
|
|
case MethodType.BinaryInt:
|
|
case MethodType.Binary:
|
|
return $"{Op1Name}.CloneEmpty()";
|
|
case MethodType.Comparison:
|
|
return $"{Op1Name}.CloneEmpty<bool>()";
|
|
case MethodType.Contraction:
|
|
return $"{Op1Name}.CloneEmpty(resultDimensions)";
|
|
default:
|
|
throw new ArgumentException();
|
|
}
|
|
}
|
|
|
|
public bool IsNumeric { get; }
|
|
public bool IsBitwise { get; }
|
|
}
|
|
|
|
|
|
public MethodConfiguration[] methodConfiguration = new []
|
|
{
|
|
new MethodConfiguration("Add", MethodType.Binary, "+", isNumeric:true),
|
|
new MethodConfiguration("Add", MethodType.BinaryScalar, "+", isNumeric:true),
|
|
new MethodConfiguration("UnaryPlus", MethodType.Unary, "+", isNumeric:true),
|
|
new MethodConfiguration("Subtract", MethodType.Binary, "-", isNumeric:true),
|
|
new MethodConfiguration("Subtract", MethodType.BinaryScalar, "-", isNumeric:true),
|
|
new MethodConfiguration("UnaryMinus", MethodType.Unary, "-", isNumeric:true),
|
|
new MethodConfiguration("Increment", MethodType.UnaryInPlace, "++", isNumeric:true),
|
|
new MethodConfiguration("Decrement", MethodType.UnaryInPlace, "--", isNumeric:true),
|
|
new MethodConfiguration("Multiply", MethodType.Binary, "*", isNumeric:true), // element-wise product, not matrix product
|
|
new MethodConfiguration("Multiply", MethodType.BinaryScalar, "*", isNumeric:true),
|
|
new MethodConfiguration("Divide", MethodType.Binary, "/", isNumeric:true),
|
|
new MethodConfiguration("Divide", MethodType.BinaryScalar, "/", isNumeric:true),
|
|
new MethodConfiguration("Modulo", MethodType.Binary, "%", isNumeric:true),
|
|
new MethodConfiguration("Modulo", MethodType.BinaryScalar, "%", isNumeric:true),
|
|
new MethodConfiguration("And", MethodType.Binary, "&", isBitwise: true),
|
|
new MethodConfiguration("And", MethodType.BinaryScalar, "&", isBitwise: true),
|
|
new MethodConfiguration("Or", MethodType.Binary, "|", isBitwise: true),
|
|
new MethodConfiguration("Or", MethodType.BinaryScalar, "|", isBitwise: true),
|
|
new MethodConfiguration("Xor", MethodType.Binary, "^", isBitwise: true),
|
|
new MethodConfiguration("Xor", MethodType.BinaryScalar, "^", isBitwise: true),
|
|
new MethodConfiguration("LeftShift", MethodType.BinaryInt, "<<", isBitwise: true),
|
|
new MethodConfiguration("RightShift", MethodType.BinaryInt, ">>", isBitwise: true),
|
|
|
|
// Note all of these are element-wise operations not testing the operation on the entire Tensor
|
|
new MethodConfiguration("Equals", MethodType.Comparison, "=="),
|
|
new MethodConfiguration("NotEquals", MethodType.Comparison, "!="),
|
|
new MethodConfiguration("GreaterThanOrEqual", MethodType.Comparison, ">=", isNumeric:true),
|
|
new MethodConfiguration("LessThanOrEqual", MethodType.Comparison, "<=", isNumeric:true),
|
|
new MethodConfiguration("GreaterThan", MethodType.Comparison, ">", isNumeric:true),
|
|
new MethodConfiguration("LessThan", MethodType.Comparison, "<", isNumeric:true),
|
|
|
|
new MethodConfiguration("Contract", MethodType.Contraction, isNumeric:true),
|
|
}.OrderBy(m => m.MethodName).ToArray();
|
|
#>
|