diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index ad18302318..de53f943bc 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -30,6 +30,7 @@ Do not modify directly.* | Cos | ai.onnx(7+) | | | Cosh | ai.onnx(9+) | | | Div | ai.onnx(7-12,13,14+) | | +| Einsum | ai.onnx(12+) | | | Elu | ai.onnx(6+) | | | Equal | ai.onnx(7-10,11-12,13-18,19+) | | | Erf | ai.onnx(9-12,13+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index a8fbf9c00e..9c46b97694 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -6,6 +6,7 @@ import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; +import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; @@ -52,6 +53,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Cos', [unaryOps.cos]], ['Cosh', [unaryOps.cosh]], ['Div', [binaryOps.div]], + ['Einsum', [einsum, parseEinsumAttributes]], ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]], ['Equal', [binaryOps.equal]], ['Erf', [unaryOps.erf]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts new file mode 100644 index 0000000000..f0196f37c3 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -0,0 +1,290 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + +export interface EinsumAttributes extends AttributeWithCacheKey { + readonly equation: string; +} +// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS) +// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication +// "ij->ji" expresses matrix transpose +// "ii->i" diagonal elements of a square matrix +// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable. +// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to +// 'Z' or '...' to represent arbitrary dimensions. + +const symbolPattern = + '[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match +const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match +const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end. +const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match +const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end. + +interface SymbolInfo { + count: number; // Symbol corresponding to a dimmension of an input + inputIndices: number[]; // Number of input variables the symbol corresponds to + dimValue: number; // Number of dimensions the symbol corresponds to +} + +class EinsumTerm { + constructor(inputIndex = -1) { + this.symbolToIndices = new Map(); + this.inputIndex = inputIndex; + } + + // Add a symbol to the term + addSymbol(symbol: string, index: number) { + let value = this.symbolToIndices.get(symbol); + if (value === undefined) { + value = [index]; + } else { + value.push(index); + } + this.symbolToIndices.set(symbol, value); + } + + symbolToIndices: Map; // Map from symbol to dimensions of the input corresponding to the term + inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs +} + +class EinsumEquation { + constructor(inputs: readonly TensorView[], equation: string) { + this.hasEllipsis = false; + this.symbolToInfo = new Map(); + this.lhs = new Array(); + this.outputDims = []; + // As rhs needs to be updated allow using let instead of const for both lhs and rhs. + // eslint-disable-next-line prefer-const + let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, '']; + if (!lhs.match(RegExp(lhsPatternOnly))) { + throw new Error('Invalid LHS term'); + } + const inputTerms = lhs.split(','); + inputTerms.forEach((inputTerm, index) => { + const dims = inputs[index].dims.slice(); + if (!inputTerm.match(RegExp(termPatternOnly))) { + throw new Error('Invalid LHS term'); + } + const einsumTerm = this.processTerm(inputTerm, true, dims, index); + this.lhs.push(einsumTerm); + }); + + // Initialize the RHS if not specified + if (rhs === '') { + // Construct RHS from LHS terms/symbols + rhs += [...this.symbolToInfo.entries()] + .filter(([sym, info]) => (info.count === 1 || sym === '...')) + .map(([sym]) => sym) + .join(''); + } else { + if (!rhs.match(RegExp(termPattern))) { + throw new Error('Invalid RHS'); + } + } + + // Compute output dims + const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g')); + rhsSymbols?.forEach((symbol) => { + if (symbol === '...') { + this.outputDims = this.outputDims.concat(this.ellipsisDims); + } else { + const info = this.symbolToInfo.get(symbol); + if (info === undefined) { + throw new Error('Invalid RHS symbol'); + } + this.outputDims.push(info.dimValue); + } + }); + this.rhs = this.processTerm(rhs, true, this.outputDims); + } // End of EinsumEqation constructor + + // Add a symbol to the equation + addSymbol(symbol: string, dimValue: number, inputIndex: number) { + let info = this.symbolToInfo.get(symbol); + if (info !== undefined) { + if (info.dimValue !== dimValue && info.count !== 1) { + throw new Error('Dimension mismatch'); + } else { + info.count++; + info.inputIndices.push(inputIndex); + } + } else { + info = {count: 1, dimValue, inputIndices: [inputIndex]}; + } + this.symbolToInfo.set(symbol, info); + } + + // Process one input/output term + processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm { + const rank = dims.length; + let ellipsis = false; + let ellipsisDims = []; + let nextDim = 0; + // For output empty string is allowed because the output may be reduced to a scalar value + if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) { + throw new Error('Invalid LHS term'); + } + const indexSymbols = term.match(RegExp(symbolPattern, 'g')); + const einsumTerm = new EinsumTerm(index); + // symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...' + indexSymbols?.forEach((symbol: string, i: number) => { + if (symbol === '...') { + if (ellipsis) { + throw new Error('Only one ellipsis is allowed per input term'); + } + ellipsis = true; + const ellipsisDimLength = rank - indexSymbols.length + 1; + if (ellipsisDimLength < 0) { + throw new Error('Ellipsis out of bounds'); + } + ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength); + if (this.hasEllipsis) { + if (this.ellipsisDims.length !== ellipsisDims.length || + this.ellipsisDims.toString() !== ellipsisDims.toString()) { + throw new Error('Ellipsis dimensions mismatch'); + } + } else if (isInput) { + this.hasEllipsis = true; + this.ellipsisDims = ellipsisDims; + } else { + throw new Error('Ellipsis must be specified in the LHS'); + } + // Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling + for (let j = 0; j < ellipsisDims.length; j++) { + const symbol = String.fromCharCode('0'.charCodeAt(0) + i); + einsumTerm.addSymbol(symbol, i + j); + this.addSymbol(symbol, dims[nextDim++], index); + } + } else { + einsumTerm.addSymbol(symbol, i); + this.addSymbol(symbol, dims[nextDim++], index); + } + }); + return einsumTerm; + } + + symbolToInfo: Map; // All symbols in the equation + hasEllipsis: boolean; // The equation has ellipsis or not + ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to. + lhs: EinsumTerm[]; // Terms on the left-hand side of the equation + rhs: EinsumTerm; // Term on the right-hand side of the equation + outputDims: number[]; // Output dimensions of the equation +} // End of class EinsumEquation + + +const createEinsumProgramMetadata = (inputCount: number, cacheHint: string): ProgramMetadata => + ({name: 'Einsum', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint}); + +const createEinsumProgramInfo = + (metadata: ProgramMetadata, inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => { + const dataType = inputs[0].dataType; + const inputVars = new Array(inputs.length); + for (let i = 0; i < inputs.length; ++i) { + inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims); + } + const outputShape = einsumEquation.outputDims; + const outputSize = ShapeUtil.size(outputShape); + const output = outputVariable('output', dataType, outputShape); + const idxCopy: string[] = []; + const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys()); + const initProd = 'var prod = 1.0;'; + const initSum = 'var sum = 0.0;'; + const updateSum = 'sum += prod;'; + const reduceOpsSetIndices: string[] = []; + const reduceOpsLoopHeaders: string[] = []; + const reduceOpsLoopFooters: string[] = []; + const reduceOpCompute: string[] = []; + const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length; + einsumEquation.symbolToInfo.forEach((info, symbol) => { + if (rhsSymbols.includes(symbol)) { + const outputIndex = rhsSymbols.indexOf(symbol); + einsumEquation.lhs.forEach((term, i) => { + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + idxCopy.push(`${ + inputVars[i].indicesSet( + `input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`); + }); + } + }); + } else { + einsumEquation.lhs.forEach((term, i) => { + const info = einsumEquation.symbolToInfo.get(symbol); + if (info === undefined) { + throw new Error('Invalid symbol error'); + } + if (info.inputIndices.includes(i)) { + const indices = term.symbolToIndices.get(symbol); + if (indices === undefined) { + throw new Error('Invalid symbol error'); + } + indices.forEach((index) => { + reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`); + }); + reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`); + } + }); + reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${ + einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`); + reduceOpsLoopFooters.push('}'); + } + }); + const reduceOps = isReduceOpsWithoutLoop ? + [ + ...idxCopy, + `let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};` + ] : + [ + ...idxCopy, + initSum, + ...reduceOpsLoopHeaders, + ...reduceOpsSetIndices, + initProd, + ...reduceOpCompute, + updateSum, + ...reduceOpsLoopFooters, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(...inputVars, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + var outputIndices = ${output.offsetToIndices('global_idx')}; + ${inputVars.map((inputVar, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')} + ${reduceOps.join('\n')}; + ${output.setByOffset('global_idx', 'sum')}; + }`; + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +const createEinsumProgramInfoLoader = + (inputs: readonly TensorView[], einsumEquation: EinsumEquation, attributes: EinsumAttributes): + ProgramInfoLoader => { + const metadata = createEinsumProgramMetadata(inputs.length, attributes.cacheKey); + return {...metadata, get: () => createEinsumProgramInfo(metadata, inputs, einsumEquation)}; + }; + +export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { + const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); + context.compute(createEinsumProgramInfoLoader(context.inputs, einsumEquation, attributes)); +}; + +export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { + const equation = (attributes.equation as string).replace(/\s+/g, ''); + return createAttributeWithCacheKey({equation}); +}; diff --git a/js/web/test/data/ops/einsum.jsonc b/js/web/test/data/ops/einsum.jsonc new file mode 100644 index 0000000000..baf30cf982 --- /dev/null +++ b/js/web/test/data/ops/einsum.jsonc @@ -0,0 +1,635 @@ +[ + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i", + "type": "string" + } + ], + "cases": [ + { + "name": "Dotproduct/scalar product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [32], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,i->i", + "type": "string" + } + ], + "cases": [ + { + "name": "elementwise product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 10, 18], + "dims": [3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j", + "type": "string" + } + ], + "cases": [ + { + "name": "Product without specifying RSH", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 8, 10, 12, 12, 15, 18], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i,j->ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Product", + "inputs": [ + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 8, 10, 12, 12, 15, 18], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ii,jj", + "type": "string" + } + ], + "cases": [ + { + "name": "Diagonal elementwise multiplication", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [45], + "dims": [], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ii,jj -> ij", + "type": "string" + } + ], + "cases": [ + { + "name": "Dotproduct", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [1, 0, 0, 0, 1, 0, 0, 0, 1], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 1, 1, 5, 5, 5, 9, 9, 9], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij,jk->ik", + "type": "string" + } + ], + "cases": [ + { + "name": "Multiply", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [38, 44, 50, 56, 83, 98, 113, 128], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij->ji", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij->i", + "type": "string" + } + ], + "cases": [ + { + "name": "ReduceSum", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6, 15, 24], + "dims": [3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ii->i", + "type": "string" + } + ], + "cases": [ + { + "name": "Diagonal", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9], + "dims": [3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "ij...,jk...->ik...", + "type": "string" + } + ], + "cases": [ + { + "name": "Multiply with ellipsis - A", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [38, 44, 50, 56, 83, 98, 113, 128], + "dims": [2, 4, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij,...jk->...ik", + "type": "string" + } + ], + "cases": [ + { + "name": "Multiply with ellipsis - B", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 2, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [38, 44, 50, 56, 83, 98, 113, 128], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "i...j,j...k->i...k", + "type": "string" + } + ], + "cases": [ + { + "name": "Multiply with ellipsis - C", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [2, 1, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [38, 44, 50, 56, 83, 98, 113, 128], + "dims": [2, 1, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij,jk->...ik", + "type": "string" + } + ], + "cases": [ + { + "name": "Multiply with ellipsis - D", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 2, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [38, 44, 50, 56, 83, 98, 113, 128], + "dims": [1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij->...ji", + "type": "string" + } + ], + "cases": [ + { + "name": "Transpose with ellipsis", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [1, 2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 4, 2, 5, 3, 6], + "dims": [1, 3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ij->...i", + "type": "string" + } + ], + "cases": [ + { + "name": "ReduceSum with ellipsis", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [1, 3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [6, 15, 24], + "dims": [1, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "einsum", + "operator": "Einsum", + "opset": { + "domain": "", + "version": 12 + }, + "attributes": [ + { + "name": "equation", + "data": "...ii->...i", + "type": "string" + } + ], + "cases": [ + { + "name": "Diagonal with ellipsis", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9], + "dims": [1, 3, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9], + "dims": [1, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c5b3b1933e..6c79bf6c83 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -319,6 +319,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); + std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -573,6 +575,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc index 2c99e246b6..2402bb33ce 100644 --- a/onnxruntime/core/providers/js/js_export.cc +++ b/onnxruntime/core/providers/js/js_export.cc @@ -9,9 +9,7 @@ const void* JsepOutput(void* context, int index, const void* data) { const uint32_t* data_offset = reinterpret_cast(data); uint32_t dim = *data_offset++; size_t dim_size = static_cast(dim); - std::vector dims; - dims.reserve(dim_size); - dims.resize(dim_size); + std::vector dims(dim_size); for (size_t i = 0; i < dim_size; i++) { dims[i] = static_cast(*data_offset++); } diff --git a/onnxruntime/core/providers/js/operators/einsum.cc b/onnxruntime/core/providers/js/operators/einsum.cc new file mode 100644 index 0000000000..2fdc14fa3a --- /dev/null +++ b/onnxruntime/core/providers/js/operators/einsum.cc @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "einsum.h" + +namespace onnxruntime { +namespace js { + +ONNX_OPERATOR_TYPED_KERNEL_EX( + Einsum, + kOnnxDomain, + 12, + float, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Einsum); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/einsum.h b/onnxruntime/core/providers/js/operators/einsum.h new file mode 100644 index 0000000000..ec8b6f5dab --- /dev/null +++ b/onnxruntime/core/providers/js/operators/einsum.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +class Einsum final : public JsKernel { + public: + Einsum(const OpKernelInfo& info) : JsKernel(info) { + std::string equation; + ORT_ENFORCE(info.GetAttr("equation", &equation).IsOK(), + "Missing 'equation' attribute"); + JSEP_INIT_KERNEL_ATTRIBUTE(Einsum, ({ + "equation" : UTF8ToString($1), + }), + equation.c_str()); + } +}; + +} // namespace js +} // namespace onnxruntime