diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 87b13cba8f..f6b82b4ac7 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -30,6 +30,7 @@ Do not modify directly.* | Elu | ai.onnx(6+) | | | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | +| Expand | ai.onnx(8-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | Gemm | ai.onnx(7-8,9-10,11+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 7a6d2927eb..6a43cca4e1 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -5,6 +5,7 @@ import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; +import {expand} from './ops/expand'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {matMul} from './ops/matmul'; import * as pool from './ops/pool'; @@ -41,6 +42,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]], ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], + ['Expand', [expand]], ['Floor', [unaryOps.floor]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts new file mode 100644 index 0000000000..e8f213fe88 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {createIndicesHelper, ShaderHelper} from './common'; + +export const expandProgramMetadata = { + name: 'Expand', + inputTypes: [GpuDataType.default] +}; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { + throw new Error('Expand requires 2 input.'); + } + const inputShape = inputs[0].dims; + + const shape: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach(v => shape.push(Number(v))); + } + let shapeIndex = shape.length < inputShape.length ? 0 : shape.length - inputShape.length; + let inputShapeIndex = inputShape.length < shape.length ? 0 : inputShape.length - shape.length; + for (; shapeIndex < shape.length && inputShapeIndex < inputShape.length; ++shapeIndex, ++inputShapeIndex) { + if (shape[shapeIndex] !== inputShape[inputShapeIndex] && shape[shapeIndex] !== 1 && + inputShape[inputShapeIndex] !== 1) { + throw new Error('Expand requires shape to be broadcastable to input'); + } + } +}; + +const getAdjustedShape = (shape1: readonly number[], shape2: readonly number[]): number[] => { + const diff = shape1.length - shape2.length; + const shape: number[] = []; + for (let i = 0; i < diff; ++i) { + shape.push(shape1[i]); + } + for (let i = 0; i < shape2.length; ++i) { + shape.push(shape2[i] === 1 ? shape1[i + diff] : shape2[i]); + } + return shape; +}; + +const calculateOutputShape = (inputShape: readonly number[], shape: readonly number[]): number[] => + (inputShape.length > shape.length) ? getAdjustedShape(inputShape, shape) : getAdjustedShape(shape, inputShape); + + +const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const inputShape = inputs[0].dims; + + const shape: number[] = []; + if (inputs[1].dims[0] > 0) { + inputs[1].getBigInt64Array().forEach(v => shape.push(Number(v))); + } + const outputShape: number[] = calculateOutputShape(inputShape, shape); + const outputSize = ShapeUtil.size(outputShape); + const inputIndicesHelper = createIndicesHelper('input', inputShape); + const outputIndicesHelper = createIndicesHelper('output', outputShape); + const dataType = 'f32'; + + const calculateInputIndexImpl = (): string => ` + fn calculateInputIndex(outputIndices: array) -> array { + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + for (var i = 0; i < ${inputShape.length}; i++) { + if (inputShape[i] == 1) { + inputIndices[i] = 0; + } else { + inputIndices[i] = outputIndices[i + ${outputShape.length - inputShape.length}]; + } + } + return inputIndices; +}`; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); + ${calculateInputIndexImpl()}; + @group(0) @binding(0) var input : array<${dataType}>; + @group(0) @binding(1) var output : array<${dataType}>; + ${outputIndicesHelper.o2iImpl} + ${inputIndicesHelper.i2oImpl} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + ${inputIndicesHelper.indicesVariableDeclaration('inputIndices')} + ${outputIndicesHelper.indicesVariableDeclaration('outputIndices')} + ${outputIndicesHelper.o2iCall('global_idx', 'outputIndices')} + inputIndices = calculateInputIndex(outputIndices); + output[global_idx] = input[${inputIndicesHelper.i2oExpression('inputIndices')}]; +}`; + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; +}; + +export const expand = (context: ComputeContext): void => { + validateInputs(context.inputs); + context.compute( + {...expandProgramMetadata, get: () => createExpandProgramInfo(expandProgramMetadata, context.inputs)}, + {inputs: [0]}); +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index d45e461994..e5ca91e2f1 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -519,8 +519,8 @@ "test_erf", "test_exp_example", "test_exp", - // "test_expand_dim_changed", - // "test_expand_dim_unchanged", + "test_expand_dim_changed", + "test_expand_dim_unchanged", // "test_eyelike_populate_off_main_diagonal", // "test_eyelike_with_dtype", // "test_eyelike_without_dtype", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 112501e193..9a0c2a0c64 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -243,6 +243,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Split); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Split); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 12, Expand); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Expand); std::unique_ptr RegisterKernels() { auto kernel_registry = std::make_unique(); @@ -424,6 +426,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/js/operators/expand.cc b/onnxruntime/core/providers/js/operators/expand.cc new file mode 100644 index 0000000000..61d6511a37 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/expand.cc @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" +#include "expand.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Expand, + kOnnxDomain, + 8, + 12, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Expand); + +ONNX_OPERATOR_KERNEL_EX( + Expand, + kOnnxDomain, + 13, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1), + Expand); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/expand.h b/onnxruntime/core/providers/js/operators/expand.h new file mode 100644 index 0000000000..b259fe80d1 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/expand.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Expand, Expand); + +} // namespace js +} // namespace onnxruntime