[js/webgpu] Add HardSigmoid support (#19215)

### Description
This op is required in mobilenetv3-small-100. With this PR,
mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on
ADL.
This commit is contained in:
Jiajia Qin 2024-01-23 07:53:26 +08:00 committed by GitHub
parent e283cdb218
commit 2e0a388c36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 30 additions and 3 deletions

View file

@ -52,6 +52,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |

View file

@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],

View file

@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
};
export interface HardSigmoidAttributes extends AttributeWithCacheKey {
readonly alpha: number;
readonly beta: number;
}
export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
createAttributeWithCacheKey(attributes as {
alpha: number;
beta: number;
});
export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'HardSigmoid',
a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
attributes.beta})))`,
undefined, attributes.cacheKey));
};
export const sin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
};

View file

@ -597,9 +597,9 @@
// // "test_hardmax_example",
// // "test_hardmax_negative_axis",
// // "test_hardmax_one_hot",
// // "test_hardsigmoid_default",
// // "test_hardsigmoid_example",
// // "test_hardsigmoid",
"test_hardsigmoid_default",
"test_hardsigmoid_example",
"test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",

View file

@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log);
@ -392,6 +393,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO(13, Erf),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid),
KERNEL_CREATE_INFO(13, Sigmoid),
KERNEL_CREATE_INFO(6, HardSigmoid),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Log),
KERNEL_CREATE_INFO(13, Log),

View file

@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5)
JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid)
JSEP_KERNEL_IMPL(Log, Log)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)