From 8f7b89bd5bbfce6983dbd1935e7073bad7701921 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 16 Dec 2023 03:26:15 +0800 Subject: [PATCH] [js/webgpu] Optimize NCHW layout for InstanceNormalization (#18123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description The changes in this PR includes: 1) Fix f16 errors in InstanceNormalization with NCHW format. 2) Use vec to further optimize the original algorithm. 3) (Removed) Don't do layout conversion for InstanceNormalization for JSEP since InstanceNormalization itself is suitable for NCHW layout and has better performance in our current implementation. Tested on sd-vae-decoder-f16.onnx, it becomes 285 ms from 314 ms. The aggregate gpu profiling data can be found as below (Note the data is based change 3).): Before: Kernel | Time (Ms) | Percentage (%) -- | -- | -- Conv | 201.55 | 69.56 InstanceNormalization | 42.49 | 14.67 Transpose | 28.95 | 9.99 Mul | 5.69 | 1.96 Add | 3.82 | 1.32 MatMul | 3.27 | 1.13 Sigmoid | 2.24 | 0.77 Resize | 1.16 | 0.40 Softmax | 0.34 | 0.12 Cast | 0.24 | 0.08 Sum | 289.75
After: Kernel | Time (Ms) | Percentage (%) -- | -- | -- Conv | 205.44 | 79.43 InstanceNormalization | 18.24 | 7.05 Transpose | 17.64 | 6.82 Mul | 5.69 | 2.20 Add | 3.81 | 1.47 MatMul | 3.56 | 1.38 Sigmoid | 2.24 | 0.86 Resize | 1.19 | 0.46 Softmax | 0.59 | 0.23 Cast | 0.24 | 0.09 Sum | 258.65 |   From above table, we can see that two ops time are greatly reduced. One is InstanceNormalization and the other is Transpose. The reason that the transpose time is reduced is because each InstanceNormalization is surrounded with two reshape ops in sd-vae-decoder-f16.onnx. Due to JSEP is prefer NHWC and InstanceNormalization is layout sensitive op, so two extra transpose ops are inserted dynamically when executing this model. After this change, those inserted transpose ops are not needed anymore. So the overall transpose time is reduced. --- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 97f633c7cf..3a84844544 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common'; +import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -26,22 +26,25 @@ const createInstanceNormProgramInfo = const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); + const components = getMaxComponents(normSize); + const normPackedSize = normSize / components; const C = xShape[1]; - const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normSize]); + const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); const variables = [x, scale, bias, output]; const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; const workgroupSize = 64; const getShaderSource = (shaderHelper: ShaderHelper) => ` const C: u32 = ${C}; const normSize: u32 = ${normSize}; const epsilon: f32 = ${attributes.epsilon}; - var meanShared : ${dataType}; - var squaredNormShared : ${dataType}; - var workgroupShared : array<${dataType}, ${workgroupSize}>; + var meanShared : f32; + var squaredNormShared : f32; + var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; ${shaderHelper.declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} @@ -51,9 +54,9 @@ const createInstanceNormProgramInfo = let localIndex = local_id.x; // initialize workgroup memory - var initial: ${dataType} = 0; - for (var h = localIndex; h < normSize; h += workgroupSize) { - initial = initial + ${x.get('batch', 'channel', 'h')}; + var initial = ${f32Type}(0); + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; workgroupBarrier(); @@ -66,14 +69,14 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - meanShared = workgroupShared[0] / ${dataType}(normSize); + meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize); } workgroupBarrier(); // reinitialize workgroup memory. - initial = 0; - for (var h = localIndex; h < normSize; h += workgroupSize) { - let deviation = ${x.get('batch', 'channel', 'h')} - meanShared; + initial = ${f32Type}(0); + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } workgroupShared[localIndex] = initial; @@ -87,15 +90,16 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - squaredNormShared = workgroupShared[0]; + squaredNormShared = ${sumVector('workgroupShared[0]', components)}; } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / ${dataType}(normSize) + epsilon); - let channelScale = invStdDev * ${scale.getByOffset('channel')}; - let channelShift = ${bias.getByOffset('channel')} - meanShared * channelScale; - for (var h = localIndex; h < normSize; h += workgroupSize) { - let value = ${x.get('batch', 'channel', 'h')} * channelScale + channelShift; + let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon); + let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); + let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; + for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ + f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`;