diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 7fb03487a2..57121edbde 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -93,6 +93,11 @@ file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cu_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cuh" ) +file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc" +) + file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/*.h" "${ONNXRUNTIME_ROOT}/core/providers/*.cc" @@ -1158,8 +1163,12 @@ if (onnxruntime_USE_JSEP) "${ONNXRUNTIME_ROOT}/core/providers/js/*.h" "${ONNXRUNTIME_ROOT}/core/providers/js/*.cc" ) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_js_contrib_ops_cc_srcs}) + list(APPEND onnxruntime_providers_js_cc_srcs ${onnxruntime_js_contrib_ops_cc_srcs}) + endif() - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_js_cc_srcs}) + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_js_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_js ${onnxruntime_providers_js_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_js onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 diff --git a/include/onnxruntime/core/framework/op_kernel.h b/include/onnxruntime/core/framework/op_kernel.h index 4168c23614..36e931c2df 100644 --- a/include/onnxruntime/core/framework/op_kernel.h +++ b/include/onnxruntime/core/framework/op_kernel.h @@ -11,16 +11,17 @@ #ifndef SHARED_PROVIDER #include + #include "core/common/exceptions.h" #include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/framework/execution_provider.h" #include "core/framework/kernel_def_builder.h" -#include "core/framework/ort_value.h" #include "core/framework/op_kernel_info.h" #include "core/framework/op_node_proto_helper.h" -#include "core/framework/tensor.h" +#include "core/framework/ort_value.h" #include "core/framework/sparse_tensor.h" +#include "core/framework/tensor.h" #include "core/graph/constants.h" #include "core/graph/graph_viewer.h" #if !defined(ORT_MINIMAL_BUILD) @@ -28,9 +29,9 @@ #else #include "onnx/defs/data_type_utils.h" #endif -#include "onnx/onnx_pb.h" -#include "onnx/onnx-operators_pb.h" #include "core/common/gsl.h" +#include "onnx/onnx-operators_pb.h" +#include "onnx/onnx_pb.h" namespace onnxruntime { class OpKernelContext; } @@ -187,6 +188,13 @@ KernelCreateInfo BuildKernelCreateInfo(); } // namespace cuda } // namespace contrib +namespace contrib { +namespace js { +template +KernelCreateInfo BuildKernelCreateInfo(); +} // namespace js +} // namespace contrib + namespace contrib { namespace rocm { template diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 722d8d0421..b0cf2ccb3b 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -33,6 +33,7 @@ Do not modify directly.* | Expand | ai.onnx(8-12,13+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | +| Gelu | com.microsoft(1+) | | | Gemm | ai.onnx(7-8,9-10,11+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index c4bb5cf92a..164436f497 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -6,6 +6,7 @@ import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'; import {expand} from './ops/expand'; +import {gelu} from './ops/gelu'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {matMul} from './ops/matmul'; import * as pool from './ops/pool'; @@ -45,6 +46,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Exp', [unaryOps.exp]], ['Expand', [expand]], ['Floor', [unaryOps.floor]], + ['Gelu', [gelu]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index 5345367ead..dd4f13e76e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -19,7 +19,7 @@ // // modified to fit the needs of the project -export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid'; +export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; export const typeSnippet = (component: number) => { switch (component) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/gelu.ts new file mode 100644 index 0000000000..7443f665f7 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/gelu.ts @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {ShaderHelper} from './common'; +import {erfImpl} from './unary-op'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 1) { + throw new Error('Gelu requires 1 input'); + } + if (inputs[0].dataType !== DataType.float) { + throw new Error('Input must be float'); + } +}; + + +const createGeluProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const inputShape = inputs[0].dims; + const outputShape = inputShape.slice(0); + const outputSize = ShapeUtil.size(outputShape); + const dataType = 'f32'; + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${erfImpl('f32')}; + @group(0) @binding(0) var input: array<${dataType}>; + @group(0) @binding(1) var output: array<${dataType}>; + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + var x = input[global_id.x]; + output[global_id.x] = 0.5 * x * (1.0 + erf_vf32(x * 0.7071067811865475)); + }`; + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64)}) + }; +}; + +const createGeluProgramInfoLoader = (inputs: readonly TensorView[]): ProgramInfoLoader => { + const metadata: ProgramMetadata = {name: 'Gelu', inputTypes: [GpuDataType.default]}; + return {...metadata, get: () => createGeluProgramInfo(metadata, inputs)}; +}; + +export const gelu = (context: ComputeContext): void => { + validateInputs(context.inputs); + // const erfValue = erfImpl(context.inputs[0].getFloat32Array().map(x => x * 0.7071067811865475)); + context.compute(createGeluProgramInfoLoader(context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 40d5d848d4..fa914cc78c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -145,20 +145,23 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erf = (context: ComputeContext): void => { - context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, ` - const r0: f32 = 0.3275911; - const r1: f32 = 0.254829592; - const r2: f32 = -0.284496736; - const r3: f32 = 1.421413741; - const r4: f32 = -1.453152027; - const r5: f32 = 1.061405429; +export const erfImpl = (dataType: string) => ` +const r0: f32 = 0.3275911; +const r1: f32 = 0.254829592; +const r2: f32 = -0.284496736; +const r3: f32 = 1.421413741; +const r4: f32 = -1.453152027; +const r5: f32 = 1.061405429; - fn erf_vf32(v: vec4) -> vec4 { - let absv = abs(v); - let x = 1.0 / (1.0 + r0 * absv); - return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); - }`)); +fn erf_vf32(v: ${dataType}) -> ${dataType} { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +}`; + +export const erf = (context: ComputeContext): void => { + context.compute( + createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4'))); }; export const exp = (context: ComputeContext): void => { diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index cc4699398d..dedb1bd536 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -33,7 +33,8 @@ const ALL_REGISTERED_OPERATORS: Map < string, { // parse js_execution_provider.cc const JS_EXECUTION_PROVIDER_CONTENTS = - fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8'); + fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8') + + fs.readFileSync(path.join(__dirname, '../../../onnxruntime/contrib_ops/js/js_contrib_kernels.cc'), 'utf8'); MATCHERS.forEach(m => { for (const match of JS_EXECUTION_PROVIDER_CONTENTS.matchAll(m)) { const groups = match.groups!; @@ -50,6 +51,9 @@ MATCHERS.forEach(m => { case 'kMSInternalNHWCDomain': domain = 'com.ms.internal.nhwc'; break; + case 'kMSDomain': + domain = 'com.microsoft'; + break; default: throw new Error(`not supported domain: ${opsetDomain}`); } diff --git a/js/web/test/data/ops/gelu.jsonc b/js/web/test/data/ops/gelu.jsonc new file mode 100644 index 0000000000..79e4335c2d --- /dev/null +++ b/js/web/test/data/ops/gelu.jsonc @@ -0,0 +1,44 @@ +[ + { + "name": "gelu", + "operator": "Gelu", + "opsets": [{ "domain": "com.microsoft", "version": 1 }], + "attributes": [], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [1.0, -2.0, 0, 2.0], + "dims": [2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 0, 0, 2.0], + "dims": [2, 2], + "type": "float32" + } + ] + }, + { + "name": "Scalar", + "inputs": [ + { + "data": [1.0], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0], + "dims": [], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index c7f09a2768..643a2f1b5a 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1347,6 +1347,7 @@ "leaky-relu.jsonc", "reduce-min.jsonc", "relu.jsonc", + "gelu.jsonc", //"pad.jsonc", //"pad-big.jsonc", "pow.jsonc", diff --git a/onnxruntime/contrib_ops/js/gelu.cc b/onnxruntime/contrib_ops/js/gelu.cc new file mode 100644 index 0000000000..57de4e21a2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/gelu.cc @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +ONNX_OPERATOR_KERNEL_EX( + Gelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/gelu.h b/onnxruntime/contrib_ops/js/gelu.h new file mode 100644 index 0000000000..ca2677ec0e --- /dev/null +++ b/onnxruntime/contrib_ops/js/gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(Gelu, Gelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc new file mode 100644 index 0000000000..3001aae8a9 --- /dev/null +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/js_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); + +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} + +Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + return Status::OK(); +} + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.h b/onnxruntime/contrib_ops/js/js_contrib_kernels.h new file mode 100644 index 0000000000..273065dbde --- /dev/null +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +Status RegisterJsContribKernels(KernelRegistry& kernel_registry); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 6f0a0be780..09a12b178c 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -8,6 +8,10 @@ #include "js_execution_provider.h" +#ifndef DISABLE_CONTRIB_OPS +#include "contrib_ops/js/js_contrib_kernels.h" +#endif + #include "core/graph/function_utils.h" #include "core/framework/compute_capability.h" #include "core/framework/data_transfer_manager.h" @@ -485,6 +489,10 @@ std::vector> JsExecutionProvider::GetCapabili std::shared_ptr JsExecutionProvider::GetKernelRegistry() const { static std::shared_ptr registry = js::RegisterKernels(); +#ifndef DISABLE_CONTRIB_OPS + Status status = ::onnxruntime::contrib::js::RegisterJsContribKernels(*registry); + ORT_ENFORCE(status.IsOK(), "Failed to register JS contrib kernels: " + status.ErrorMessage()); +#endif return registry; }