mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
### Description
See
454996d496
for manual changes (excluded auto-generated formatting changes)
### Why
Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.
- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.
No one in community seems interested in fixing those.
Choose Prettier as it is the most popular TS/JS formatter.
### How to merge
It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.
So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
314 lines
13 KiB
TypeScript
314 lines
13 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { DataType } from '../../../wasm-common';
|
|
import { TensorView } from '../../tensor-view';
|
|
import { ShapeUtil } from '../../util';
|
|
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
|
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
|
|
|
|
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
|
|
|
|
export interface EinsumAttributes extends AttributeWithCacheKey {
|
|
readonly equation: string;
|
|
}
|
|
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
|
|
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
|
|
// "ij->ji" expresses matrix transpose
|
|
// "ii->i" diagonal elements of a square matrix
|
|
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
|
|
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
|
|
// 'Z' or '...' to represent arbitrary dimensions.
|
|
|
|
const symbolPattern = '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
|
|
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
|
|
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
|
|
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
|
|
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
|
|
|
|
interface SymbolInfo {
|
|
count: number; // Symbol corresponding to a dimmension of an input
|
|
inputIndices: number[]; // Number of input variables the symbol corresponds to
|
|
dimValue: number; // Number of dimensions the symbol corresponds to
|
|
}
|
|
|
|
class EinsumTerm {
|
|
constructor(inputIndex = -1) {
|
|
this.symbolToIndices = new Map<string, number[]>();
|
|
this.inputIndex = inputIndex;
|
|
}
|
|
|
|
// Add a symbol to the term
|
|
addSymbol(symbol: string, index: number) {
|
|
let value = this.symbolToIndices.get(symbol);
|
|
if (value === undefined) {
|
|
value = [index];
|
|
} else {
|
|
value.push(index);
|
|
}
|
|
this.symbolToIndices.set(symbol, value);
|
|
}
|
|
|
|
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
|
|
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
|
|
}
|
|
|
|
class EinsumEquation {
|
|
constructor(
|
|
inputs: readonly TensorView[],
|
|
public readonly equation: string,
|
|
) {
|
|
this.hasEllipsis = false;
|
|
this.symbolToInfo = new Map<string, SymbolInfo>();
|
|
this.lhs = new Array<EinsumTerm>();
|
|
this.outputDims = [];
|
|
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
|
|
// eslint-disable-next-line prefer-const
|
|
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
|
|
if (!lhs.match(RegExp(lhsPatternOnly))) {
|
|
throw new Error('Invalid LHS term');
|
|
}
|
|
const inputTerms = lhs.split(',');
|
|
inputTerms.forEach((inputTerm, index) => {
|
|
const dims = inputs[index].dims.slice();
|
|
if (!inputTerm.match(RegExp(termPatternOnly))) {
|
|
throw new Error('Invalid LHS term');
|
|
}
|
|
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
|
|
this.lhs.push(einsumTerm);
|
|
});
|
|
|
|
// Initialize the RHS if not specified
|
|
if (rhs === '') {
|
|
// Construct RHS from LHS terms/symbols
|
|
rhs += [...this.symbolToInfo.entries()]
|
|
.filter(([sym, info]) => info.count === 1 || sym === '...')
|
|
.map(([sym]) => sym)
|
|
.join('');
|
|
} else {
|
|
if (!rhs.match(RegExp(termPattern))) {
|
|
throw new Error('Invalid RHS');
|
|
}
|
|
}
|
|
|
|
// Compute output dims
|
|
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
|
|
rhsSymbols?.forEach((symbol) => {
|
|
if (symbol === '...') {
|
|
this.outputDims = this.outputDims.concat(this.ellipsisDims);
|
|
} else {
|
|
const info = this.symbolToInfo.get(symbol);
|
|
if (info === undefined) {
|
|
throw new Error('Invalid RHS symbol');
|
|
}
|
|
this.outputDims.push(info.dimValue);
|
|
}
|
|
});
|
|
this.rhs = this.processTerm(rhs, false, this.outputDims);
|
|
} // End of EinsumEqation constructor
|
|
|
|
// Add a symbol to the equation
|
|
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
|
|
let info = this.symbolToInfo.get(symbol);
|
|
if (info !== undefined) {
|
|
if (info.dimValue !== dimValue && info.count !== 1) {
|
|
throw new Error('Dimension mismatch');
|
|
} else {
|
|
info.count++;
|
|
info.inputIndices.push(inputIndex);
|
|
}
|
|
} else {
|
|
info = { count: 1, dimValue, inputIndices: [inputIndex] };
|
|
}
|
|
this.symbolToInfo.set(symbol, info);
|
|
}
|
|
|
|
// Process one input/output term
|
|
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
|
|
const rank = dims.length;
|
|
let ellipsis = false;
|
|
let ellipsisDims = [];
|
|
let nextDim = 0;
|
|
// For output empty string is allowed because the output may be reduced to a scalar value
|
|
if (!term.match(RegExp(termPatternOnly)) && !isInput && term !== '') {
|
|
throw new Error('Invalid LHS term');
|
|
}
|
|
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
|
|
const einsumTerm = new EinsumTerm(index);
|
|
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
|
|
indexSymbols?.forEach((symbol: string, i: number) => {
|
|
if (symbol === '...') {
|
|
if (ellipsis) {
|
|
throw new Error('Only one ellipsis is allowed per input term');
|
|
}
|
|
ellipsis = true;
|
|
const ellipsisDimLength = rank - indexSymbols.length + 1;
|
|
if (ellipsisDimLength < 0) {
|
|
throw new Error('Ellipsis out of bounds');
|
|
}
|
|
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
|
|
if (this.hasEllipsis) {
|
|
if (
|
|
this.ellipsisDims.length !== ellipsisDims.length ||
|
|
this.ellipsisDims.toString() !== ellipsisDims.toString()
|
|
) {
|
|
throw new Error('Ellipsis dimensions mismatch');
|
|
}
|
|
} else if (isInput) {
|
|
this.hasEllipsis = true;
|
|
this.ellipsisDims = ellipsisDims;
|
|
} else {
|
|
throw new Error('Ellipsis must be specified in the LHS');
|
|
}
|
|
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
|
|
for (let j = 0; j < ellipsisDims.length; j++) {
|
|
const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
|
|
einsumTerm.addSymbol(symbol, i + j);
|
|
this.addSymbol(symbol, dims[nextDim++], index);
|
|
}
|
|
} else {
|
|
einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
|
|
this.addSymbol(symbol, dims[nextDim++], index);
|
|
}
|
|
});
|
|
return einsumTerm;
|
|
}
|
|
|
|
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
|
|
hasEllipsis: boolean; // The equation has ellipsis or not
|
|
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
|
|
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
|
|
rhs: EinsumTerm; // Term on the right-hand side of the equation
|
|
outputDims: number[]; // Output dimensions of the equation
|
|
} // End of class EinsumEquation
|
|
|
|
const appendMax = (name: string): string => name + '_max';
|
|
|
|
const createEinsumProgramInfo = (
|
|
inputShapes: Array<readonly number[]>,
|
|
dataType: number,
|
|
einsumEquation: EinsumEquation,
|
|
outputShape: readonly number[],
|
|
): ProgramInfo => {
|
|
const ranks = inputShapes.map((dims) => dims.length);
|
|
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
|
|
const outputSize = ShapeUtil.size(outputShape);
|
|
const output = outputVariable('output', dataType, outputShape.length);
|
|
const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter(
|
|
(symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol),
|
|
);
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
const idxCopy: string[] = [];
|
|
const initProd = 'var prod = 1.0;';
|
|
const initSum = 'var sum = 0.0;';
|
|
const updateSum = 'sum += prod;';
|
|
const reduceOpsSetIndices: string[] = [];
|
|
const reduceOpsLoopHeaders: string[] = [];
|
|
const reduceOpsLoopFooters: string[] = [];
|
|
const reduceOpCompute: string[] = [];
|
|
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
|
|
einsumEquation.symbolToInfo.forEach((info, symbol) => {
|
|
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
|
|
const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
|
|
if (outputIndex !== undefined) {
|
|
einsumEquation.lhs.forEach((term, i) => {
|
|
if (info.inputIndices.includes(i)) {
|
|
const indices = term.symbolToIndices.get(symbol);
|
|
if (indices === undefined) {
|
|
throw new Error('Invalid symbol error');
|
|
}
|
|
indices.forEach((index) => {
|
|
idxCopy.push(
|
|
`${inputVars[i].indicesSet(
|
|
`input${i}Indices`,
|
|
index,
|
|
output.indicesGet('outputIndices', outputIndex),
|
|
)}`,
|
|
);
|
|
});
|
|
}
|
|
});
|
|
}
|
|
} else {
|
|
einsumEquation.lhs.forEach((term, i) => {
|
|
if (info.inputIndices.includes(i)) {
|
|
const indices = term.symbolToIndices.get(symbol);
|
|
if (indices === undefined) {
|
|
throw new Error('Invalid symbol error');
|
|
}
|
|
indices.forEach((index) => {
|
|
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
|
|
});
|
|
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
|
|
}
|
|
});
|
|
reduceOpsLoopHeaders.push(
|
|
`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`,
|
|
);
|
|
reduceOpsLoopFooters.push('}');
|
|
}
|
|
});
|
|
const reduceOps = isReduceOpsWithoutLoop
|
|
? [
|
|
...idxCopy,
|
|
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`,
|
|
]
|
|
: [
|
|
...idxCopy,
|
|
initSum,
|
|
...reduceOpsLoopHeaders,
|
|
...reduceOpsSetIndices,
|
|
initProd,
|
|
...reduceOpCompute,
|
|
updateSum,
|
|
...reduceOpsLoopFooters,
|
|
];
|
|
return `
|
|
${shaderHelper
|
|
.registerUniforms(uniformsSymbols.map((symbol) => ({ name: `${appendMax(symbol)}`, type: 'u32' })))
|
|
.registerUniform('outputSize', 'u32')
|
|
.declareVariables(...inputVars, output)}
|
|
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
|
var outputIndices = ${output.offsetToIndices('global_idx')};
|
|
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
|
|
${reduceOps.join('\n')};
|
|
${output.setByOffset('global_idx', 'sum')};
|
|
}`;
|
|
};
|
|
return {
|
|
name: 'Einsum',
|
|
shaderCache: { hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank') },
|
|
getRunData: () => {
|
|
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
|
|
// filter is added to make sure that dimValue is never 0.
|
|
const programUniformsInit: ProgramUniform[] = uniformsSymbols
|
|
.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
|
|
.map((symbol) => ({ type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0 }));
|
|
programUniformsInit.push({ type: DataType.uint32, data: outputSize });
|
|
const programUniforms: ProgramUniform[] = inputShapes
|
|
.map((dims, _) => [...createTensorShapeVariables(dims)])
|
|
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
|
|
programUniforms.push(...createTensorShapeVariables(outputShape));
|
|
return {
|
|
outputs: [{ dims: outputShape, dataType }],
|
|
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
|
|
programUniforms,
|
|
};
|
|
},
|
|
getShaderSource,
|
|
};
|
|
};
|
|
|
|
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
|
|
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
|
|
const outputShape = einsumEquation.outputDims;
|
|
const inputShapes = context.inputs.map((input, _) => input.dims);
|
|
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
|
|
};
|
|
|
|
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
|
|
const equation = (attributes.equation as string).replace(/\s+/g, '');
|
|
return createAttributeWithCacheKey({ equation });
|
|
};
|