From a02c885f861b7bf18d5caeebb6308bcfec08b89b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 26 Apr 2023 15:11:01 -0700 Subject: [PATCH] [js/webgpu] add implementation of Relu, LeakyRelu and ThresholdedRelu (#15668) ### Description add implementation of Relu, LeakyRelu and ThresholdedRelu --- .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 5 +++- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 28 +++++++++++++++---- js/web/test/suite-test-list.jsonc | 12 ++++---- .../providers/js/js_execution_provider.cc | 12 ++++++++ .../core/providers/js/operators/unary.cc | 12 ++++++++ 5 files changed, 57 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index f25e308439..77f362bb6d 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -32,12 +32,13 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Cos', [unaryOps.cos]], ['Cosh', [unaryOps.cosh]], ['Div', [binaryOps.div]], - ['Elu', [unaryOps.elu, unaryOps.parseEluAttributes]], + ['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]], ['Erf', [unaryOps.erf]], ['Floor', [unaryOps.floor]], ['Gemm', [gemm, parseGemmAttributes]], ['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]], ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], + ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['MatMul', [matMul]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], @@ -45,6 +46,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Neg', [unaryOps.neg]], ['Pow', [binaryOps.pow]], ['Reciprocal', [unaryOps.reciprocal]], + ['Relu', [unaryOps.relu]], ['Sigmoid', [unaryOps.sigmoid]], ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], @@ -52,5 +54,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sub', [binaryOps.sub]], ['Tan', [unaryOps.tan]], ['Tanh', [unaryOps.tanh]], + ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Transpose', [transpose, parseTransposeAttributes]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 6ea0007ba0..6f0c6519fe 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -123,11 +123,14 @@ export const cosh = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh')); }; -export interface EluAttributes extends AttributeWithCacheKey { +export interface AlphaAttributes extends AttributeWithCacheKey { readonly alpha: number; } -export const elu = (context: ComputeContext, attributes: EluAttributes): void => { +export const parseAlphaAttributes = (attributes: Record): AlphaAttributes => + createAttributeWithCacheKey(attributes as {alpha: number}); + +export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => { context.compute(createElementwiseProgramInfoLoader( context.inputs[0], 'Elu', a => `elu_vf32(${a})`, ` const elu_alpha_: f32 = f32(${attributes.alpha}); @@ -142,9 +145,6 @@ export const elu = (context: ComputeContext, attributes: EluAttributes): void => attributes.cacheKey)); }; -export const parseEluAttributes = (attributes: Record): EluAttributes => - createAttributeWithCacheKey(attributes as {alpha: number}); - export const erf = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, ` const r0: f32 = 0.3275911; @@ -165,6 +165,12 @@ export const floor = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor')); }; +export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4(0.0))`, + `const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`)); +}; + export const neg = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`)); }; @@ -173,6 +179,11 @@ export const reciprocal = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`)); }; +export const relu = (context: ComputeContext): void => { + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'Relu', a => `select(vec4(0.0), ${a}, ${a} > vec4(0.0))`)); +}; + export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; @@ -196,3 +207,10 @@ export const tan = (context: ComputeContext): void => { export const tanh = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh')); }; + +export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { + context.compute(createElementwiseProgramInfoLoader( + context.inputs[0], 'ThresholdedRelu', a => `select(vec4(0.0), ${a}, ${a} > thresholded_relu_alpha_)`, + `const thresholded_relu_alpha_: vec4 = vec4(${attributes.alpha});`)); + return 0; +}; diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 17928899c9..90a5c084ad 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -644,9 +644,9 @@ // // "test_layer_normalization_4d_axis3", // // "test_layer_normalization_default_axis_expanded", // // "test_layer_normalization_default_axis", - // "test_leakyrelu_default", - // "test_leakyrelu_example", - // "test_leakyrelu", + "test_leakyrelu_default", + "test_leakyrelu_example", + "test_leakyrelu", // "test_less_bcast", // "test_less_equal_bcast_expanded", // "test_less_equal_bcast", @@ -1254,9 +1254,9 @@ // // "test_tfidfvectorizer_tf_onlybigrams_levelempty", // // "test_tfidfvectorizer_tf_onlybigrams_skip5", // // "test_tfidfvectorizer_tf_uniandbigrams_skip5", - // "test_thresholdedrelu_default", - // "test_thresholdedrelu_example", - // "test_thresholdedrelu", + "test_thresholdedrelu_default", + "test_thresholdedrelu_example", + "test_thresholdedrelu", // // "test_tile_precomputed", // // "test_tile", // // "test_top_k_negative_axis", diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index df5679bda7..2fba425b32 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -109,6 +109,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Clip); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Clip); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, Elu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Relu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Add); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Add); @@ -211,6 +217,12 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), KERNEL_CREATE_INFO(13, Clip), KERNEL_CREATE_INFO(6, Elu), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + KERNEL_CREATE_INFO(14, Relu), + KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + KERNEL_CREATE_INFO(16, LeakyRelu), + KERNEL_CREATE_INFO(10, ThresholdedRelu), // binary - math KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index df8c9760c1..ab1d0da625 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -116,5 +116,17 @@ ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider, JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0) JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu) +JSEP_KERNEL_IMPL(Relu, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu) +JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu) + +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01) +JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu) +JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu) + +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0) +JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu) + } // namespace js } // namespace onnxruntime