fix for webgl lrn (#15236)

fix issue that resulted in wrong results for lrn on webgpu
This commit is contained in:
Guenther Schmuelling 2023-03-30 16:16:57 -07:00 committed by GitHub
parent 9f942e1a3e
commit 4645726d74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -42,24 +42,14 @@ const lrnProgramMetadata = {
inputTypes: [TextureType.unpacked]
};
function getOutputExpression(attributes: LrnAttributes): string {
let expression = `float(${attributes.bias}) + float(${attributes.alpha}) * square_sum`;
if (attributes.beta === 0.5) {
expression = `inversesqrt(${expression})`;
} else if (attributes.beta === 1.0) {
expression = `1.0/(${expression})`;
} else {
expression = `exp(log(${expression})) * float(-${attributes.beta})`;
}
return `x * ${expression}`;
}
function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo {
const C = inputs[0].dims[1];
const rank = inputs[0].dims.length;
const from = -Math.floor((attributes.size - 1) / 2);
const to = Math.ceil((attributes.size - 1) / 2);
const alpha = `float(${attributes.alpha}) / float(${attributes.size})`;
const bias = `float(${attributes.bias})`;
const beta = `float(${attributes.beta})`;
const shaderSource = `
float process(int indices[${rank}]) {
@ -75,8 +65,7 @@ function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): Prog
square_sum += j * j;
}
}
return ${getOutputExpression(attributes)};
return x / pow(${bias} + ${alpha} * square_sum, ${beta});
}`;
return {
...lrnProgramMetadata,