mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[js/webgpu] support customop FastGelu (#19392)
### Description Support WebGPU custom operator FastGelu.
This commit is contained in:
parent
a4cfdc1c28
commit
5ff27ef02a
10 changed files with 353 additions and 8 deletions
|
|
@ -41,6 +41,7 @@ Do not modify directly.*
|
|||
| Erf | ai.onnx(9-12,13+) | |
|
||||
| Exp | ai.onnx(6-12,13+) | |
|
||||
| Expand | ai.onnx(8-12,13+) | |
|
||||
| FastGelu | com.microsoft(1+) | |
|
||||
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
|
||||
| Floor | ai.onnx(6-12,13+) | |
|
||||
| FusedConv | com.microsoft(1+) | |
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose'
|
|||
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
|
||||
import {einsum, parseEinsumAttributes} from './ops/einsum';
|
||||
import {expand} from './ops/expand';
|
||||
import {fastGelu} from './ops/fast-gelu';
|
||||
import {gather, parseGatherAttributes} from './ops/gather';
|
||||
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
|
||||
import {gemm, parseGemmAttributes} from './ops/gemm';
|
||||
|
|
@ -72,6 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Erf', [unaryOps.erf]],
|
||||
['Exp', [unaryOps.exp]],
|
||||
['Expand', [expand]],
|
||||
['FastGelu', [fastGelu]],
|
||||
['Floor', [unaryOps.floor]],
|
||||
['FusedConv', [conv, parseConvAttributes]],
|
||||
['Gather', [gather, parseGatherAttributes]],
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI
|
|||
|
||||
${shaderHelper.declareVariables(input, bias, output)}
|
||||
|
||||
${erfImpl(`vec4<${dataType}>`, dataType)}
|
||||
${erfImpl(dataType)}
|
||||
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
|
|
|||
69
js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts
Normal file
69
js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common';
|
||||
import * as unary from './unary-op';
|
||||
|
||||
// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias.
|
||||
|
||||
const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => {
|
||||
const dataType = inputTensors[0].dataType;
|
||||
const outputSize = ShapeUtil.size(inputTensors[0].dims);
|
||||
const biasLength = ShapeUtil.size(inputTensors[1].dims);
|
||||
// can only use vec4 when bias length is multiple of 4
|
||||
const useVec4 = biasLength % 4 === 0;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper): string => {
|
||||
const x = inputVariable('x', dataType, [1], 4);
|
||||
const bias = inputVariable('bias', dataType, [1], 4);
|
||||
const y = outputVariable('y', dataType, [1], 4);
|
||||
|
||||
const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}];
|
||||
|
||||
const singleElementBias = (i: 0|1|2|3) => `
|
||||
let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size;
|
||||
let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`;
|
||||
const biasGetExpression = useVec4 ?
|
||||
`
|
||||
let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` :
|
||||
`${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)}
|
||||
let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`;
|
||||
|
||||
return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)}
|
||||
|
||||
${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))}
|
||||
|
||||
${shaderHelper.mainStart(WORKGROUP_SIZE)}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')}
|
||||
|
||||
let x = ${x.getByOffset('global_idx')};
|
||||
${biasGetExpression}
|
||||
let x_in = x + bias;
|
||||
${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))}
|
||||
}`;
|
||||
};
|
||||
|
||||
return {
|
||||
name: 'FastGeluWithBias',
|
||||
shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']},
|
||||
getShaderSource,
|
||||
getRunData: (inputs) => ({
|
||||
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}],
|
||||
programUniforms:
|
||||
[{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)}
|
||||
})
|
||||
};
|
||||
};
|
||||
|
||||
export const fastGelu = (context: ComputeContext): void => {
|
||||
if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) {
|
||||
unary.fastGelu(context);
|
||||
} else {
|
||||
context.compute(createFastGeluProgramInfo(context.inputs));
|
||||
}
|
||||
};
|
||||
|
|
@ -178,7 +178,7 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void
|
|||
attributes.cacheKey));
|
||||
};
|
||||
|
||||
export const erfImpl = (dataType: string, varType = 'f32') => `
|
||||
export const erfImpl = (varType = 'f32') => `
|
||||
const r0: ${varType} = 0.3275911;
|
||||
const r1: ${varType} = 0.254829592;
|
||||
const r2: ${varType} = -0.284496736;
|
||||
|
|
@ -186,7 +186,7 @@ const r3: ${varType} = 1.421413741;
|
|||
const r4: ${varType} = -1.453152027;
|
||||
const r5: ${varType} = 1.061405429;
|
||||
|
||||
fn erf_vf32(v: ${dataType}) -> ${dataType} {
|
||||
fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> {
|
||||
let absv = abs(v);
|
||||
let x = 1.0 / (1.0 + r0 * absv);
|
||||
return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv));
|
||||
|
|
@ -194,8 +194,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
|
|||
|
||||
export const erf = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
|
||||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType)));
|
||||
};
|
||||
|
||||
export const exp = (context: ComputeContext): void => {
|
||||
|
|
@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => {
|
|||
export const gelu = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
|
||||
erfImpl(`vec4<${dataType}>`, dataType)));
|
||||
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType)));
|
||||
};
|
||||
|
||||
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
|
|
@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => {
|
|||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan'));
|
||||
};
|
||||
|
||||
export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`;
|
||||
|
||||
export const tanh = (context: ComputeContext): void => {
|
||||
// TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved
|
||||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression));
|
||||
};
|
||||
|
||||
export const fastGeluImpl = (varType = 'f32') => `
|
||||
const fast_gelu_a: ${varType} = 0.5;
|
||||
const fast_gelu_b: ${varType} = 0.7978845608028654;
|
||||
const fast_gelu_c: ${varType} = 0.035677408136300125;
|
||||
|
||||
fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> {
|
||||
return ${tanhExpression('v')};
|
||||
}
|
||||
`;
|
||||
|
||||
export const fastGeluExpression = (x: string) =>
|
||||
`(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`;
|
||||
|
||||
export const fastGelu = (context: ComputeContext): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`));
|
||||
context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined,
|
||||
context.inputs[0].dataType));
|
||||
};
|
||||
|
||||
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
|
||||
|
|
|
|||
211
js/web/test/data/ops/fast-gelu.jsonc
Normal file
211
js/web/test/data/ops/fast-gelu.jsonc
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
[
|
||||
{
|
||||
"name": "FastGelu test without bias",
|
||||
"operator": "FastGelu",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"cases": [
|
||||
{
|
||||
"name": "scalar",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.841192],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[2x4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[3x5]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581,
|
||||
1.0617, 1.17393, 1.28671, 1.39957
|
||||
],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "FastGelu test with bias",
|
||||
"operator": "FastGelu",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"cases": [
|
||||
{
|
||||
"name": "scalar",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [0.5],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.39957],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[2x4], [4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3, 4],
|
||||
"dims": [4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[2x4], [3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1, 2, 3],
|
||||
"dims": [3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[3x5], [2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [2, 3],
|
||||
"dims": [2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869,
|
||||
4.39999, 3.49938
|
||||
],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[3x5], [7]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7],
|
||||
"dims": [7],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989,
|
||||
4.09996, 3.59959
|
||||
],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[4x4], [8]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0],
|
||||
"dims": [4, 4],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1],
|
||||
"dims": [8],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957,
|
||||
1.39957, 4.39999, 1.0617, -0.149419, 3.09737
|
||||
],
|
||||
"dims": [4, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1352,6 +1352,7 @@
|
|||
"equal.jsonc",
|
||||
"exp.jsonc",
|
||||
"expand.jsonc",
|
||||
"fast-gelu.jsonc",
|
||||
"floor.jsonc",
|
||||
"gather-elements.jsonc",
|
||||
"gemm.jsonc",
|
||||
|
|
|
|||
23
onnxruntime/contrib_ops/js/fast_gelu.cc
Normal file
23
onnxruntime/contrib_ops/js/fast_gelu.cc
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "fast_gelu.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsepSupportedFloatTypes;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
FastGelu,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
FastGelu);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
17
onnxruntime/contrib_ops/js/fast_gelu.h
Normal file
17
onnxruntime/contrib_ops/js/fast_gelu.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsKernel;
|
||||
JSEP_KERNEL_IMPL(FastGelu, FastGelu);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -8,6 +8,7 @@ namespace contrib {
|
|||
namespace js {
|
||||
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu);
|
||||
|
|
@ -24,6 +25,7 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
|
|||
Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue