<#@ import namespace="System.Linq" #> <#@ import namespace="System.Text" #> <#@ import namespace="System.Collections.Generic" #> <#+ public class TypeConfiguration { public TypeConfiguration(string typeName, string classPrefix = null, string oneLiteral = "1", string zeroLiteral = "0", bool supportsNumeric = true, bool supportsBitwise = true, IEnumerable unsupportedMethods = null) { TypeName = typeName; ClassPrefix = classPrefix ?? char.ToUpper(typeName[0]) + typeName.Substring(1); OneLiteral = oneLiteral; ZeroLiteral = zeroLiteral; SupportsNumeric = supportsNumeric; SupportsBitwise = supportsBitwise; UnsupportedMethods = new HashSet(unsupportedMethods ?? Enumerable.Empty()); } public string TypeName { get; } public string ClassPrefix { get; } public string OneLiteral { get; } public string ZeroLiteral { get; } public bool SupportsNumeric { get; } public bool SupportsBitwise { get; } public ISet UnsupportedMethods { get; } } public string GenerateIfStatementHeader(TypeConfiguration type) { string keyword = (type == typeConfiguration[0]) ? "if" : "else if"; return $"{keyword} (typeof(T) == typeof({type.TypeName}))"; } public TypeConfiguration[] typeConfiguration = new [] { new TypeConfiguration("bool", oneLiteral:"true", zeroLiteral:"false", supportsNumeric: false, unsupportedMethods: new[] {"LeftShift", "RightShift"}), new TypeConfiguration("byte"), new TypeConfiguration("char", oneLiteral:"(char)1", zeroLiteral:"(char)0"), new TypeConfiguration("decimal", supportsBitwise: false), new TypeConfiguration("double", oneLiteral:"1.0", supportsBitwise: false), new TypeConfiguration("float", oneLiteral:"1.0f", supportsBitwise: false), new TypeConfiguration("int"), new TypeConfiguration("long"), new TypeConfiguration("sbyte", classPrefix:"SByte"), new TypeConfiguration("short"), new TypeConfiguration("uint", classPrefix:"UInt", unsupportedMethods: new[] {"UnaryMinus"}), new TypeConfiguration("ulong", classPrefix:"ULong", unsupportedMethods: new[] {"UnaryMinus"}), new TypeConfiguration("ushort", classPrefix:"UShort", unsupportedMethods: new[] {"UnaryMinus"}) }; public enum MethodType { Unary, UnaryInPlace, BinaryScalar, BinaryInt, Binary, Comparison, Contraction } public class MethodConfiguration { public MethodConfiguration(string methodName, MethodType methodType, string op = null, bool isNumeric = false, bool isBitwise = false) { MethodName = methodName; MethodType = methodType; Operator = op; IsNumeric = isNumeric; IsBitwise = isBitwise; } public string ResultName => "result"; public string Op1Name { get { switch (MethodType) { case MethodType.Unary: case MethodType.UnaryInPlace: case MethodType.BinaryScalar: case MethodType.BinaryInt: return "tensor"; case MethodType.Binary: case MethodType.Comparison: case MethodType.Contraction: return "left"; default: throw new ArgumentException(); }; } } public string Op2Name { get { switch (MethodType) { case MethodType.BinaryScalar: return "scalar"; case MethodType.BinaryInt: return "value"; case MethodType.Binary: case MethodType.Comparison: case MethodType.Contraction: return "right"; case MethodType.Unary: case MethodType.UnaryInPlace: default: throw new ArgumentException(); }; } } public string MethodName { get; } public MethodType MethodType { get; } public string Operator { get; } public string GetGenericMethodSignature(string tensorType, string genericType) { var resultType = GetResultType(tensorType, genericType); var arguments = GetMethodArguments(tensorType, genericType); return $"{resultType} {MethodName}<{genericType}>({arguments})"; } public string GetGenericResultMethodSignature(string tensorType, string genericType) { var resultType = GetResultType(tensorType, genericType); var arguments = GetMethodArguments(tensorType, genericType); return $"void {MethodName}<{genericType}>({arguments}, {resultType} {ResultName})"; } public string GetResultMethodSignature(string tensorType, string genericType) { var resultType = GetResultType(tensorType, genericType); var arguments = GetMethodArguments(tensorType, genericType); return $"void {MethodName}({arguments}, {resultType} {ResultName})"; } public string GetMethodArguments(string tensorType, string genericType) { switch (MethodType) { case MethodType.Unary: case MethodType.UnaryInPlace: return $"{tensorType}<{genericType}> {Op1Name}"; case MethodType.BinaryScalar: return $"{tensorType}<{genericType}> {Op1Name}, {genericType} {Op2Name}"; case MethodType.BinaryInt: return $"{tensorType}<{genericType}> {Op1Name}, int {Op2Name}"; case MethodType.Binary: case MethodType.Comparison: return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}"; case MethodType.Contraction: return $"{tensorType}<{genericType}> {Op1Name}, {tensorType}<{genericType}> {Op2Name}, int[] leftAxes, int[] rightAxes"; default: throw new ArgumentException(); } } public string GetCallArguments() { switch (MethodType) { case MethodType.Unary: case MethodType.UnaryInPlace: return $"{Op1Name}"; case MethodType.BinaryScalar: case MethodType.BinaryInt: case MethodType.Binary: case MethodType.Comparison: return $"{Op1Name}, {Op2Name}"; case MethodType.Contraction: return "left, right, leftAxes, rightAxes"; default: throw new ArgumentException(); } } public string GetValidationMethod(bool includeResult) { var suffix = includeResult ? ", result" : ""; switch (MethodType) { case MethodType.Unary: case MethodType.UnaryInPlace: case MethodType.BinaryScalar: case MethodType.BinaryInt: return $"ValidateArgs({Op1Name}{suffix});"; case MethodType.Binary: case MethodType.Comparison: return $"ValidateBinaryArgs({Op1Name}, {Op2Name}{suffix});"; case MethodType.Contraction: return $"var resultDimensions = ValidateContractArgs({Op1Name}, {Op2Name}, leftAxes, rightAxes{suffix});"; default: throw new ArgumentException(); } } public string GetResultType(string tensorType, string typeName) { switch (MethodType) { case MethodType.Unary: case MethodType.UnaryInPlace: case MethodType.BinaryScalar: case MethodType.BinaryInt: case MethodType.Binary: case MethodType.Contraction: return $"{tensorType}<{typeName}>"; case MethodType.Comparison: return $"{tensorType}"; default: throw new ArgumentException(); } } public string GetLinearOperationCheck() { switch (MethodType) { case MethodType.Unary: case MethodType.BinaryScalar: case MethodType.BinaryInt: return $"({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride)"; case MethodType.Binary: case MethodType.Comparison: return $"(({ResultName}.IsReversedStride == {Op1Name}.IsReversedStride) && ({ResultName}.IsReversedStride == {Op2Name}.IsReversedStride))"; case MethodType.UnaryInPlace: default: throw new ArgumentException(); } } public string GetElementOperation(string typeName, string access) { return GetElementOperation(typeName, access, access, access); } public string GetElementOperation(string typeName, string resultAccess, string leftAccess, string rightAccess) { switch (MethodType) { case MethodType.Unary: return $"{ResultName}{resultAccess} = ({typeName}){Operator}{Op1Name}{leftAccess}"; case MethodType.UnaryInPlace: return $"{ResultName}{resultAccess}{Operator}"; case MethodType.BinaryScalar: case MethodType.BinaryInt: return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name})"; case MethodType.Binary: return $"{ResultName}{resultAccess} = ({typeName})({Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess})"; case MethodType.Comparison: return $"{ResultName}{resultAccess} = {Op1Name}{leftAccess} {Operator} {Op2Name}{rightAccess}"; default: throw new ArgumentException(); } } public string InitializeResult(string typeName) { switch (MethodType) { case MethodType.UnaryInPlace: return $"{Op1Name}.Clone()"; case MethodType.Unary: case MethodType.BinaryScalar: case MethodType.BinaryInt: case MethodType.Binary: return $"{Op1Name}.CloneEmpty()"; case MethodType.Comparison: return $"{Op1Name}.CloneEmpty()"; case MethodType.Contraction: return $"{Op1Name}.CloneEmpty(resultDimensions)"; default: throw new ArgumentException(); } } public bool IsNumeric { get; } public bool IsBitwise { get; } } public MethodConfiguration[] methodConfiguration = new [] { new MethodConfiguration("Add", MethodType.Binary, "+", isNumeric:true), new MethodConfiguration("Add", MethodType.BinaryScalar, "+", isNumeric:true), new MethodConfiguration("UnaryPlus", MethodType.Unary, "+", isNumeric:true), new MethodConfiguration("Subtract", MethodType.Binary, "-", isNumeric:true), new MethodConfiguration("Subtract", MethodType.BinaryScalar, "-", isNumeric:true), new MethodConfiguration("UnaryMinus", MethodType.Unary, "-", isNumeric:true), new MethodConfiguration("Increment", MethodType.UnaryInPlace, "++", isNumeric:true), new MethodConfiguration("Decrement", MethodType.UnaryInPlace, "--", isNumeric:true), new MethodConfiguration("Multiply", MethodType.Binary, "*", isNumeric:true), // element-wise product, not matrix product new MethodConfiguration("Multiply", MethodType.BinaryScalar, "*", isNumeric:true), new MethodConfiguration("Divide", MethodType.Binary, "/", isNumeric:true), new MethodConfiguration("Divide", MethodType.BinaryScalar, "/", isNumeric:true), new MethodConfiguration("Modulo", MethodType.Binary, "%", isNumeric:true), new MethodConfiguration("Modulo", MethodType.BinaryScalar, "%", isNumeric:true), new MethodConfiguration("And", MethodType.Binary, "&", isBitwise: true), new MethodConfiguration("And", MethodType.BinaryScalar, "&", isBitwise: true), new MethodConfiguration("Or", MethodType.Binary, "|", isBitwise: true), new MethodConfiguration("Or", MethodType.BinaryScalar, "|", isBitwise: true), new MethodConfiguration("Xor", MethodType.Binary, "^", isBitwise: true), new MethodConfiguration("Xor", MethodType.BinaryScalar, "^", isBitwise: true), new MethodConfiguration("LeftShift", MethodType.BinaryInt, "<<", isBitwise: true), new MethodConfiguration("RightShift", MethodType.BinaryInt, ">>", isBitwise: true), // Note all of these are element-wise operations not testing the operation on the entire Tensor new MethodConfiguration("Equals", MethodType.Comparison, "=="), new MethodConfiguration("NotEquals", MethodType.Comparison, "!="), new MethodConfiguration("GreaterThanOrEqual", MethodType.Comparison, ">=", isNumeric:true), new MethodConfiguration("LessThanOrEqual", MethodType.Comparison, "<=", isNumeric:true), new MethodConfiguration("GreaterThan", MethodType.Comparison, ">", isNumeric:true), new MethodConfiguration("LessThan", MethodType.Comparison, "<", isNumeric:true), new MethodConfiguration("Contract", MethodType.Contraction, isNumeric:true), }.OrderBy(m => m.MethodName).ToArray(); #>