mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[js/webgpu] Add HardSigmoid support (#19215)
### Description This op is required in mobilenetv3-small-100. With this PR, mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on ADL.
This commit is contained in:
parent
e283cdb218
commit
2e0a388c36
6 changed files with 30 additions and 3 deletions
|
|
@ -52,6 +52,7 @@ Do not modify directly.*
|
|||
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
|
||||
| Greater | ai.onnx(7-8,9-12,13+) | |
|
||||
| GreaterOrEqual | ai.onnx(12-15,16+) | |
|
||||
| HardSigmoid | ai.onnx(6+) | |
|
||||
| If | ai.onnx(1-10,11-12,13-18,19+) | |
|
||||
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
|
||||
| LayerNormalization | ai.onnx(17+) | |
|
||||
|
|
|
|||
|
|
@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['Greater', [binaryOps.greater]],
|
||||
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
|
||||
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
|
||||
['InstanceNormalization', [instanceNorm]],
|
||||
['LayerNormalization', [layerNorm]],
|
||||
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
|
||||
|
|
|
|||
|
|
@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
|
|||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
|
||||
};
|
||||
|
||||
export interface HardSigmoidAttributes extends AttributeWithCacheKey {
|
||||
readonly alpha: number;
|
||||
readonly beta: number;
|
||||
}
|
||||
|
||||
export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
|
||||
createAttributeWithCacheKey(attributes as {
|
||||
alpha: number;
|
||||
beta: number;
|
||||
});
|
||||
|
||||
export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(createElementwiseProgramInfo(
|
||||
context.inputs[0], 'HardSigmoid',
|
||||
a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
|
||||
attributes.beta})))`,
|
||||
undefined, attributes.cacheKey));
|
||||
};
|
||||
|
||||
export const sin = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
|
||||
};
|
||||
|
|
|
|||
|
|
@ -597,9 +597,9 @@
|
|||
// // "test_hardmax_example",
|
||||
// // "test_hardmax_negative_axis",
|
||||
// // "test_hardmax_one_hot",
|
||||
// // "test_hardsigmoid_default",
|
||||
// // "test_hardsigmoid_example",
|
||||
// // "test_hardsigmoid",
|
||||
"test_hardsigmoid_default",
|
||||
"test_hardsigmoid_example",
|
||||
"test_hardsigmoid",
|
||||
// // "test_hardswish_expanded",
|
||||
// // "test_hardswish",
|
||||
"test_if",
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log);
|
||||
|
||||
|
|
@ -392,6 +393,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
KERNEL_CREATE_INFO(13, Erf),
|
||||
KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid),
|
||||
KERNEL_CREATE_INFO(13, Sigmoid),
|
||||
KERNEL_CREATE_INFO(6, HardSigmoid),
|
||||
KERNEL_CREATE_INFO_VERSIONED(6, 12, Log),
|
||||
KERNEL_CREATE_INFO(13, Log),
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
|
|||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
|
||||
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5)
|
||||
JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid)
|
||||
|
||||
JSEP_KERNEL_IMPL(Log, Log)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
|
||||
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)
|
||||
|
|
|
|||
Loading…
Reference in a new issue