mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
webgpu quickgelu (#20939)
This commit is contained in:
parent
5b87544aab
commit
c749bd997a
7 changed files with 125 additions and 0 deletions
|
|
@ -74,6 +74,7 @@ Do not modify directly.*
|
|||
| Not | ai.onnx(1+) | |
|
||||
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
|
||||
| Pow | ai.onnx(7-11,12,13-14,15+) | |
|
||||
| QuickGelu | com.microsoft(1+) | |
|
||||
| Range | ai.onnx(11+) | |
|
||||
| Reciprocal | ai.onnx(6-12,13+) | |
|
||||
| ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | |
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Not', [unaryOps.not]],
|
||||
['Pad', [pad]],
|
||||
['Pow', [binaryOps.pow]],
|
||||
['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]],
|
||||
['Range', [range]],
|
||||
['Reciprocal', [unaryOps.reciprocal]],
|
||||
['ReduceMin', [reduceMin]],
|
||||
|
|
|
|||
|
|
@ -314,3 +314,31 @@ export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttrib
|
|||
export const log = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log'));
|
||||
};
|
||||
|
||||
export const quickGeluImpl = (varType: string, alpha: number) => `
|
||||
const alpha = vec4<${varType}>(${alpha});
|
||||
const one = ${varType}(1.0);
|
||||
const zero = ${varType}(0.0);
|
||||
|
||||
fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> {
|
||||
let v = x *alpha;
|
||||
var x1 : vec4<${varType}>;
|
||||
for (var i = 0; i < 4; i = i + 1) {
|
||||
if (v[i] >= zero) {
|
||||
x1[i] = one / (one + exp(-v[i]));
|
||||
} else {
|
||||
x1[i] = one - one / (one + exp(v[i]));
|
||||
}
|
||||
}
|
||||
return x * x1;
|
||||
}
|
||||
`;
|
||||
|
||||
export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`;
|
||||
|
||||
export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
const dType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'QuickGelu', quickGeluExpression, quickGeluImpl(dType, attributes.alpha), attributes.cacheKey,
|
||||
context.inputs[0].dataType));
|
||||
};
|
||||
|
|
|
|||
46
js/web/test/data/ops/quick-gelu.jsonc
Normal file
46
js/web/test/data/ops/quick-gelu.jsonc
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
[
|
||||
{
|
||||
"name": "QuickGelu test",
|
||||
"operator": "QuickGelu",
|
||||
"opset": { "domain": "com.microsoft", "version": 1 },
|
||||
"cases": [
|
||||
{
|
||||
"name": "[2x4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, -0.8],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.441123, 0.53689, 0.636815],
|
||||
"dims": [2, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "[3x5]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, -1.5],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.845795, 1.9356, 2.98192, 3.99558, 4.99899, 0.953383,
|
||||
1.0622, 1.17178, 1.2817, 1.39166
|
||||
],
|
||||
"dims": [3, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -16,6 +16,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu);
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
|
||||
|
|
@ -38,6 +39,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1,
|
||||
SkipLayerNormalization)>,
|
||||
|
|
|
|||
23
onnxruntime/contrib_ops/js/quick_gelu.cc
Normal file
23
onnxruntime/contrib_ops/js/quick_gelu.cc
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "quick_gelu.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsepSupportedFloatTypes;
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
QuickGelu,
|
||||
kMSDomain,
|
||||
1,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
QuickGelu);
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
24
onnxruntime/contrib_ops/js/quick_gelu.h
Normal file
24
onnxruntime/contrib_ops/js/quick_gelu.h
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/js/js_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace js {
|
||||
|
||||
using onnxruntime::js::JsKernel;
|
||||
|
||||
class QuickGelu final : public JsKernel {
|
||||
public:
|
||||
explicit QuickGelu(const OpKernelInfo& info) : JsKernel(info) {
|
||||
float alpha = info.GetAttrOrDefault<float>("alpha", 1.0);
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(QuickGelu, ({"alpha" : $1}), alpha);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace js
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue