mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
### Description Add uinforms to Einsum ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Improve performance.
312 lines
14 KiB
TypeScript
312 lines
14 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {TensorView} from '../../tensor-view';
|
|
import {ShapeUtil} from '../../util';
|
|
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
|
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
|
|
|
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
|
|
|
|
|
|
export interface EinsumAttributes extends AttributeWithCacheKey {
|
|
readonly equation: string;
|
|
}
|
|
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
|
|
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
|
|
// "ij->ji" expresses matrix transpose
|
|
// "ii->i" diagonal elements of a square matrix
|
|
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
|
|
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
|
|
// 'Z' or '...' to represent arbitrary dimensions.
|
|
|
|
const symbolPattern =
|
|
'[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
|
|
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
|
|
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
|
|
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
|
|
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
|
|
|
|
interface SymbolInfo {
|
|
count: number; // Symbol corresponding to a dimmension of an input
|
|
inputIndices: number[]; // Number of input variables the symbol corresponds to
|
|
dimValue: number; // Number of dimensions the symbol corresponds to
|
|
}
|
|
|
|
class EinsumTerm {
|
|
constructor(inputIndex = -1) {
|
|
this.symbolToIndices = new Map<string, number[]>();
|
|
this.inputIndex = inputIndex;
|
|
}
|
|
|
|
// Add a symbol to the term
|
|
addSymbol(symbol: string, index: number) {
|
|
let value = this.symbolToIndices.get(symbol);
|
|
if (value === undefined) {
|
|
value = [index];
|
|
} else {
|
|
value.push(index);
|
|
}
|
|
this.symbolToIndices.set(symbol, value);
|
|
}
|
|
|
|
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
|
|
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
|
|
}
|
|
|
|
class EinsumEquation {
|
|
constructor(inputs: readonly TensorView[], public readonly equation: string) {
|
|
this.hasEllipsis = false;
|
|
this.symbolToInfo = new Map<string, SymbolInfo>();
|
|
this.lhs = new Array<EinsumTerm>();
|
|
this.outputDims = [];
|
|
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
|
|
// eslint-disable-next-line prefer-const
|
|
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
|
|
if (!lhs.match(RegExp(lhsPatternOnly))) {
|
|
throw new Error('Invalid LHS term');
|
|
}
|
|
const inputTerms = lhs.split(',');
|
|
inputTerms.forEach((inputTerm, index) => {
|
|
const dims = inputs[index].dims.slice();
|
|
if (!inputTerm.match(RegExp(termPatternOnly))) {
|
|
throw new Error('Invalid LHS term');
|
|
}
|
|
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
|
|
this.lhs.push(einsumTerm);
|
|
});
|
|
|
|
// Initialize the RHS if not specified
|
|
if (rhs === '') {
|
|
// Construct RHS from LHS terms/symbols
|
|
rhs += [...this.symbolToInfo.entries()]
|
|
.filter(([sym, info]) => (info.count === 1 || sym === '...'))
|
|
.map(([sym]) => sym)
|
|
.join('');
|
|
} else {
|
|
if (!rhs.match(RegExp(termPattern))) {
|
|
throw new Error('Invalid RHS');
|
|
}
|
|
}
|
|
|
|
// Compute output dims
|
|
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
|
|
rhsSymbols?.forEach((symbol) => {
|
|
if (symbol === '...') {
|
|
this.outputDims = this.outputDims.concat(this.ellipsisDims);
|
|
} else {
|
|
const info = this.symbolToInfo.get(symbol);
|
|
if (info === undefined) {
|
|
throw new Error('Invalid RHS symbol');
|
|
}
|
|
this.outputDims.push(info.dimValue);
|
|
}
|
|
});
|
|
this.rhs = this.processTerm(rhs, false, this.outputDims);
|
|
} // End of EinsumEqation constructor
|
|
|
|
// Add a symbol to the equation
|
|
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
|
|
let info = this.symbolToInfo.get(symbol);
|
|
if (info !== undefined) {
|
|
if (info.dimValue !== dimValue && info.count !== 1) {
|
|
throw new Error('Dimension mismatch');
|
|
} else {
|
|
info.count++;
|
|
info.inputIndices.push(inputIndex);
|
|
}
|
|
} else {
|
|
info = {count: 1, dimValue, inputIndices: [inputIndex]};
|
|
}
|
|
this.symbolToInfo.set(symbol, info);
|
|
}
|
|
|
|
// Process one input/output term
|
|
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
|
|
const rank = dims.length;
|
|
let ellipsis = false;
|
|
let ellipsisDims = [];
|
|
let nextDim = 0;
|
|
// For output empty string is allowed because the output may be reduced to a scalar value
|
|
if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) {
|
|
throw new Error('Invalid LHS term');
|
|
}
|
|
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
|
|
const einsumTerm = new EinsumTerm(index);
|
|
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
|
|
indexSymbols?.forEach((symbol: string, i: number) => {
|
|
if (symbol === '...') {
|
|
if (ellipsis) {
|
|
throw new Error('Only one ellipsis is allowed per input term');
|
|
}
|
|
ellipsis = true;
|
|
const ellipsisDimLength = rank - indexSymbols.length + 1;
|
|
if (ellipsisDimLength < 0) {
|
|
throw new Error('Ellipsis out of bounds');
|
|
}
|
|
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
|
|
if (this.hasEllipsis) {
|
|
if (this.ellipsisDims.length !== ellipsisDims.length ||
|
|
this.ellipsisDims.toString() !== ellipsisDims.toString()) {
|
|
throw new Error('Ellipsis dimensions mismatch');
|
|
}
|
|
} else if (isInput) {
|
|
this.hasEllipsis = true;
|
|
this.ellipsisDims = ellipsisDims;
|
|
} else {
|
|
throw new Error('Ellipsis must be specified in the LHS');
|
|
}
|
|
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
|
|
for (let j = 0; j < ellipsisDims.length; j++) {
|
|
const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
|
|
einsumTerm.addSymbol(symbol, i + j);
|
|
this.addSymbol(symbol, dims[nextDim++], index);
|
|
}
|
|
} else {
|
|
einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
|
|
this.addSymbol(symbol, dims[nextDim++], index);
|
|
}
|
|
});
|
|
return einsumTerm;
|
|
}
|
|
|
|
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
|
|
hasEllipsis: boolean; // The equation has ellipsis or not
|
|
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
|
|
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
|
|
rhs: EinsumTerm; // Term on the right-hand side of the equation
|
|
outputDims: number[]; // Output dimensions of the equation
|
|
} // End of class EinsumEquation
|
|
|
|
const appendMax = (name: string): string => name + '_max';
|
|
|
|
const createEinsumProgramInfo =
|
|
(enableInputShapesUniforms: readonly boolean[], inputShapes: Array<readonly number[]>, dataType: number,
|
|
einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => {
|
|
const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims);
|
|
const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank));
|
|
const outputSize = ShapeUtil.size(outputShape);
|
|
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
|
|
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
|
|
const output = outputVariable('output', dataType, outputShapeOrRank);
|
|
const uniformsSymbols =
|
|
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
const idxCopy: string[] = [];
|
|
const initProd = 'var prod = 1.0;';
|
|
const initSum = 'var sum = 0.0;';
|
|
const updateSum = 'sum += prod;';
|
|
const reduceOpsSetIndices: string[] = [];
|
|
const reduceOpsLoopHeaders: string[] = [];
|
|
const reduceOpsLoopFooters: string[] = [];
|
|
const reduceOpCompute: string[] = [];
|
|
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
|
|
einsumEquation.symbolToInfo.forEach((info, symbol) => {
|
|
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
|
|
const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
|
|
if (outputIndex !== undefined) {
|
|
einsumEquation.lhs.forEach((term, i) => {
|
|
if (info.inputIndices.includes(i)) {
|
|
const indices = term.symbolToIndices.get(symbol);
|
|
if (indices === undefined) {
|
|
throw new Error('Invalid symbol error');
|
|
}
|
|
indices.forEach((index) => {
|
|
idxCopy.push(`${
|
|
inputVars[i].indicesSet(
|
|
`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
|
|
});
|
|
}
|
|
});
|
|
}
|
|
} else {
|
|
einsumEquation.lhs.forEach((term, i) => {
|
|
if (info.inputIndices.includes(i)) {
|
|
const indices = term.symbolToIndices.get(symbol);
|
|
if (indices === undefined) {
|
|
throw new Error('Invalid symbol error');
|
|
}
|
|
indices.forEach((index) => {
|
|
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
|
|
});
|
|
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
|
|
}
|
|
});
|
|
reduceOpsLoopHeaders.push(
|
|
`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`);
|
|
reduceOpsLoopFooters.push('}');
|
|
}
|
|
});
|
|
const reduceOps = isReduceOpsWithoutLoop ?
|
|
[
|
|
...idxCopy,
|
|
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
|
|
] :
|
|
[
|
|
...idxCopy,
|
|
initSum,
|
|
...reduceOpsLoopHeaders,
|
|
...reduceOpsSetIndices,
|
|
initProd,
|
|
...reduceOpCompute,
|
|
updateSum,
|
|
...reduceOpsLoopFooters,
|
|
];
|
|
return `
|
|
${
|
|
shaderHelper
|
|
.registerUniforms(uniformsSymbols.map((symbol) => ({name: `${appendMax(symbol)}`, type: 'u32'})))
|
|
.registerUniform('outputSize', 'u32')
|
|
.declareVariables(...inputVars, output)}
|
|
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
|
var outputIndices = ${output.offsetToIndices('global_idx')};
|
|
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
|
|
${reduceOps.join('\n')};
|
|
${output.setByOffset('global_idx', 'sum')};
|
|
}`;
|
|
};
|
|
return {
|
|
name: 'Einsum',
|
|
shaderCache: {
|
|
hint: einsumEquation.equation,
|
|
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
|
|
},
|
|
getRunData: () => {
|
|
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
|
|
// filter is added to make sure that dimValue is never 0.
|
|
const programUniformsInit: ProgramUniform[] =
|
|
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
|
|
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
|
|
programUniformsInit.push({type: 'uint32', data: outputSize});
|
|
const programUniforms: ProgramUniform[] =
|
|
inputShapes.filter((_, index) => enableInputShapesUniforms[index])
|
|
.map((dims, _) => [...createTensorShapeVariables(dims)])
|
|
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
|
|
if (enableOutputShapesUniforms) {
|
|
programUniforms.push(...createTensorShapeVariables(outputShape));
|
|
}
|
|
return ({
|
|
outputs: [{dims: outputShape, dataType}],
|
|
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
|
programUniforms
|
|
});
|
|
},
|
|
getShaderSource,
|
|
};
|
|
};
|
|
|
|
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
|
|
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
|
|
const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length));
|
|
const outputShape = einsumEquation.outputDims;
|
|
const inputShapes = context.inputs.map((input, _) => input.dims);
|
|
context.compute(createEinsumProgramInfo(
|
|
enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
|
|
};
|
|
|
|
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
|
|
const equation = (attributes.equation as string).replace(/\s+/g, '');
|
|
return createAttributeWithCacheKey({equation});
|
|
};
|