[JS/Web] Added Einsum operator support. (#17401)

### Description
Added Einsum operator support to JSEP.



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-09-11 15:57:15 -07:00 committed by GitHub
parent 850baced33
commit bf6d6961cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 979 additions and 3 deletions

View file

@ -30,6 +30,7 @@ Do not modify directly.*
| Cos | ai.onnx(7+) | |
| Cosh | ai.onnx(9+) | |
| Div | ai.onnx(7-12,13,14+) | |
| Einsum | ai.onnx(12+) | |
| Elu | ai.onnx(6+) | |
| Equal | ai.onnx(7-10,11-12,13-18,19+) | |
| Erf | ai.onnx(9-12,13+) | |

View file

@ -6,6 +6,7 @@ import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
@ -52,6 +53,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Cos', [unaryOps.cos]],
['Cosh', [unaryOps.cosh]],
['Div', [binaryOps.div]],
['Einsum', [einsum, parseEinsumAttributes]],
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
['Equal', [binaryOps.equal]],
['Erf', [unaryOps.erf]],

View file

@ -0,0 +1,290 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
export interface EinsumAttributes extends AttributeWithCacheKey {
readonly equation: string;
}
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
// "ij->ji" expresses matrix transpose
// "ii->i" diagonal elements of a square matrix
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
// 'Z' or '...' to represent arbitrary dimensions.
const symbolPattern =
'[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
interface SymbolInfo {
count: number; // Symbol corresponding to a dimmension of an input
inputIndices: number[]; // Number of input variables the symbol corresponds to
dimValue: number; // Number of dimensions the symbol corresponds to
}
class EinsumTerm {
constructor(inputIndex = -1) {
this.symbolToIndices = new Map<string, number[]>();
this.inputIndex = inputIndex;
}
// Add a symbol to the term
addSymbol(symbol: string, index: number) {
let value = this.symbolToIndices.get(symbol);
if (value === undefined) {
value = [index];
} else {
value.push(index);
}
this.symbolToIndices.set(symbol, value);
}
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
}
class EinsumEquation {
constructor(inputs: readonly TensorView[], equation: string) {
this.hasEllipsis = false;
this.symbolToInfo = new Map<string, SymbolInfo>();
this.lhs = new Array<EinsumTerm>();
this.outputDims = [];
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
// eslint-disable-next-line prefer-const
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
if (!lhs.match(RegExp(lhsPatternOnly))) {
throw new Error('Invalid LHS term');
}
const inputTerms = lhs.split(',');
inputTerms.forEach((inputTerm, index) => {
const dims = inputs[index].dims.slice();
if (!inputTerm.match(RegExp(termPatternOnly))) {
throw new Error('Invalid LHS term');
}
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
this.lhs.push(einsumTerm);
});
// Initialize the RHS if not specified
if (rhs === '') {
// Construct RHS from LHS terms/symbols
rhs += [...this.symbolToInfo.entries()]
.filter(([sym, info]) => (info.count === 1 || sym === '...'))
.map(([sym]) => sym)
.join('');
} else {
if (!rhs.match(RegExp(termPattern))) {
throw new Error('Invalid RHS');
}
}
// Compute output dims
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
rhsSymbols?.forEach((symbol) => {
if (symbol === '...') {
this.outputDims = this.outputDims.concat(this.ellipsisDims);
} else {
const info = this.symbolToInfo.get(symbol);
if (info === undefined) {
throw new Error('Invalid RHS symbol');
}
this.outputDims.push(info.dimValue);
}
});
this.rhs = this.processTerm(rhs, true, this.outputDims);
} // End of EinsumEqation constructor
// Add a symbol to the equation
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
let info = this.symbolToInfo.get(symbol);
if (info !== undefined) {
if (info.dimValue !== dimValue && info.count !== 1) {
throw new Error('Dimension mismatch');
} else {
info.count++;
info.inputIndices.push(inputIndex);
}
} else {
info = {count: 1, dimValue, inputIndices: [inputIndex]};
}
this.symbolToInfo.set(symbol, info);
}
// Process one input/output term
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
const rank = dims.length;
let ellipsis = false;
let ellipsisDims = [];
let nextDim = 0;
// For output empty string is allowed because the output may be reduced to a scalar value
if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) {
throw new Error('Invalid LHS term');
}
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
const einsumTerm = new EinsumTerm(index);
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
indexSymbols?.forEach((symbol: string, i: number) => {
if (symbol === '...') {
if (ellipsis) {
throw new Error('Only one ellipsis is allowed per input term');
}
ellipsis = true;
const ellipsisDimLength = rank - indexSymbols.length + 1;
if (ellipsisDimLength < 0) {
throw new Error('Ellipsis out of bounds');
}
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
if (this.hasEllipsis) {
if (this.ellipsisDims.length !== ellipsisDims.length ||
this.ellipsisDims.toString() !== ellipsisDims.toString()) {
throw new Error('Ellipsis dimensions mismatch');
}
} else if (isInput) {
this.hasEllipsis = true;
this.ellipsisDims = ellipsisDims;
} else {
throw new Error('Ellipsis must be specified in the LHS');
}
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
for (let j = 0; j < ellipsisDims.length; j++) {
const symbol = String.fromCharCode('0'.charCodeAt(0) + i);
einsumTerm.addSymbol(symbol, i + j);
this.addSymbol(symbol, dims[nextDim++], index);
}
} else {
einsumTerm.addSymbol(symbol, i);
this.addSymbol(symbol, dims[nextDim++], index);
}
});
return einsumTerm;
}
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
hasEllipsis: boolean; // The equation has ellipsis or not
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
rhs: EinsumTerm; // Term on the right-hand side of the equation
outputDims: number[]; // Output dimensions of the equation
} // End of class EinsumEquation
const createEinsumProgramMetadata = (inputCount: number, cacheHint: string): ProgramMetadata =>
({name: 'Einsum', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint});
const createEinsumProgramInfo =
(metadata: ProgramMetadata, inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => {
const dataType = inputs[0].dataType;
const inputVars = new Array<IndicesHelper>(inputs.length);
for (let i = 0; i < inputs.length; ++i) {
inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims);
}
const outputShape = einsumEquation.outputDims;
const outputSize = ShapeUtil.size(outputShape);
const output = outputVariable('output', dataType, outputShape);
const idxCopy: string[] = [];
const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys());
const initProd = 'var prod = 1.0;';
const initSum = 'var sum = 0.0;';
const updateSum = 'sum += prod;';
const reduceOpsSetIndices: string[] = [];
const reduceOpsLoopHeaders: string[] = [];
const reduceOpsLoopFooters: string[] = [];
const reduceOpCompute: string[] = [];
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length;
einsumEquation.symbolToInfo.forEach((info, symbol) => {
if (rhsSymbols.includes(symbol)) {
const outputIndex = rhsSymbols.indexOf(symbol);
einsumEquation.lhs.forEach((term, i) => {
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
idxCopy.push(`${
inputVars[i].indicesSet(
`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
});
}
});
} else {
einsumEquation.lhs.forEach((term, i) => {
const info = einsumEquation.symbolToInfo.get(symbol);
if (info === undefined) {
throw new Error('Invalid symbol error');
}
if (info.inputIndices.includes(i)) {
const indices = term.symbolToIndices.get(symbol);
if (indices === undefined) {
throw new Error('Invalid symbol error');
}
indices.forEach((index) => {
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
});
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
}
});
reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${
einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`);
reduceOpsLoopFooters.push('}');
}
});
const reduceOps = isReduceOpsWithoutLoop ?
[
...idxCopy,
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
] :
[
...idxCopy,
initSum,
...reduceOpsLoopHeaders,
...reduceOpsSetIndices,
initProd,
...reduceOpCompute,
updateSum,
...reduceOpsLoopFooters,
];
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(...inputVars, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
var outputIndices = ${output.offsetToIndices('global_idx')};
${inputVars.map((inputVar, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
${reduceOps.join('\n')};
${output.setByOffset('global_idx', 'sum')};
}`;
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
};
};
const createEinsumProgramInfoLoader =
(inputs: readonly TensorView[], einsumEquation: EinsumEquation, attributes: EinsumAttributes):
ProgramInfoLoader => {
const metadata = createEinsumProgramMetadata(inputs.length, attributes.cacheKey);
return {...metadata, get: () => createEinsumProgramInfo(metadata, inputs, einsumEquation)};
};
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
context.compute(createEinsumProgramInfoLoader(context.inputs, einsumEquation, attributes));
};
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
const equation = (attributes.equation as string).replace(/\s+/g, '');
return createAttributeWithCacheKey({equation});
};

View file

@ -0,0 +1,635 @@
[
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "i,i",
"type": "string"
}
],
"cases": [
{
"name": "Dotproduct/scalar product",
"inputs": [
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [32],
"dims": [],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "i,i->i",
"type": "string"
}
],
"cases": [
{
"name": "elementwise product",
"inputs": [
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 10, 18],
"dims": [3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "i,j",
"type": "string"
}
],
"cases": [
{
"name": "Product without specifying RSH",
"inputs": [
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 5, 6, 8, 10, 12, 12, 15, 18],
"dims": [3, 3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "i,j->ij",
"type": "string"
}
],
"cases": [
{
"name": "Product",
"inputs": [
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 5, 6, 8, 10, 12, 12, 15, 18],
"dims": [3, 3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ii,jj",
"type": "string"
}
],
"cases": [
{
"name": "Diagonal elementwise multiplication",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [3, 3],
"type": "float32"
},
{
"data": [1, 0, 0, 0, 1, 0, 0, 0, 1],
"dims": [3, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [45],
"dims": [],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ii,jj -> ij",
"type": "string"
}
],
"cases": [
{
"name": "Dotproduct",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [3, 3],
"type": "float32"
},
{
"data": [1, 0, 0, 0, 1, 0, 0, 0, 1],
"dims": [3, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 1, 1, 5, 5, 5, 9, 9, 9],
"dims": [3, 3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ij,jk->ik",
"type": "string"
}
],
"cases": [
{
"name": "Multiply",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [38, 44, 50, 56, 83, 98, 113, 128],
"dims": [2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ij->ji",
"type": "string"
}
],
"cases": [
{
"name": "Transpose",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 4, 2, 5, 3, 6],
"dims": [3, 2],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ij->i",
"type": "string"
}
],
"cases": [
{
"name": "ReduceSum",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [3, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [6, 15, 24],
"dims": [3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ii->i",
"type": "string"
}
],
"cases": [
{
"name": "Diagonal",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [3, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9],
"dims": [3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "ij...,jk...->ik...",
"type": "string"
}
],
"cases": [
{
"name": "Multiply with ellipsis - A",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3, 1],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [3, 4, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [38, 44, 50, 56, 83, 98, 113, 128],
"dims": [2, 4, 1],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "...ij,...jk->...ik",
"type": "string"
}
],
"cases": [
{
"name": "Multiply with ellipsis - B",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [1, 2, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [1, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [38, 44, 50, 56, 83, 98, 113, 128],
"dims": [1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "i...j,j...k->i...k",
"type": "string"
}
],
"cases": [
{
"name": "Multiply with ellipsis - C",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 1, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [3, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [38, 44, 50, 56, 83, 98, 113, 128],
"dims": [2, 1, 4],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "...ij,jk->...ik",
"type": "string"
}
],
"cases": [
{
"name": "Multiply with ellipsis - D",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [1, 2, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
"dims": [3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [38, 44, 50, 56, 83, 98, 113, 128],
"dims": [1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "...ij->...ji",
"type": "string"
}
],
"cases": [
{
"name": "Transpose with ellipsis",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [1, 2, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 4, 2, 5, 3, 6],
"dims": [1, 3, 2],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "...ij->...i",
"type": "string"
}
],
"cases": [
{
"name": "ReduceSum with ellipsis",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [1, 3, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [6, 15, 24],
"dims": [1, 3],
"type": "float32"
}
]
}
]
},
{
"name": "einsum",
"operator": "Einsum",
"opset": {
"domain": "",
"version": 12
},
"attributes": [
{
"name": "equation",
"data": "...ii->...i",
"type": "string"
}
],
"cases": [
{
"name": "Diagonal with ellipsis",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
"dims": [1, 3, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 5, 9],
"dims": [1, 3],
"type": "float32"
}
]
}
]
}
]

View file

@ -319,6 +319,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);
std::unique_ptr<KernelRegistry> RegisterKernels() {
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
@ -573,6 +575,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -9,9 +9,7 @@ const void* JsepOutput(void* context, int index, const void* data) {
const uint32_t* data_offset = reinterpret_cast<const uint32_t*>(data);
uint32_t dim = *data_offset++;
size_t dim_size = static_cast<size_t>(dim);
std::vector<int64_t> dims;
dims.reserve(dim_size);
dims.resize(dim_size);
std::vector<int64_t> dims(dim_size);
for (size_t i = 0; i < dim_size; i++) {
dims[i] = static_cast<int64_t>(*data_offset++);
}

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/js/js_kernel.h"
#include "einsum.h"
namespace onnxruntime {
namespace js {
ONNX_OPERATOR_TYPED_KERNEL_EX(
Einsum,
kOnnxDomain,
12,
float,
kJsExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Einsum);
} // namespace js
} // namespace onnxruntime

View file

@ -0,0 +1,25 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace js {
class Einsum final : public JsKernel {
public:
Einsum(const OpKernelInfo& info) : JsKernel(info) {
std::string equation;
ORT_ENFORCE(info.GetAttr<std::string>("equation", &equation).IsOK(),
"Missing 'equation' attribute");
JSEP_INIT_KERNEL_ATTRIBUTE(Einsum, ({
"equation" : UTF8ToString($1),
}),
equation.c_str());
}
};
} // namespace js
} // namespace onnxruntime