From 4645726d7478e3013e781230df4eb48676dca7da Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 30 Mar 2023 16:16:57 -0700 Subject: [PATCH] fix for webgl lrn (#15236) fix issue that resulted in wrong results for lrn on webgpu --- js/web/lib/onnxjs/backends/webgl/ops/lrn.ts | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts b/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts index 547695844f..21dae1200e 100644 --- a/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts +++ b/js/web/lib/onnxjs/backends/webgl/ops/lrn.ts @@ -42,24 +42,14 @@ const lrnProgramMetadata = { inputTypes: [TextureType.unpacked] }; -function getOutputExpression(attributes: LrnAttributes): string { - let expression = `float(${attributes.bias}) + float(${attributes.alpha}) * square_sum`; - if (attributes.beta === 0.5) { - expression = `inversesqrt(${expression})`; - } else if (attributes.beta === 1.0) { - expression = `1.0/(${expression})`; - } else { - expression = `exp(log(${expression})) * float(-${attributes.beta})`; - } - return `x * ${expression}`; -} - function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo { const C = inputs[0].dims[1]; - const rank = inputs[0].dims.length; const from = -Math.floor((attributes.size - 1) / 2); const to = Math.ceil((attributes.size - 1) / 2); + const alpha = `float(${attributes.alpha}) / float(${attributes.size})`; + const bias = `float(${attributes.bias})`; + const beta = `float(${attributes.beta})`; const shaderSource = ` float process(int indices[${rank}]) { @@ -75,8 +65,7 @@ function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): Prog square_sum += j * j; } } - - return ${getOutputExpression(attributes)}; + return x / pow(${bias} + ${alpha} * square_sum, ${beta}); }`; return { ...lrnProgramMetadata,