[JS/Web] Added Gelu contrib operator support to JSEP (#16909)

### Description
Added Gelu operator to JSEP


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
satyajandhyala 2023-07-31 09:18:58 -07:00 committed by GitHub
parent 92b6e10d37
commit dd24d52737
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 242 additions and 20 deletions

View file

@ -93,6 +93,11 @@ file(GLOB_RECURSE onnxruntime_rocm_contrib_ops_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/rocm/*.cuh"
)
file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
)
file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/*.cc"
@ -1158,8 +1163,12 @@ if (onnxruntime_USE_JSEP)
"${ONNXRUNTIME_ROOT}/core/providers/js/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/js/*.cc"
)
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_js_contrib_ops_cc_srcs})
list(APPEND onnxruntime_providers_js_cc_srcs ${onnxruntime_js_contrib_ops_cc_srcs})
endif()
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_js_cc_srcs})
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_js_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_js ${onnxruntime_providers_js_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_js
onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11

View file

@ -11,16 +11,17 @@
#ifndef SHARED_PROVIDER
#include <functional>
#include "core/common/exceptions.h"
#include "core/common/logging/logging.h"
#include "core/common/status.h"
#include "core/framework/execution_provider.h"
#include "core/framework/kernel_def_builder.h"
#include "core/framework/ort_value.h"
#include "core/framework/op_kernel_info.h"
#include "core/framework/op_node_proto_helper.h"
#include "core/framework/tensor.h"
#include "core/framework/ort_value.h"
#include "core/framework/sparse_tensor.h"
#include "core/framework/tensor.h"
#include "core/graph/constants.h"
#include "core/graph/graph_viewer.h"
#if !defined(ORT_MINIMAL_BUILD)
@ -28,9 +29,9 @@
#else
#include "onnx/defs/data_type_utils.h"
#endif
#include "onnx/onnx_pb.h"
#include "onnx/onnx-operators_pb.h"
#include "core/common/gsl.h"
#include "onnx/onnx-operators_pb.h"
#include "onnx/onnx_pb.h"
namespace onnxruntime {
class OpKernelContext;
}
@ -187,6 +188,13 @@ KernelCreateInfo BuildKernelCreateInfo();
} // namespace cuda
} // namespace contrib
namespace contrib {
namespace js {
template <typename T>
KernelCreateInfo BuildKernelCreateInfo();
} // namespace js
} // namespace contrib
namespace contrib {
namespace rocm {
template <typename T>

View file

@ -33,6 +33,7 @@ Do not modify directly.*
| Expand | ai.onnx(8-12,13+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |

View file

@ -6,6 +6,7 @@ import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {expand} from './ops/expand';
import {gelu} from './ops/gelu';
import {gemm, parseGemmAttributes} from './ops/gemm';
import {matMul} from './ops/matmul';
import * as pool from './ops/pool';
@ -45,6 +46,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Exp', [unaryOps.exp]],
['Expand', [expand]],
['Floor', [unaryOps.floor]],
['Gelu', [gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],

View file

@ -19,7 +19,7 @@
//
// modified to fit the needs of the project
export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid';
export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu';
export const typeSnippet = (component: number) => {
switch (component) {

View file

@ -0,0 +1,53 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
import {ShaderHelper} from './common';
import {erfImpl} from './unary-op';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Gelu requires 1 input');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Input must be float');
}
};
const createGeluProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => {
const inputShape = inputs[0].dims;
const outputShape = inputShape.slice(0);
const outputSize = ShapeUtil.size(outputShape);
const dataType = 'f32';
const getShaderSource = (shaderHelper: ShaderHelper) => `
${erfImpl('f32')};
@group(0) @binding(0) var<storage, read> input: array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> output: array<${dataType}>;
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
var x = input[global_id.x];
output[global_id.x] = 0.5 * x * (1.0 + erf_vf32(x * 0.7071067811865475));
}`;
return {
...metadata,
outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(outputSize / 64)})
};
};
const createGeluProgramInfoLoader = (inputs: readonly TensorView[]): ProgramInfoLoader => {
const metadata: ProgramMetadata = {name: 'Gelu', inputTypes: [GpuDataType.default]};
return {...metadata, get: () => createGeluProgramInfo(metadata, inputs)};
};
export const gelu = (context: ComputeContext): void => {
validateInputs(context.inputs);
// const erfValue = erfImpl(context.inputs[0].getFloat32Array().map(x => x * 0.7071067811865475));
context.compute(createGeluProgramInfoLoader(context.inputs));
};

View file

@ -145,20 +145,23 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
attributes.cacheKey));
};
export const erf = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
export const erfImpl = (dataType: string) => `
const r0: f32 = 0.3275911;
const r1: f32 = 0.254829592;
const r2: f32 = -0.284496736;
const r3: f32 = 1.421413741;
const r4: f32 = -1.453152027;
const r5: f32 = 1.061405429;
fn erf_vf32(v: vec4<f32>) -> vec4<f32> {
let absv = abs(v);
let x = 1.0 / (1.0 + r0 * absv);
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
}`));
fn erf_vf32(v: ${dataType}) -> ${dataType} {
let absv = abs(v);
let x = 1.0 / (1.0 + r0 * absv);
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
}`;
export const erf = (context: ComputeContext): void => {
context.compute(
createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl('vec4<f32>')));
};
export const exp = (context: ComputeContext): void => {

View file

@ -33,7 +33,8 @@ const ALL_REGISTERED_OPERATORS: Map < string, {
// parse js_execution_provider.cc
const JS_EXECUTION_PROVIDER_CONTENTS =
fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8');
fs.readFileSync(path.join(__dirname, '../../../onnxruntime/core/providers/js/js_execution_provider.cc'), 'utf8') +
fs.readFileSync(path.join(__dirname, '../../../onnxruntime/contrib_ops/js/js_contrib_kernels.cc'), 'utf8');
MATCHERS.forEach(m => {
for (const match of JS_EXECUTION_PROVIDER_CONTENTS.matchAll(m)) {
const groups = match.groups!;
@ -50,6 +51,9 @@ MATCHERS.forEach(m => {
case 'kMSInternalNHWCDomain':
domain = 'com.ms.internal.nhwc';
break;
case 'kMSDomain':
domain = 'com.microsoft';
break;
default:
throw new Error(`not supported domain: ${opsetDomain}`);
}

View file

@ -0,0 +1,44 @@
[
{
"name": "gelu",
"operator": "Gelu",
"opsets": [{ "domain": "com.microsoft", "version": 1 }],
"attributes": [],
"cases": [
{
"name": "T[0]",
"inputs": [
{
"data": [1.0, -2.0, 0, 2.0],
"dims": [2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 0, 0, 2.0],
"dims": [2, 2],
"type": "float32"
}
]
},
{
"name": "Scalar",
"inputs": [
{
"data": [1.0],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0],
"dims": [],
"type": "float32"
}
]
}
]
}
]

View file

@ -1347,6 +1347,7 @@
"leaky-relu.jsonc",
"reduce-min.jsonc",
"relu.jsonc",
"gelu.jsonc",
//"pad.jsonc",
//"pad-big.jsonc",
"pow.jsonc",

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gelu.h"
namespace onnxruntime {
namespace contrib {
namespace js {
ONNX_OPERATOR_KERNEL_EX(
Gelu,
kMSDomain,
1,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gelu);
} // namespace js
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/js/js_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace js {
using onnxruntime::js::JsKernel;
JSEP_KERNEL_IMPL(Gelu, Gelu);
} // namespace js
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/js/js_contrib_kernels.h"
namespace onnxruntime {
namespace contrib {
namespace js {
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
return info;
}
Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
};
for (auto& function_table_entry : function_table) {
KernelCreateInfo info = function_table_entry();
if (info.kernel_def != nullptr) { // filter disabled entries where type is void
ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
}
}
return Status::OK();
}
} // namespace js
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"
namespace onnxruntime {
namespace contrib {
namespace js {
Status RegisterJsContribKernels(KernelRegistry& kernel_registry);
} // namespace js
} // namespace contrib
} // namespace onnxruntime

View file

@ -8,6 +8,10 @@
#include "js_execution_provider.h"
#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/js/js_contrib_kernels.h"
#endif
#include "core/graph/function_utils.h"
#include "core/framework/compute_capability.h"
#include "core/framework/data_transfer_manager.h"
@ -485,6 +489,10 @@ std::vector<std::unique_ptr<ComputeCapability>> JsExecutionProvider::GetCapabili
std::shared_ptr<KernelRegistry> JsExecutionProvider::GetKernelRegistry() const {
static std::shared_ptr<KernelRegistry> registry = js::RegisterKernels();
#ifndef DISABLE_CONTRIB_OPS
Status status = ::onnxruntime::contrib::js::RegisterJsContribKernels(*registry);
ORT_ENFORCE(status.IsOK(), "Failed to register JS contrib kernels: " + status.ErrorMessage());
#endif
return registry;
}