mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[Web/JS] Added Expand operator support. (#16577)
### Description Added Expand operator support. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
1b07bbceaa
commit
d41bbac7b9
7 changed files with 158 additions and 2 deletions
|
|
@ -30,6 +30,7 @@ Do not modify directly.*
|
|||
| Elu | ai.onnx(6+) | |
|
||||
| Erf | ai.onnx(9-12,13+) | |
|
||||
| Exp | ai.onnx(6-12,13+) | |
|
||||
| Expand | ai.onnx(8-12,13+) | |
|
||||
| Floor | ai.onnx(6-12,13+) | |
|
||||
| Gemm | ai.onnx(7-8,9-10,11+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import * as binaryOps from './ops/binary-op';
|
|||
import {concat, parseConcatAttributes} from './ops/concat';
|
||||
import {conv, parseConvAttributes} from './ops/conv';
|
||||
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
|
||||
import {expand} from './ops/expand';
|
||||
import {gemm, parseGemmAttributes} from './ops/gemm';
|
||||
import {matMul} from './ops/matmul';
|
||||
import * as pool from './ops/pool';
|
||||
|
|
@ -41,6 +42,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
|
||||
['Erf', [unaryOps.erf]],
|
||||
['Exp', [unaryOps.exp]],
|
||||
['Expand', [expand]],
|
||||
['Floor', [unaryOps.floor]],
|
||||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
|
|
|
|||
105
js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Normal file
105
js/web/lib/wasm/jsep/webgpu/ops/expand.ts
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
|
||||
|
||||
import {createIndicesHelper, ShaderHelper} from './common';
|
||||
|
||||
export const expandProgramMetadata = {
|
||||
name: 'Expand',
|
||||
inputTypes: [GpuDataType.default]
|
||||
};
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
throw new Error('Expand requires 2 input.');
|
||||
}
|
||||
const inputShape = inputs[0].dims;
|
||||
|
||||
const shape: number[] = [];
|
||||
if (inputs[1].dims[0] > 0) {
|
||||
inputs[1].getBigInt64Array().forEach(v => shape.push(Number(v)));
|
||||
}
|
||||
let shapeIndex = shape.length < inputShape.length ? 0 : shape.length - inputShape.length;
|
||||
let inputShapeIndex = inputShape.length < shape.length ? 0 : inputShape.length - shape.length;
|
||||
for (; shapeIndex < shape.length && inputShapeIndex < inputShape.length; ++shapeIndex, ++inputShapeIndex) {
|
||||
if (shape[shapeIndex] !== inputShape[inputShapeIndex] && shape[shapeIndex] !== 1 &&
|
||||
inputShape[inputShapeIndex] !== 1) {
|
||||
throw new Error('Expand requires shape to be broadcastable to input');
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const getAdjustedShape = (shape1: readonly number[], shape2: readonly number[]): number[] => {
|
||||
const diff = shape1.length - shape2.length;
|
||||
const shape: number[] = [];
|
||||
for (let i = 0; i < diff; ++i) {
|
||||
shape.push(shape1[i]);
|
||||
}
|
||||
for (let i = 0; i < shape2.length; ++i) {
|
||||
shape.push(shape2[i] === 1 ? shape1[i + diff] : shape2[i]);
|
||||
}
|
||||
return shape;
|
||||
};
|
||||
|
||||
const calculateOutputShape = (inputShape: readonly number[], shape: readonly number[]): number[] =>
|
||||
(inputShape.length > shape.length) ? getAdjustedShape(inputShape, shape) : getAdjustedShape(shape, inputShape);
|
||||
|
||||
|
||||
const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
|
||||
const shape: number[] = [];
|
||||
if (inputs[1].dims[0] > 0) {
|
||||
inputs[1].getBigInt64Array().forEach(v => shape.push(Number(v)));
|
||||
}
|
||||
const outputShape: number[] = calculateOutputShape(inputShape, shape);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const inputIndicesHelper = createIndicesHelper('input', inputShape);
|
||||
const outputIndicesHelper = createIndicesHelper('output', outputShape);
|
||||
const dataType = 'f32';
|
||||
|
||||
const calculateInputIndexImpl = (): string => `
|
||||
fn calculateInputIndex(outputIndices: array<u32, ${outputShape.length}>) -> array<u32,${inputShape.length}> {
|
||||
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
|
||||
for (var i = 0; i < ${inputShape.length}; i++) {
|
||||
if (inputShape[i] == 1) {
|
||||
inputIndices[i] = 0;
|
||||
} else {
|
||||
inputIndices[i] = outputIndices[i + ${outputShape.length - inputShape.length}];
|
||||
}
|
||||
}
|
||||
return inputIndices;
|
||||
}`;
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const inputShape = array<u32, ${inputShape.length}>(${inputShape.map(i => `${i}u`).join(',')});
|
||||
${calculateInputIndexImpl()};
|
||||
@group(0) @binding(0) var<storage, read> input : array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> output : array<${dataType}>;
|
||||
${outputIndicesHelper.o2iImpl}
|
||||
${inputIndicesHelper.i2oImpl}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
${inputIndicesHelper.indicesVariableDeclaration('inputIndices')}
|
||||
${outputIndicesHelper.indicesVariableDeclaration('outputIndices')}
|
||||
${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')}
|
||||
inputIndices = calculateInputIndex(outputIndices);
|
||||
output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}];
|
||||
}`;
|
||||
return {
|
||||
...metadata,
|
||||
getShaderSource,
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
|
||||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)})
|
||||
};
|
||||
};
|
||||
|
||||
export const expand = (context: ComputeContext): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute(
|
||||
{...expandProgramMetadata, get: () => createExpandProgramInfo(expandProgramMetadata, context.inputs)},
|
||||
{inputs: [0]});
|
||||
};
|
||||
|
|
@ -519,8 +519,8 @@
|
|||
"test_erf",
|
||||
"test_exp_example",
|
||||
"test_exp",
|
||||
// "test_expand_dim_changed",
|
||||
// "test_expand_dim_unchanged",
|
||||
"test_expand_dim_changed",
|
||||
"test_expand_dim_unchanged",
|
||||
// "test_eyelike_populate_off_main_diagonal",
|
||||
// "test_eyelike_with_dtype",
|
||||
// "test_eyelike_without_dtype",
|
||||
|
|
|
|||
|
|
@ -243,6 +243,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand);
|
||||
|
||||
std::unique_ptr<KernelRegistry> RegisterKernels() {
|
||||
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
|
||||
|
|
@ -424,6 +426,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
30
onnxruntime/core/providers/js/operators/expand.cc
Normal file
30
onnxruntime/core/providers/js/operators/expand.cc
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
#include "expand.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
Expand,
|
||||
kOnnxDomain,
|
||||
8,
|
||||
12,
|
||||
kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Expand);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Expand,
|
||||
kOnnxDomain,
|
||||
13,
|
||||
kJsExecutionProvider,
|
||||
KernelDefBuilder()
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1),
|
||||
Expand);
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
14
onnxruntime/core/providers/js/operators/expand.h
Normal file
14
onnxruntime/core/providers/js/operators/expand.h
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
JSEP_KERNEL_IMPL(Expand, Expand);
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue