mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[JS/Web] Added Einsum operator support. (#17401)
### Description Added Einsum operator support to JSEP. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
850baced33
commit
bf6d6961cc
8 changed files with 979 additions and 3 deletions
|
|
@ -30,6 +30,7 @@ Do not modify directly.*
|
|||
| Cos | ai.onnx(7+) | |
|
||||
| Cosh | ai.onnx(9+) | |
|
||||
| Div | ai.onnx(7-12,13,14+) | |
|
||||
| Einsum | ai.onnx(12+) | |
|
||||
| Elu | ai.onnx(6+) | |
|
||||
| Equal | ai.onnx(7-10,11-12,13-18,19+) | |
|
||||
| Erf | ai.onnx(9-12,13+) | |
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import * as binaryOps from './ops/binary-op';
|
|||
import {concat, parseConcatAttributes} from './ops/concat';
|
||||
import {conv, parseConvAttributes} from './ops/conv';
|
||||
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
|
||||
import {einsum, parseEinsumAttributes} from './ops/einsum';
|
||||
import {expand} from './ops/expand';
|
||||
import {gather, parseGatherAttributes} from './ops/gather';
|
||||
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
|
||||
|
|
@ -52,6 +53,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Cos', [unaryOps.cos]],
|
||||
['Cosh', [unaryOps.cosh]],
|
||||
['Div', [binaryOps.div]],
|
||||
['Einsum', [einsum, parseEinsumAttributes]],
|
||||
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
|
||||
['Equal', [binaryOps.equal]],
|
||||
['Erf', [unaryOps.erf]],
|
||||
|
|
|
|||
290
js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Normal file
290
js/web/lib/wasm/jsep/webgpu/ops/einsum.ts
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface EinsumAttributes extends AttributeWithCacheKey {
|
||||
readonly equation: string;
|
||||
}
|
||||
// The equation attribute value is a string which consists of left hand side (LHS) and optionally right hand side (RHS)
|
||||
// separated by '->'. Ex. "ij,jk -> ik" expresses matrix multiplication
|
||||
// "ij->ji" expresses matrix transpose
|
||||
// "ii->i" diagonal elements of a square matrix
|
||||
// LHS consists of a sequence of terms separated by commas. Each term corresponds to an input variable.
|
||||
// Each symbol corresponds to a dimension in the input variable. The symbol can be either a letter, 'a' to 'z' or 'A' to
|
||||
// 'Z' or '...' to represent arbitrary dimensions.
|
||||
|
||||
const symbolPattern =
|
||||
'[a-zA-Z]|\\.\\.\\.'; // The pattern each symbol in each term in the symbolic equation should match
|
||||
const termPattern = '(' + symbolPattern + ')+'; // The pattern each term in the symbolic equation should match
|
||||
const termPatternOnly = '^' + termPattern + '$'; // The patterns only matchs a term begin to end.
|
||||
const lhsPattern = '(' + termPattern + ',)*' + termPattern; // The pattern the LHS should match
|
||||
const lhsPatternOnly = '^' + lhsPattern + '$'; // The patterns only matchs a LHS begin to end.
|
||||
|
||||
interface SymbolInfo {
|
||||
count: number; // Symbol corresponding to a dimmension of an input
|
||||
inputIndices: number[]; // Number of input variables the symbol corresponds to
|
||||
dimValue: number; // Number of dimensions the symbol corresponds to
|
||||
}
|
||||
|
||||
class EinsumTerm {
|
||||
constructor(inputIndex = -1) {
|
||||
this.symbolToIndices = new Map<string, number[]>();
|
||||
this.inputIndex = inputIndex;
|
||||
}
|
||||
|
||||
// Add a symbol to the term
|
||||
addSymbol(symbol: string, index: number) {
|
||||
let value = this.symbolToIndices.get(symbol);
|
||||
if (value === undefined) {
|
||||
value = [index];
|
||||
} else {
|
||||
value.push(index);
|
||||
}
|
||||
this.symbolToIndices.set(symbol, value);
|
||||
}
|
||||
|
||||
symbolToIndices: Map<string, number[]>; // Map from symbol to dimensions of the input corresponding to the term
|
||||
inputIndex: number; // -1 for output and 0, 1, 2, ... for inputs
|
||||
}
|
||||
|
||||
class EinsumEquation {
|
||||
constructor(inputs: readonly TensorView[], equation: string) {
|
||||
this.hasEllipsis = false;
|
||||
this.symbolToInfo = new Map<string, SymbolInfo>();
|
||||
this.lhs = new Array<EinsumTerm>();
|
||||
this.outputDims = [];
|
||||
// As rhs needs to be updated allow using let instead of const for both lhs and rhs.
|
||||
// eslint-disable-next-line prefer-const
|
||||
let [lhs, rhs] = equation.includes('->') ? equation.split('->', 2) : [equation, ''];
|
||||
if (!lhs.match(RegExp(lhsPatternOnly))) {
|
||||
throw new Error('Invalid LHS term');
|
||||
}
|
||||
const inputTerms = lhs.split(',');
|
||||
inputTerms.forEach((inputTerm, index) => {
|
||||
const dims = inputs[index].dims.slice();
|
||||
if (!inputTerm.match(RegExp(termPatternOnly))) {
|
||||
throw new Error('Invalid LHS term');
|
||||
}
|
||||
const einsumTerm = this.processTerm(inputTerm, true, dims, index);
|
||||
this.lhs.push(einsumTerm);
|
||||
});
|
||||
|
||||
// Initialize the RHS if not specified
|
||||
if (rhs === '') {
|
||||
// Construct RHS from LHS terms/symbols
|
||||
rhs += [...this.symbolToInfo.entries()]
|
||||
.filter(([sym, info]) => (info.count === 1 || sym === '...'))
|
||||
.map(([sym]) => sym)
|
||||
.join('');
|
||||
} else {
|
||||
if (!rhs.match(RegExp(termPattern))) {
|
||||
throw new Error('Invalid RHS');
|
||||
}
|
||||
}
|
||||
|
||||
// Compute output dims
|
||||
const rhsSymbols = rhs.match(RegExp(symbolPattern, 'g'));
|
||||
rhsSymbols?.forEach((symbol) => {
|
||||
if (symbol === '...') {
|
||||
this.outputDims = this.outputDims.concat(this.ellipsisDims);
|
||||
} else {
|
||||
const info = this.symbolToInfo.get(symbol);
|
||||
if (info === undefined) {
|
||||
throw new Error('Invalid RHS symbol');
|
||||
}
|
||||
this.outputDims.push(info.dimValue);
|
||||
}
|
||||
});
|
||||
this.rhs = this.processTerm(rhs, true, this.outputDims);
|
||||
} // End of EinsumEqation constructor
|
||||
|
||||
// Add a symbol to the equation
|
||||
addSymbol(symbol: string, dimValue: number, inputIndex: number) {
|
||||
let info = this.symbolToInfo.get(symbol);
|
||||
if (info !== undefined) {
|
||||
if (info.dimValue !== dimValue && info.count !== 1) {
|
||||
throw new Error('Dimension mismatch');
|
||||
} else {
|
||||
info.count++;
|
||||
info.inputIndices.push(inputIndex);
|
||||
}
|
||||
} else {
|
||||
info = {count: 1, dimValue, inputIndices: [inputIndex]};
|
||||
}
|
||||
this.symbolToInfo.set(symbol, info);
|
||||
}
|
||||
|
||||
// Process one input/output term
|
||||
processTerm(term: string, isInput: boolean, dims: readonly number[], index = -1): EinsumTerm {
|
||||
const rank = dims.length;
|
||||
let ellipsis = false;
|
||||
let ellipsisDims = [];
|
||||
let nextDim = 0;
|
||||
// For output empty string is allowed because the output may be reduced to a scalar value
|
||||
if (!term.match(RegExp(termPatternOnly)) && (!isInput && term !== '')) {
|
||||
throw new Error('Invalid LHS term');
|
||||
}
|
||||
const indexSymbols = term.match(RegExp(symbolPattern, 'g'));
|
||||
const einsumTerm = new EinsumTerm(index);
|
||||
// symbol can be either a lettre, 'a' to 'z' or 'A' to 'Z', or '...'
|
||||
indexSymbols?.forEach((symbol: string, i: number) => {
|
||||
if (symbol === '...') {
|
||||
if (ellipsis) {
|
||||
throw new Error('Only one ellipsis is allowed per input term');
|
||||
}
|
||||
ellipsis = true;
|
||||
const ellipsisDimLength = rank - indexSymbols.length + 1;
|
||||
if (ellipsisDimLength < 0) {
|
||||
throw new Error('Ellipsis out of bounds');
|
||||
}
|
||||
ellipsisDims = dims.slice(nextDim, nextDim + ellipsisDimLength);
|
||||
if (this.hasEllipsis) {
|
||||
if (this.ellipsisDims.length !== ellipsisDims.length ||
|
||||
this.ellipsisDims.toString() !== ellipsisDims.toString()) {
|
||||
throw new Error('Ellipsis dimensions mismatch');
|
||||
}
|
||||
} else if (isInput) {
|
||||
this.hasEllipsis = true;
|
||||
this.ellipsisDims = ellipsisDims;
|
||||
} else {
|
||||
throw new Error('Ellipsis must be specified in the LHS');
|
||||
}
|
||||
// Add '0', '1', '2', '3', '4', etc to represent ellipsis dimensions to avoid special handling
|
||||
for (let j = 0; j < ellipsisDims.length; j++) {
|
||||
const symbol = String.fromCharCode('0'.charCodeAt(0) + i);
|
||||
einsumTerm.addSymbol(symbol, i + j);
|
||||
this.addSymbol(symbol, dims[nextDim++], index);
|
||||
}
|
||||
} else {
|
||||
einsumTerm.addSymbol(symbol, i);
|
||||
this.addSymbol(symbol, dims[nextDim++], index);
|
||||
}
|
||||
});
|
||||
return einsumTerm;
|
||||
}
|
||||
|
||||
symbolToInfo: Map<string, SymbolInfo>; // All symbols in the equation
|
||||
hasEllipsis: boolean; // The equation has ellipsis or not
|
||||
ellipsisDims: number[]; // The dimensions of the equation ellipsis corresponds to.
|
||||
lhs: EinsumTerm[]; // Terms on the left-hand side of the equation
|
||||
rhs: EinsumTerm; // Term on the right-hand side of the equation
|
||||
outputDims: number[]; // Output dimensions of the equation
|
||||
} // End of class EinsumEquation
|
||||
|
||||
|
||||
const createEinsumProgramMetadata = (inputCount: number, cacheHint: string): ProgramMetadata =>
|
||||
({name: 'Einsum', inputTypes: Array(inputCount).fill(GpuDataType.default), cacheHint});
|
||||
|
||||
const createEinsumProgramInfo =
|
||||
(metadata: ProgramMetadata, inputs: readonly TensorView[], einsumEquation: EinsumEquation): ProgramInfo => {
|
||||
const dataType = inputs[0].dataType;
|
||||
const inputVars = new Array<IndicesHelper>(inputs.length);
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputs[i].dims);
|
||||
}
|
||||
const outputShape = einsumEquation.outputDims;
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const output = outputVariable('output', dataType, outputShape);
|
||||
const idxCopy: string[] = [];
|
||||
const rhsSymbols = Array.from(einsumEquation.rhs.symbolToIndices.keys());
|
||||
const initProd = 'var prod = 1.0;';
|
||||
const initSum = 'var sum = 0.0;';
|
||||
const updateSum = 'sum += prod;';
|
||||
const reduceOpsSetIndices: string[] = [];
|
||||
const reduceOpsLoopHeaders: string[] = [];
|
||||
const reduceOpsLoopFooters: string[] = [];
|
||||
const reduceOpCompute: string[] = [];
|
||||
const isReduceOpsWithoutLoop = einsumEquation.symbolToInfo.size === rhsSymbols.length;
|
||||
einsumEquation.symbolToInfo.forEach((info, symbol) => {
|
||||
if (rhsSymbols.includes(symbol)) {
|
||||
const outputIndex = rhsSymbols.indexOf(symbol);
|
||||
einsumEquation.lhs.forEach((term, i) => {
|
||||
if (info.inputIndices.includes(i)) {
|
||||
const indices = term.symbolToIndices.get(symbol);
|
||||
if (indices === undefined) {
|
||||
throw new Error('Invalid symbol error');
|
||||
}
|
||||
indices.forEach((index) => {
|
||||
idxCopy.push(`${
|
||||
inputVars[i].indicesSet(
|
||||
`input${i}Indices`, index, output.indicesGet('outputIndices', outputIndex))}`);
|
||||
});
|
||||
}
|
||||
});
|
||||
} else {
|
||||
einsumEquation.lhs.forEach((term, i) => {
|
||||
const info = einsumEquation.symbolToInfo.get(symbol);
|
||||
if (info === undefined) {
|
||||
throw new Error('Invalid symbol error');
|
||||
}
|
||||
if (info.inputIndices.includes(i)) {
|
||||
const indices = term.symbolToIndices.get(symbol);
|
||||
if (indices === undefined) {
|
||||
throw new Error('Invalid symbol error');
|
||||
}
|
||||
indices.forEach((index) => {
|
||||
reduceOpsSetIndices.push(`${inputVars[i].indicesSet(`input${i}Indices`, index, `${symbol}`)}`);
|
||||
});
|
||||
reduceOpCompute.push(`prod *= ${inputVars[i].getByIndices(`input${i}Indices`)};`);
|
||||
}
|
||||
});
|
||||
reduceOpsLoopHeaders.push(`for(var ${symbol}: u32 = 0; ${symbol} < ${
|
||||
einsumEquation.symbolToInfo.get(symbol)?.dimValue}; ${symbol}++) {`);
|
||||
reduceOpsLoopFooters.push('}');
|
||||
}
|
||||
});
|
||||
const reduceOps = isReduceOpsWithoutLoop ?
|
||||
[
|
||||
...idxCopy,
|
||||
`let sum = ${inputVars.map((inputVar, i) => inputVar.getByIndices(`input${i}Indices`)).join(' * ')};`
|
||||
] :
|
||||
[
|
||||
...idxCopy,
|
||||
initSum,
|
||||
...reduceOpsLoopHeaders,
|
||||
...reduceOpsSetIndices,
|
||||
initProd,
|
||||
...reduceOpCompute,
|
||||
updateSum,
|
||||
...reduceOpsLoopFooters,
|
||||
];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.declareVariables(...inputVars, output)}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
var outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
${inputVars.map((inputVar, i) => `var input${i}Indices: ${inputVars[i].type.indices};`).join('\n')}
|
||||
${reduceOps.join('\n')};
|
||||
${output.setByOffset('global_idx', 'sum')};
|
||||
}`;
|
||||
return {
|
||||
...metadata,
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
|
||||
getShaderSource,
|
||||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
const createEinsumProgramInfoLoader =
|
||||
(inputs: readonly TensorView[], einsumEquation: EinsumEquation, attributes: EinsumAttributes):
|
||||
ProgramInfoLoader => {
|
||||
const metadata = createEinsumProgramMetadata(inputs.length, attributes.cacheKey);
|
||||
return {...metadata, get: () => createEinsumProgramInfo(metadata, inputs, einsumEquation)};
|
||||
};
|
||||
|
||||
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
|
||||
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
|
||||
context.compute(createEinsumProgramInfoLoader(context.inputs, einsumEquation, attributes));
|
||||
};
|
||||
|
||||
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
|
||||
const equation = (attributes.equation as string).replace(/\s+/g, '');
|
||||
return createAttributeWithCacheKey({equation});
|
||||
};
|
||||
635
js/web/test/data/ops/einsum.jsonc
Normal file
635
js/web/test/data/ops/einsum.jsonc
Normal file
|
|
@ -0,0 +1,635 @@
|
|||
[
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "i,i",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Dotproduct/scalar product",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 5, 6],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [32],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "i,i->i",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "elementwise product",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 5, 6],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [4, 10, 18],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "i,j",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Product without specifying RSH",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 5, 6],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [4, 5, 6, 8, 10, 12, 12, 15, 18],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "i,j->ij",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Product",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [4, 5, 6],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [4, 5, 6, 8, 10, 12, 12, 15, 18],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ii,jj",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Diagonal elementwise multiplication",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 0, 0, 0, 1, 0, 0, 0, 1],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [45],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ii,jj -> ij",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Dotproduct",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 0, 0, 0, 1, 0, 0, 0, 1],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 1, 1, 5, 5, 5, 9, 9, 9],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ij,jk->ik",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Multiply",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [38, 44, 50, 56, 83, 98, 113, 128],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ij->ji",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Transpose",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 4, 2, 5, 3, 6],
|
||||
"dims": [3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ij->i",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "ReduceSum",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [6, 15, 24],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ii->i",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Diagonal",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 5, 9],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "ij...,jk...->ik...",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Multiply with ellipsis - A",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [2, 3, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [3, 4, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [38, 44, 50, 56, 83, 98, 113, 128],
|
||||
"dims": [2, 4, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "...ij,...jk->...ik",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Multiply with ellipsis - B",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [1, 2, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [1, 3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [38, 44, 50, 56, 83, 98, 113, 128],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "i...j,j...k->i...k",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Multiply with ellipsis - C",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [2, 1, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [3, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [38, 44, 50, 56, 83, 98, 113, 128],
|
||||
"dims": [2, 1, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "...ij,jk->...ik",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Multiply with ellipsis - D",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [1, 2, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
|
||||
"dims": [3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [38, 44, 50, 56, 83, 98, 113, 128],
|
||||
"dims": [1, 2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "...ij->...ji",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Transpose with ellipsis",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6],
|
||||
"dims": [1, 2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 4, 2, 5, 3, 6],
|
||||
"dims": [1, 3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "...ij->...i",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "ReduceSum with ellipsis",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [1, 3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [6, 15, 24],
|
||||
"dims": [1, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "einsum",
|
||||
"operator": "Einsum",
|
||||
"opset": {
|
||||
"domain": "",
|
||||
"version": 12
|
||||
},
|
||||
"attributes": [
|
||||
{
|
||||
"name": "equation",
|
||||
"data": "...ii->...i",
|
||||
"type": "string"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Diagonal with ellipsis",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
|
||||
"dims": [1, 3, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1, 5, 9],
|
||||
"dims": [1, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -319,6 +319,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
||||
|
|
@ -573,6 +575,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum)>,
|
||||
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ const void* JsepOutput(void* context, int index, const void* data) {
|
|||
const uint32_t* data_offset = reinterpret_cast<const uint32_t*>(data);
|
||||
uint32_t dim = *data_offset++;
|
||||
size_t dim_size = static_cast<size_t>(dim);
|
||||
std::vector<int64_t> dims;
|
||||
dims.reserve(dim_size);
|
||||
dims.resize(dim_size);
|
||||
std::vector<int64_t> dims(dim_size);
|
||||
for (size_t i = 0; i < dim_size; i++) {
|
||||
dims[i] = static_cast<int64_t>(*data_offset++);
|
||||
}
|
||||
|
|
|
|||
21
onnxruntime/core/providers/js/operators/einsum.cc
Normal file
21
onnxruntime/core/providers/js/operators/einsum.cc
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "einsum.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX(
|
||||
Einsum,
|
||||
kOnnxDomain,
|
||||
12,
|
||||
float,
|
||||
kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Einsum);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
25
onnxruntime/core/providers/js/operators/einsum.h
Normal file
25
onnxruntime/core/providers/js/operators/einsum.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
class Einsum final : public JsKernel {
|
||||
public:
|
||||
Einsum(const OpKernelInfo& info) : JsKernel(info) {
|
||||
std::string equation;
|
||||
ORT_ENFORCE(info.GetAttr<std::string>("equation", &equation).IsOK(),
|
||||
"Missing 'equation' attribute");
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Einsum, ({
|
||||
"equation" : UTF8ToString($1),
|
||||
}),
|
||||
equation.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue