onnxruntime/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Yulong Wang abdc31de40
[js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728)
### Description

See
454996d496
for manual changes (excluded auto-generated formatting changes)

### Why

Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.

- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.

No one in community seems interested in fixing those.

Choose Prettier as it is the most popular TS/JS formatter.

### How to merge

It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.

So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
2024-08-14 16:51:22 -07:00

314 lines
13 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper } from './common';
export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
}
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
// "ij->ji" expresses matrix transpose
// "ii->i" diagonal elements of a square matrix
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
// 'Z' or '...' to represent arbitrary dimensions.
const symbolPattern = '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
interface SymbolInfo {
count: number; // Symbol corresponding to a dimmension of an input
inputIndices: number[]; // Number of input variables the symbol corresponds to
dimValue: number; // Number of dimensions the symbol corresponds to
}
class EinsumTerm {
constructor(inputIndex = -1) {
this.symbolToIndices = new Map<string, number[]>();
this.inputIndex = inputIndex;
}
// Add a symbol to the term
addSymbol(symbol: string, index: number) {
let value = this.symbolToIndices.get(symbol);
if (value === undefined) {
value = [index];
} else {
value.push(index);
}
this.symbolToIndices.set(symbol, value);
}
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
}
class EinsumEquation {
constructor(
inputs: readonly TensorView[],
public readonly equation: string,
) {
this.hasEllipsis = false;
this.symbolToInfo = new Map<string, SymbolInfo>();
this.lhs = new Array<EinsumTerm>();
this.outputDims = [];
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
// eslint-disable-next-line prefer-const
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
if (!lhs.match(RegExp(lhsPatternOnly))) {
throw new Error('Invalid LHS term');
}
const inputTerms = lhs.split(',');
inputTerms.forEach((inputTerm, index) => {
const dims = inputs[index].dims.slice();
if (!inputTerm.match(RegExp(termPatternOnly))) {
throw new Error('Invalid LHS term');
}
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
this.lhs.push(einsumTerm);
});
// Initialize the RHS if not specified
if (rhs === '') {
// Construct RHS from LHS terms/symbols
rhs += [...this.symbolToInfo.entries()]
.filter(([sym, info]) => info.count === 1 || sym === '...')
.map(([sym]) => sym)
.join('');
} else {
if (!rhs.match(RegExp(termPattern))) {
throw new Error('Invalid RHS');
}
}
// Compute output dims
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
rhsSymbols?.forEach((symbol) => {
if (symbol === '...') {
this.outputDims = this.outputDims.concat(this.ellipsisDims);
} else {
const info = this.symbolToInfo.get(symbol);
if (info === undefined) {
throw new Error('Invalid RHS symbol');
}
this.outputDims.push(info.dimValue);
}
});
this.rhs = this.processTerm(rhs, false, this.outputDims);
} // End of EinsumEqation constructor
// Add a symbol to the equation
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
let info = this.symbolToInfo.get(symbol);
if (info !== undefined) {
if (info.dimValue !== dimValue && info.count !== 1) {
throw new Error('Dimension mismatch');
} else {
info.count++;
info.inputIndices.push(inputIndex);
}
} else {
info = { count: 1, dimValue, inputIndices: [inputIndex] };
}
this.symbolToInfo.set(symbol, info);
}
// Process one input/output term
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
const rank = dims.length;
let ellipsis = false;
let ellipsisDims = [];
let nextDim = 0;
// For output empty string is allowed because the output may be reduced to a scalar value
if (!term.match(RegExp(termPatternOnly)) && !isInput && term !== '') {
throw new Error('Invalid LHS term');
}
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
const einsumTerm = new EinsumTerm(index);
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
indexSymbols?.forEach((symbol: string, i: number) => {
if (symbol === '...') {
if (ellipsis) {
throw new Error('Only one ellipsis is allowed per input term');
}
ellipsis = true;
const ellipsisDimLength = rank - indexSymbols.length + 1;
if (ellipsisDimLength < 0) {
throw new Error('Ellipsis out of bounds');
}
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
if (this.hasEllipsis) {
if (
this.ellipsisDims.length !== ellipsisDims.length ||
this.ellipsisDims.toString() !== ellipsisDims.toString()
) {
throw new Error('Ellipsis dimensions mismatch');
}
} else if (isInput) {
this.hasEllipsis = true;
this.ellipsisDims = ellipsisDims;
} else {
throw new Error('Ellipsis must be specified in the LHS');
}
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
for (let j = 0; j < ellipsisDims.length; j++) {
const symbol = String.fromCharCode('0'.charCodeAt(0) + j);
einsumTerm.addSymbol(symbol, i + j);
this.addSymbol(symbol, dims[nextDim++], index);
}
} else {
einsumTerm.addSymbol(symbol, i + (this.hasEllipsis ? this.ellipsisDims.length - 1 : 0));
this.addSymbol(symbol, dims[nextDim++], index);
}
});
return einsumTerm;
}
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
hasEllipsis: boolean; // The equation has ellipsis or not
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
rhs: EinsumTerm; // Term on the right-hand side of the equation
outputDims: number[]; // Output dimensions of the equation
} // End of class EinsumEquation
const appendMax = (name: string): string => name + '_max';
const createEinsumProgramInfo = (
inputShapes: Array<readonly number[]>,
dataType: number,
einsumEquation: EinsumEquation,
outputShape: readonly number[],
): ProgramInfo => {
const ranks = inputShapes.map((dims) => dims.length);
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
const outputSize = ShapeUtil.size(outputShape);
const output = outputVariable('output', dataType, outputShape.length);
const uniformsSymbols = [...einsumEquation.symbolToInfo.keys()].filter(
(symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol),
);
const getShaderSource = (shaderHelper: ShaderHelper) => {
const idxCopy: string[] = [];
const initProd = 'var prod = 1.0;';
const initSum = 'var sum = 0.0;';
const updateSum = 'sum += prod;';
const reduceOpsSetIndices: string[] = [];
const reduceOpsLoopHeaders: string[] = [];
const reduceOpsLoopFooters: string[] = [];
const reduceOpCompute: string[] = [];
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === einsumEquation.rhs.symbolToIndices.size;
einsumEquation.symbolToInfo.forEach((info, symbol) => {
if (einsumEquation.rhs.symbolToIndices.has(symbol)) {
const outputIndex = einsumEquation.rhs.symbolToIndices.get(symbol)?.[0];
if (outputIndex !== undefined) {
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
idxCopy.push(
`${inputVars[i].indicesSet(
`input${i}Indices`,
index,
output.indicesGet('outputIndices', outputIndex),
)}`,
);
});
}
});
}
} else {
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
});
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
}
});
reduceOpsLoopHeaders.push(
`for(var ${symbol}: u32 = 0; ${symbol} < uniforms.${appendMax(symbol)}; ${symbol}++) {`,
);
reduceOpsLoopFooters.push('}');
}
});
const reduceOps = isReduceOpsWithoutLoop
? [
...idxCopy,
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`,
]
: [
...idxCopy,
initSum,
...reduceOpsLoopHeaders,
...reduceOpsSetIndices,
initProd,
...reduceOpCompute,
updateSum,
...reduceOpsLoopFooters,
];
return `
${shaderHelper
.registerUniforms(uniformsSymbols.map((symbol) => ({ name: `${appendMax(symbol)}`, type: 'u32' })))
.registerUniform('outputSize', 'u32')
.declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((_var, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;
};
return {
name: 'Einsum',
shaderCache: { hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank') },
getRunData: () => {
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
// filter is added to make sure that dimValue is never 0.
const programUniformsInit: ProgramUniform[] = uniformsSymbols
.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
.map((symbol) => ({ type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0 }));
programUniformsInit.push({ type: DataType.uint32, data: outputSize });
const programUniforms: ProgramUniform[] = inputShapes
.map((dims, _) => [...createTensorShapeVariables(dims)])
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
programUniforms.push(...createTensorShapeVariables(outputShape));
return {
outputs: [{ dims: outputShape, dataType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
};
},
getShaderSource,
};
};
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
const outputShape = einsumEquation.outputDims;
const inputShapes = context.inputs.map((input, _) => input.dims);
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
};
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
const equation = (attributes.equation as string).replace(/\s+/g, '');
return createAttributeWithCacheKey({ equation });
};