[js/webgpu] add implementation of Relu, LeakyRelu and ThresholdedRelu (#15668)

### Description
add implementation of Relu, LeakyRelu and ThresholdedRelu
This commit is contained in:
Yulong Wang 2023-04-26 15:11:01 -07:00 committed by GitHub
parent 76ddc92fbd
commit a02c885f86
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 57 additions and 12 deletions

View file

@ -32,12 +32,13 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Cos', [unaryOps.cos]],
['Cosh', [unaryOps.cosh]],
['Div', [binaryOps.div]],
['Elu', [unaryOps.elu, unaryOps.parseEluAttributes]],
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
['Erf', [unaryOps.erf]],
['Floor', [unaryOps.floor]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['MatMul', [matMul]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
@ -45,6 +46,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Neg', [unaryOps.neg]],
['Pow', [binaryOps.pow]],
['Reciprocal', [unaryOps.reciprocal]],
['Relu', [unaryOps.relu]],
['Sigmoid', [unaryOps.sigmoid]],
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
@ -52,5 +54,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Sub', [binaryOps.sub]],
['Tan', [unaryOps.tan]],
['Tanh', [unaryOps.tanh]],
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
['Transpose', [transpose, parseTransposeAttributes]],
]);

View file

@ -123,11 +123,14 @@ export const cosh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh'));
};
export interface EluAttributes extends AttributeWithCacheKey {
export interface AlphaAttributes extends AttributeWithCacheKey {
readonly alpha: number;
}
export const elu = (context: ComputeContext, attributes: EluAttributes): void => {
export const parseAlphaAttributes = (attributes: Record<string, unknown>): AlphaAttributes =>
createAttributeWithCacheKey(attributes as {alpha: number});
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
const elu_alpha_: f32 = f32(${attributes.alpha});
@ -142,9 +145,6 @@ export const elu = (context: ComputeContext, attributes: EluAttributes): void =>
attributes.cacheKey));
};
export const parseEluAttributes = (attributes: Record<string, unknown>): EluAttributes =>
createAttributeWithCacheKey(attributes as {alpha: number});
export const erf = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, `
const r0: f32 = 0.3275911;
@ -165,6 +165,12 @@ export const floor = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor'));
};
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<f32>(0.0))`,
`const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`));
};
export const neg = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`));
};
@ -173,6 +179,11 @@ export const reciprocal = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`));
};
export const relu = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'Relu', a => `select(vec4<f32>(0.0), ${a}, ${a} > vec4<f32>(0.0))`));
};
export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
};
@ -196,3 +207,10 @@ export const tan = (context: ComputeContext): void => {
export const tanh = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh'));
};
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
context.compute(createElementwiseProgramInfoLoader(
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<f32>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
`const thresholded_relu_alpha_: vec4<f32> = vec4<f32>(${attributes.alpha});`));
return 0;
};

View file

@ -644,9 +644,9 @@
// // "test_layer_normalization_4d_axis3",
// // "test_layer_normalization_default_axis_expanded",
// // "test_layer_normalization_default_axis",
// "test_leakyrelu_default",
// "test_leakyrelu_example",
// "test_leakyrelu",
"test_leakyrelu_default",
"test_leakyrelu_example",
"test_leakyrelu",
// "test_less_bcast",
// "test_less_equal_bcast_expanded",
// "test_less_equal_bcast",
@ -1254,9 +1254,9 @@
// // "test_tfidfvectorizer_tf_onlybigrams_levelempty",
// // "test_tfidfvectorizer_tf_onlybigrams_skip5",
// // "test_tfidfvectorizer_tf_uniandbigrams_skip5",
// "test_thresholdedrelu_default",
// "test_thresholdedrelu_example",
// "test_thresholdedrelu",
"test_thresholdedrelu_default",
"test_thresholdedrelu_example",
"test_thresholdedrelu",
// // "test_tile_precomputed",
// // "test_tile",
// // "test_top_k_negative_axis",

View file

@ -109,6 +109,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Clip);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Clip);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, Elu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Relu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Relu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Relu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LeakyRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, ThresholdedRelu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Add);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Add);
@ -211,6 +217,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip),
KERNEL_CREATE_INFO(13, Clip),
KERNEL_CREATE_INFO(6, Elu),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu),
KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu),
KERNEL_CREATE_INFO(14, Relu),
KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu),
KERNEL_CREATE_INFO(16, LeakyRelu),
KERNEL_CREATE_INFO(10, ThresholdedRelu),
// binary - math
KERNEL_CREATE_INFO_VERSIONED(7, 12, Add),

View file

@ -116,5 +116,17 @@ ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider,
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0)
JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu)
JSEP_KERNEL_IMPL(Relu, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu)
JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu)
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu)
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu)
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0)
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu)
} // namespace js
} // namespace onnxruntime