mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[JS/Web] Added Gelu contrib operator support to JSEP (#16909)
### Description Added Gelu operator to JSEP ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
92b6e10d37
commit
dd24d52737
15 changed files with 242 additions and 20 deletions
|
|
@ -93,6 +93,11 @@ file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cu_srcs CONFIGURE_DEPENDS
|
|||
"${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cuh"
|
||||
)
|
||||
|
||||
file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
|
||||
)
|
||||
|
||||
file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/*.cc"
|
||||
|
|
@ -1158,8 +1163,12 @@ if (onnxruntime_USE_JSEP)
|
|||
"${ONNXRUNTIME_ROOT}/core/providers/js/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/js/*.cc"
|
||||
)
|
||||
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
|
||||
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_js_contrib_ops_cc_srcs})
|
||||
list(APPEND onnxruntime_providers_js_cc_srcs ${onnxruntime_js_contrib_ops_cc_srcs})
|
||||
endif()
|
||||
|
||||
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_js_cc_srcs})
|
||||
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_js_cc_srcs})
|
||||
onnxruntime_add_static_library(onnxruntime_providers_js ${onnxruntime_providers_js_cc_srcs})
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_js
|
||||
onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11
|
||||
|
|
|
|||
|
|
@ -11,16 +11,17 @@
|
|||
|
||||
#ifndef SHARED_PROVIDER
|
||||
#include <functional>
|
||||
|
||||
#include "core/common/exceptions.h"
|
||||
#include "core/common/logging/logging.h"
|
||||
#include "core/common/status.h"
|
||||
#include "core/framework/execution_provider.h"
|
||||
#include "core/framework/kernel_def_builder.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "core/framework/op_kernel_info.h"
|
||||
#include "core/framework/op_node_proto_helper.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "core/framework/sparse_tensor.h"
|
||||
#include "core/framework/tensor.h"
|
||||
#include "core/graph/constants.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
|
@ -28,9 +29,9 @@
|
|||
#else
|
||||
#include "onnx/defs/data_type_utils.h"
|
||||
#endif
|
||||
#include "onnx/onnx_pb.h"
|
||||
#include "onnx/onnx-operators_pb.h"
|
||||
#include "core/common/gsl.h"
|
||||
#include "onnx/onnx-operators_pb.h"
|
||||
#include "onnx/onnx_pb.h"
|
||||
namespace onnxruntime {
|
||||
class OpKernelContext;
|
||||
}
|
||||
|
|
@ -187,6 +188,13 @@ KernelCreateInfo BuildKernelCreateInfo();
|
|||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
template <typename T>
|
||||
KernelCreateInfo BuildKernelCreateInfo();
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
|
||||
namespace contrib {
|
||||
namespace rocm {
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ Do not modify directly.*
|
|||
| Expand | ai.onnx(8-12,13+) | |
|
||||
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
|
||||
| Floor | ai.onnx(6-12,13+) | |
|
||||
| Gelu | com.microsoft(1+) | |
|
||||
| Gemm | ai.onnx(7-8,9-10,11+) | |
|
||||
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import {concat, parseConcatAttributes} from './ops/concat';
|
|||
import {conv, parseConvAttributes} from './ops/conv';
|
||||
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
|
||||
import {expand} from './ops/expand';
|
||||
import {gelu} from './ops/gelu';
|
||||
import {gemm, parseGemmAttributes} from './ops/gemm';
|
||||
import {matMul} from './ops/matmul';
|
||||
import * as pool from './ops/pool';
|
||||
|
|
@ -45,6 +46,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Exp', [unaryOps.exp]],
|
||||
['Expand', [expand]],
|
||||
['Floor', [unaryOps.floor]],
|
||||
['Gelu', [gelu]],
|
||||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
//
|
||||
// modified to fit the needs of the project
|
||||
|
||||
export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid';
|
||||
export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu';
|
||||
|
||||
export const typeSnippet = (component: number) => {
|
||||
switch (component) {
|
||||
|
|
|
|||
53
js/web/lib/wasm/jsep/webgpu/ops/gelu.ts
Normal file
53
js/web/lib/wasm/jsep/webgpu/ops/gelu.ts
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
|
||||
|
||||
import {ShaderHelper} from './common';
|
||||
import {erfImpl} from './unary-op';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
throw new Error('Gelu requires 1 input');
|
||||
}
|
||||
if (inputs[0].dataType !== DataType.float) {
|
||||
throw new Error('Input must be float');
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
const createGeluProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims;
|
||||
const outputShape = inputShape.slice(0);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const dataType = 'f32';
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${erfImpl('f32')};
|
||||
@group(0) @binding(0) var<storage, read> input: array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> output: array<${dataType}>;
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
var x = input[global_id.x];
|
||||
output[global_id.x] = 0.5 * x * (1.0 + erf_vf32(x * 0.7071067811865475));
|
||||
}`;
|
||||
return {
|
||||
...metadata,
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
|
||||
getShaderSource,
|
||||
dispatchGroup: () => ({x: Math.ceil(outputSize / 64)})
|
||||
};
|
||||
};
|
||||
|
||||
const createGeluProgramInfoLoader = (inputs: readonly TensorView[]): ProgramInfoLoader => {
|
||||
const metadata: ProgramMetadata = {name: 'Gelu', inputTypes: [GpuDataType.default]};
|
||||
return {...metadata, get: () => createGeluProgramInfo(metadata, inputs)};
|
||||
};
|
||||
|
||||
export const gelu = (context: ComputeContext): void => {
|
||||
validateInputs(context.inputs);
|
||||
// const erfValue = erfImpl(context.inputs[0].getFloat32Array().map(x => x * 0.7071067811865475));
|
||||
context.compute(createGeluProgramInfoLoader(context.inputs));
|
||||
};
|
||||
|
|
@ -145,20 +145,23 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
|
|||
attributes.cacheKey));
|
||||
};
|
||||
|
||||
export const erf = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, `
|
||||
const r0: f32 = 0.3275911;
|
||||
const r1: f32 = 0.254829592;
|
||||
const r2: f32 = -0.284496736;
|
||||
const r3: f32 = 1.421413741;
|
||||
const r4: f32 = -1.453152027;
|
||||
const r5: f32 = 1.061405429;
|
||||
export const erfImpl = (dataType: string) => `
|
||||
const r0: f32 = 0.3275911;
|
||||
const r1: f32 = 0.254829592;
|
||||
const r2: f32 = -0.284496736;
|
||||
const r3: f32 = 1.421413741;
|
||||
const r4: f32 = -1.453152027;
|
||||
const r5: f32 = 1.061405429;
|
||||
|
||||
fn erf_vf32(v: vec4<f32>) -> vec4<f32> {
|
||||
let absv = abs(v);
|
||||
let x = 1.0 / (1.0 + r0 * absv);
|
||||
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
|
||||
}`));
|
||||
fn erf_vf32(v: ${dataType}) -> ${dataType} {
|
||||
let absv = abs(v);
|
||||
let x = 1.0 / (1.0 + r0 * absv);
|
||||
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
|
||||
}`;
|
||||
|
||||
export const erf = (context: ComputeContext): void => {
|
||||
context.compute(
|
||||
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
|
||||
};
|
||||
|
||||
export const exp = (context: ComputeContext): void => {
|
||||
|
|
|
|||
|
|
@ -33,7 +33,8 @@ const ALL_REGISTERED_OPERATORS: Map < string, {
|
|||
|
||||
// parse js_execution_provider.cc
|
||||
const JS_EXECUTION_PROVIDER_CONTENTS =
|
||||
fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8');
|
||||
fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8') +
|
||||
fs.readFileSync(path.join(__dirname, '../../../onnxruntime/contrib_ops/js/js_contrib_kernels.cc'), 'utf8');
|
||||
MATCHERS.forEach(m => {
|
||||
for (const match of JS_EXECUTION_PROVIDER_CONTENTS.matchAll(m)) {
|
||||
const groups = match.groups!;
|
||||
|
|
@ -50,6 +51,9 @@ MATCHERS.forEach(m => {
|
|||
case 'kMSInternalNHWCDomain':
|
||||
domain = 'com.ms.internal.nhwc';
|
||||
break;
|
||||
case 'kMSDomain':
|
||||
domain = 'com.microsoft';
|
||||
break;
|
||||
default:
|
||||
throw new Error(`not supported domain: ${opsetDomain}`);
|
||||
}
|
||||
|
|
|
|||
44
js/web/test/data/ops/gelu.jsonc
Normal file
44
js/web/test/data/ops/gelu.jsonc
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
[
|
||||
{
|
||||
"name": "gelu",
|
||||
"operator": "Gelu",
|
||||
"opsets": [{ "domain": "com.microsoft", "version": 1 }],
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[0]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.0, -2.0, 0, 2.0],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 0, 0, 2.0],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Scalar",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1347,6 +1347,7 @@
|
|||
"leaky-relu.jsonc",
|
||||
"reduce-min.jsonc",
|
||||
"relu.jsonc",
|
||||
"gelu.jsonc",
|
||||
//"pad.jsonc",
|
||||
//"pad-big.jsonc",
|
||||
"pow.jsonc",
|
||||
|
|
|
|||
21
onnxruntime/contrib_ops/js/gelu.cc
Normal file
21
onnxruntime/contrib_ops/js/gelu.cc
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gelu.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Gelu,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gelu);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
17
onnxruntime/contrib_ops/js/gelu.h
Normal file
17
onnxruntime/contrib_ops/js/gelu.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsKernel;
|
||||
JSEP_KERNEL_IMPL(Gelu, Gelu);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
34
onnxruntime/contrib_ops/js/js_contrib_kernels.cc
Normal file
34
onnxruntime/contrib_ops/js/js_contrib_kernels.cc
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "contrib_ops/js/js_contrib_kernels.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
KernelCreateInfo info;
|
||||
return info;
|
||||
}
|
||||
|
||||
Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
KernelCreateInfo info = function_table_entry();
|
||||
if (info.kernel_def != nullptr) { // filter disabled entries where type is void
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
17
onnxruntime/contrib_ops/js/js_contrib_kernels.h
Normal file
17
onnxruntime/contrib_ops/js/js_contrib_kernels.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/framework/kernel_registry.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
Status RegisterJsContribKernels(KernelRegistry& kernel_registry);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -8,6 +8,10 @@
|
|||
|
||||
#include "js_execution_provider.h"
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
#include "contrib_ops/js/js_contrib_kernels.h"
|
||||
#endif
|
||||
|
||||
#include "core/graph/function_utils.h"
|
||||
#include "core/framework/compute_capability.h"
|
||||
#include "core/framework/data_transfer_manager.h"
|
||||
|
|
@ -485,6 +489,10 @@ std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapabili
|
|||
|
||||
std::shared_ptr<KernelRegistry> JsExecutionProvider::GetKernelRegistry() const {
|
||||
static std::shared_ptr<KernelRegistry> registry = js::RegisterKernels();
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
Status status = ::onnxruntime::contrib::js::RegisterJsContribKernels(*registry);
|
||||
ORT_ENFORCE(status.IsOK(), "Failed to register JS contrib kernels: " + status.ErrorMessage());
|
||||
#endif
|
||||
return registry;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue