mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[js/webgpu] add implementation of Relu, LeakyRelu and ThresholdedRelu (#15668)
### Description add implementation of Relu, LeakyRelu and ThresholdedRelu
This commit is contained in:
parent
76ddc92fbd
commit
a02c885f86
5 changed files with 57 additions and 12 deletions
|
|
@ -32,12 +32,13 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Cos', [unaryOps.cos]],
|
||||
['Cosh', [unaryOps.cosh]],
|
||||
['Div', [binaryOps.div]],
|
||||
['Elu', [unaryOps.elu, unaryOps.parseEluAttributes]],
|
||||
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
|
||||
['Erf', [unaryOps.erf]],
|
||||
['Floor', [unaryOps.floor]],
|
||||
['Gemm', [gemm, parseGemmAttributes]],
|
||||
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
|
||||
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
|
||||
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
|
||||
['MatMul', [matMul]],
|
||||
// TODO: support new attributes for MaxPool-8 and MaxPool-10
|
||||
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
|
||||
|
|
@ -45,6 +46,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Neg', [unaryOps.neg]],
|
||||
['Pow', [binaryOps.pow]],
|
||||
['Reciprocal', [unaryOps.reciprocal]],
|
||||
['Relu', [unaryOps.relu]],
|
||||
['Sigmoid', [unaryOps.sigmoid]],
|
||||
['Sin', [unaryOps.sin]],
|
||||
['Sinh', [unaryOps.sinh]],
|
||||
|
|
@ -52,5 +54,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['Sub', [binaryOps.sub]],
|
||||
['Tan', [unaryOps.tan]],
|
||||
['Tanh', [unaryOps.tanh]],
|
||||
['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]],
|
||||
['Transpose', [transpose, parseTransposeAttributes]],
|
||||
]);
|
||||
|
|
|
|||
|
|
@ -123,11 +123,14 @@ export const cosh = (context: ComputeContext): void => {
|
|||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Cosh', 'cosh'));
|
||||
};
|
||||
|
||||
export interface EluAttributes extends AttributeWithCacheKey {
|
||||
export interface AlphaAttributes extends AttributeWithCacheKey {
|
||||
readonly alpha: number;
|
||||
}
|
||||
|
||||
export const elu = (context: ComputeContext, attributes: EluAttributes): void => {
|
||||
export const parseAlphaAttributes = (attributes: Record<string, unknown>): AlphaAttributes =>
|
||||
createAttributeWithCacheKey(attributes as {alpha: number});
|
||||
|
||||
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
|
||||
const elu_alpha_: f32 = f32(${attributes.alpha});
|
||||
|
|
@ -142,9 +145,6 @@ export const elu = (context: ComputeContext, attributes: EluAttributes): void =>
|
|||
attributes.cacheKey));
|
||||
};
|
||||
|
||||
export const parseEluAttributes = (attributes: Record<string, unknown>): EluAttributes =>
|
||||
createAttributeWithCacheKey(attributes as {alpha: number});
|
||||
|
||||
export const erf = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, `
|
||||
const r0: f32 = 0.3275911;
|
||||
|
|
@ -165,6 +165,12 @@ export const floor = (context: ComputeContext): void => {
|
|||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Floor', 'floor'));
|
||||
};
|
||||
|
||||
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<f32>(0.0))`,
|
||||
`const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`));
|
||||
};
|
||||
|
||||
export const neg = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Neg', a => `-${a}`));
|
||||
};
|
||||
|
|
@ -173,6 +179,11 @@ export const reciprocal = (context: ComputeContext): void => {
|
|||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Reciprocal', a => `1.0/${a}`));
|
||||
};
|
||||
|
||||
export const relu = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'Relu', a => `select(vec4<f32>(0.0), ${a}, ${a} > vec4<f32>(0.0))`));
|
||||
};
|
||||
|
||||
export const sigmoid = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
|
||||
};
|
||||
|
|
@ -196,3 +207,10 @@ export const tan = (context: ComputeContext): void => {
|
|||
export const tanh = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfoLoader(context.inputs[0], 'Tanh', 'tanh'));
|
||||
};
|
||||
|
||||
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
|
||||
context.compute(createElementwiseProgramInfoLoader(
|
||||
context.inputs[0], 'ThresholdedRelu', a => `select(vec4<f32>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
|
||||
`const thresholded_relu_alpha_: vec4<f32> = vec4<f32>(${attributes.alpha});`));
|
||||
return 0;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -644,9 +644,9 @@
|
|||
// // "test_layer_normalization_4d_axis3",
|
||||
// // "test_layer_normalization_default_axis_expanded",
|
||||
// // "test_layer_normalization_default_axis",
|
||||
// "test_leakyrelu_default",
|
||||
// "test_leakyrelu_example",
|
||||
// "test_leakyrelu",
|
||||
"test_leakyrelu_default",
|
||||
"test_leakyrelu_example",
|
||||
"test_leakyrelu",
|
||||
// "test_less_bcast",
|
||||
// "test_less_equal_bcast_expanded",
|
||||
// "test_less_equal_bcast",
|
||||
|
|
@ -1254,9 +1254,9 @@
|
|||
// // "test_tfidfvectorizer_tf_onlybigrams_levelempty",
|
||||
// // "test_tfidfvectorizer_tf_onlybigrams_skip5",
|
||||
// // "test_tfidfvectorizer_tf_uniandbigrams_skip5",
|
||||
// "test_thresholdedrelu_default",
|
||||
// "test_thresholdedrelu_example",
|
||||
// "test_thresholdedrelu",
|
||||
"test_thresholdedrelu_default",
|
||||
"test_thresholdedrelu_example",
|
||||
"test_thresholdedrelu",
|
||||
// // "test_tile_precomputed",
|
||||
// // "test_tile",
|
||||
// // "test_top_k_negative_axis",
|
||||
|
|
|
|||
|
|
@ -109,6 +109,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, Clip);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Clip);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, Elu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Relu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Relu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Relu);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LeakyRelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, ThresholdedRelu);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 12, Add);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 13, Add);
|
||||
|
|
@ -211,6 +217,12 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip),
|
||||
KERNEL_CREATE_INFO(13, Clip),
|
||||
KERNEL_CREATE_INFO(6, Elu),
|
||||
KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu),
|
||||
KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu),
|
||||
KERNEL_CREATE_INFO(14, Relu),
|
||||
KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu),
|
||||
KERNEL_CREATE_INFO(16, LeakyRelu),
|
||||
KERNEL_CREATE_INFO(10, ThresholdedRelu),
|
||||
|
||||
// binary - math
|
||||
KERNEL_CREATE_INFO_VERSIONED(7, 12, Add),
|
||||
|
|
|
|||
|
|
@ -116,5 +116,17 @@ ONNX_OPERATOR_KERNEL_EX(Clip, kOnnxDomain, 13, kJsExecutionProvider,
|
|||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(Elu, Elu, alpha, 1.0)
|
||||
JSEP_ELEMENTWISE_KERNEL(Elu, 6, float, Elu)
|
||||
|
||||
JSEP_KERNEL_IMPL(Relu, Relu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, float, Relu)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, float, Relu)
|
||||
JSEP_ELEMENTWISE_KERNEL(Relu, 14, float, Relu)
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(LeakyRelu, LeakyRelu, alpha, 0.01)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, float, LeakyRelu)
|
||||
JSEP_ELEMENTWISE_KERNEL(LeakyRelu, 16, float, LeakyRelu)
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_DEFAULT(ThresholdedRelu, ThresholdedRelu, alpha, 1.0)
|
||||
JSEP_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, float, ThresholdedRelu)
|
||||
|
||||
} // namespace js
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue