diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ed174bc098..286dfaf968 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -17,7 +17,7 @@ import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm'; -import {layerNorm} from './ops/layer-norm'; +import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; @@ -83,7 +83,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]], - ['LayerNormalization', [layerNorm]], + ['LayerNormalization', [layerNorm, parseLayerNormAttributes]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], ['Less', [binaryOps.less]], ['LessOrEqual', [binaryOps.lessOrEqual]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 3a84844544..056dd54d54 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -4,58 +4,56 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common'; +import {createTensorShapeVariables, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; -export interface InstanceNormAttributes extends AttributeWithCacheKey { +export interface InstanceNormAttributes { epsilon: number; format: 'NHWC'|'NCHW'; } -const metadata = { - name: 'InstanceNormalization' -}; - const createInstanceNormProgramInfo = (inputs: readonly TensorView[], attributes: InstanceNormAttributes): ProgramInfo => { const xShape = inputs[0].dims; - const outputShape = xShape; const axis = 2; const normCount = ShapeUtil.sizeToDimension(xShape, axis); const normSize = ShapeUtil.sizeFromDimension(xShape, axis); const components = getMaxComponents(normSize); const normPackedSize = normSize / components; - const C = xShape[1]; - const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); - const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); - const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); - const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components); - const variables = [x, scale, bias, output]; - const dataType = x.type.value; - const f32Type = components === 1 ? 'f32' : `vec${components}`; - const workgroupSize = 64; - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const inputShape = [xShape[0], xShape[1], normPackedSize]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; + programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); - const C: u32 = ${C}; - const normSize: u32 = ${normSize}; - const epsilon: f32 = ${attributes.epsilon}; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims); + const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims); + const output = outputVariable('output', inputs[0].dataType, inputShape.length, components); + const variables = [x, scale, bias, output]; + const dataType = x.type.value; + const f32Type = components === 1 ? 'f32' : `vec${components}`; + const workgroupSize = 64; + + const uniforms: UniformsArrayType = [{name: 'normSize', type: 'u32'}, {name: 'normPackedSize', type: 'u32'}]; + return ` var meanShared : f32; var squaredNormShared : f32; var workgroupShared : array<${f32Type}, ${workgroupSize}>; const workgroupSize = ${workgroupSize}u; - ${shaderHelper.declareVariables(...variables)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)} ${shaderHelper.mainStart(workgroupSize)} let norm = global_idx / workgroupSize; - let batch = norm / C; - let channel = norm % C; + let batch = norm / uniforms.x_shape[1]; + let channel = norm % uniforms.x_shape[1]; let localIndex = local_id.x; // initialize workgroup memory var initial = ${f32Type}(0); - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')}); } workgroupShared[localIndex] = initial; @@ -69,13 +67,13 @@ const createInstanceNormProgramInfo = workgroupBarrier(); } if (localIndex == 0) { - meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize); + meanShared = ${sumVector('workgroupShared[0]', components)} / f32(uniforms.normSize); } workgroupBarrier(); // reinitialize workgroup memory. initial = ${f32Type}(0); - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared); initial = initial + deviation * deviation; } @@ -94,23 +92,26 @@ const createInstanceNormProgramInfo = } workgroupBarrier(); - let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon); + let invStdDev = 1 / sqrt(squaredNormShared / f32(uniforms.normSize) + f32(${attributes.epsilon})); let channelScale = invStdDev * f32(${scale.getByOffset('channel')}); let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale; - for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) { + for (var h = localIndex; h < uniforms.normPackedSize; h += workgroupSize) { let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${ - f32Type}(channelShift)); + f32Type}(channelShift)); ${output.set('batch', 'channel', 'h', 'value')}; } }`; + }; return { - ...metadata, - shaderCache: {hint: attributes.cacheKey}, + ...{name: 'InstanceNormalization'}, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: {hint: `${attributes.epsilon};${components}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: outputShape, dataType: inputs[0].dataType}, ], - dispatchGroup: {x: normCount} + dispatchGroup: {x: normCount}, + programUniforms }), getShaderSource, }; @@ -120,10 +121,6 @@ const computeMean = (context: ComputeContext, input: TensorView, scale: TensorView, bias: TensorView, n: number, h: number, c: number, epsilon: number) => { const components = getMaxComponents(c); - const inputHelper = inputVariable('input', input.dataType, input.dims, components); - const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); - const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); - const WG = 64; // we will store channel scale and channel shift in [2, components] matrix // or in vec2 when components == 1 @@ -133,65 +130,79 @@ const computeMean = const unitsOfWork = n * c / components; const wgSize = Math.ceil(h / WG); - const getMeanShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${h}; - const C: u32 = ${c / components}; - const imageSize: u32 = ${h * c / components}; + const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + const meanProgramUniforms: ProgramUniform[] = [ + {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, + {type: 'uint32', data: Math.floor(h * c / components)} + ]; + const getMeanShaderSource = (shaderHelper: ShaderHelper) => { + const inputHelper = inputVariable('input', input.dataType, input.dims, components); + return ` ${shaderHelper.declareVariables(inputHelper)} @group(0) @binding(1) var output : array<${outputType}>; + struct Uniforms {wg_size:u32, H:u32, C:u32, image_size:u32}; + @group(0) @binding(2) var uniforms: Uniforms; ${shaderHelper.mainStart(WG)} - let currentImageNumber = global_idx / ${WG} / C; - let currentChannelNumber = (global_idx / ${WG}) % C; + let currentImageNumber = global_idx / ${WG} / uniforms.C; + let currentChannelNumber = (global_idx / ${WG}) % uniforms.C; let wgId = global_idx % ${WG}; - let wgOffset = wgId * ${wgSize}; - if (wgOffset >= H) { + let wgOffset = wgId * uniforms.wg_size; + if (wgOffset >= uniforms.H) { return; } - let wgMax = min(wgOffset + ${wgSize}, H); + let wgMax = min(wgOffset + uniforms.wg_size, uniforms.H); - let offset = currentImageNumber * imageSize + currentChannelNumber; + let offset = currentImageNumber * uniforms.image_size + currentChannelNumber; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = wgOffset; i < wgMax; i++) { - let value = ${sumCastType}(input[offset + i * C]); + let value = ${sumCastType}(input[offset + i * uniforms.C]); sum += value; squaredSum += value * value; } output[global_idx] = ${setOutputValue('sum', 'squaredSum')}; }`; + }; const meanValues = context.compute( { name: 'InstanceNormComputeMean', - shaderCache: {hint: JSON.stringify({components, n, h, c})}, + shaderCache: {hint: `${components}`, inputDependencies: meanInputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, WG, 2], dataType: DataType.float}, ], dispatchGroup: {x: n * c / components}, + programUniforms: meanProgramUniforms }), getShaderSource: getMeanShaderSource, }, {inputs: [input], outputs: [-1]})[0]; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${h}; - const C: u32 = ${c / components}; - const imageSize: u32 = ${WG * c / components}; - const epsilon: f32 = ${epsilon}; + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, + {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const scaleHelper = inputVariable('scale', scale.dataType, scale.dims, components); + const biasHelper = inputVariable('bias', bias.dataType, bias.dims, components); + return ` @group(0) @binding(0) var input : array<${outputType}>; @group(0) @binding(1) var scale : array<${scaleHelper.type.storage}>; @group(0) @binding(2) var bias : array<${biasHelper.type.storage}>; @group(0) @binding(3) var output : array<${outputType}>; + struct Uniforms {units_of_work : u32, H: u32, C : u32, image_size : u32}; + @group(0) @binding(4) var uniforms: Uniforms; ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} - let currentImageNumber = global_idx / C; - let currentChannelNumber = global_idx % C; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.units_of_work')} + let currentImageNumber = global_idx / uniforms.C; + let currentChannelNumber = global_idx % uniforms.C; - let offset = currentImageNumber * imageSize; + let offset = currentImageNumber * uniforms.image_size; var sum = ${fillVector('f32', components)}; var squaredSum = ${fillVector('f32', components)}; for (var i: u32 = 0; i < ${WG}; i++) { @@ -199,24 +210,26 @@ const computeMean = sum += value[0]; squaredSum += value[1]; } - sum = sum / f32(H); - squaredSum = squaredSum / f32(H); - let invStdDev = 1 / sqrt(squaredSum - sum * sum + epsilon); + sum = sum / f32(uniforms.H); + squaredSum = squaredSum / f32(uniforms.H); + let invStdDev = 1 / sqrt(squaredSum - sum * sum + f32(${epsilon})); let channelScale = invStdDev * ${sumCastType}(scale[currentChannelNumber]); let channelShift = ${sumCastType}(bias[currentChannelNumber]) - sum * channelScale; output[global_idx] = ${setOutputValue('channelScale', 'channelShift')}; }`; - + }; return context.compute( { name: 'InstanceNormComputeChannelScaleShift', - shaderCache: {hint: JSON.stringify({components, n, h, c, epsilon})}, + // TODO: use epsilon as uniform. Currently epsilon as uniform fails test_instancenorm_epsilon. + shaderCache: {hint: `${components};${epsilon}`, inputDependencies}, getRunData: () => ({ outputs: [ {dims: [n, c, 2], dataType: DataType.float}, ], dispatchGroup: {x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, @@ -230,50 +243,51 @@ const createInstanceNormNHWCProgramInfo = const N = xShape[0]; const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; - const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); - const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); - - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; - const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; + const programUniforms: ProgramUniform[] = + [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const scaleType = components === 1 ? 'vec2f' : `mat2x${components}f`; + const scaleCastType = components === 1 ? dataType : `vec${components}<${dataType}>`; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const H: u32 = ${H}; - const C: u32 = ${C / components}; + const inputHelper = inputVariable('input', inputs[0].dataType, inputs[0].dims, components); + const outputHelper = outputVariable('output', inputs[0].dataType, outputShape, components); + return ` @group(0) @binding(0) var input : array<${inputHelper.type.storage}>; @group(0) @binding(1) var scaleInput : array<${scaleType}>; @group(0) @binding(2) var output : array<${outputHelper.type.storage}>; + struct Uniforms {H: u32, C : u32}; + @group(0) @binding(3) var uniforms: Uniforms; ${shaderHelper.mainStart()} - let currentImageNumber = global_idx / (C * H); - let currentChannelNumber = global_idx % C; + let currentImageNumber = global_idx / (uniforms.C * uniforms.H); + let currentChannelNumber = global_idx % uniforms.C; - let scaleOffset = currentImageNumber * C + currentChannelNumber; + let scaleOffset = currentImageNumber * uniforms.C + currentChannelNumber; let scale = scaleInput[scaleOffset]; output[global_idx] = fma(input[global_idx], ${scaleCastType}(scale[0]), ${scaleCastType}(scale[1])); }`; + }; context.compute( { - name: 'InstanceNormalization', - shaderCache: {hint: `${attributes.cacheKey}`}, + name: 'InstanceNormalizationNHWC', + shaderCache: {hint: `${components}`, inputDependencies}, getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)} + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms }), getShaderSource, }, {inputs: [inputs[0], channelScaleShift]}); }; -export const parseInstanceNormAttributes = (attributes: InstanceNormAttributes): InstanceNormAttributes => - createAttributeWithCacheKey({epsilon: attributes.epsilon, format: attributes.format}); - export const instanceNorm = (context: ComputeContext, attributes: InstanceNormAttributes): void => { if (attributes.format === 'NHWC') { createInstanceNormNHWCProgramInfo(context, context.inputs, attributes); diff --git a/js/web/test/data/ops/instance-norm.jsonc b/js/web/test/data/ops/instance-norm.jsonc index 6a4e691240..e89ac2da37 100644 --- a/js/web/test/data/ops/instance-norm.jsonc +++ b/js/web/test/data/ops/instance-norm.jsonc @@ -38,6 +38,79 @@ } ] }, + { + "name": "Simple test with NHWC, components 1", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5], + "dims": [1, 5, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539 + ], + "dims": [1, 5, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NHWC, components 2", + "operator": "InstanceNormalization", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8], + "dims": [2, 6, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6], + "dims": [6], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8, 9], + "dims": [6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4, 5, 6, 7, 8, 9, 4, 5, 6, 7, 8, 9], + "dims": [2, 6, 1, 1], + "type": "float32" + } + ] + } + ] + }, { "name": "Simple test with NCHW", "operator": "InstanceNormalization", @@ -75,5 +148,81 @@ ] } ] + }, + { + "name": "Simple test with NCHW, components 1", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5], + "dims": [1, 5, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5], + "dims": [5], + "type": "float32" + }, + { + "data": [4, 5, 6, 7, 8], + "dims": [5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6, + 9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539 + ], + "dims": [1, 5, 3, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Simple test with NCHW, components 2", + "operator": "InstanceNormalization", + "opset": { "domain": "", "version": 17 }, + "cases": [ + { + "name": "Simple test", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2], + "dims": [1, 3, 6, 1], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + }, + { + "data": [4, 5, 6], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.5361523628234863, 3.1216912269592285, 3.70723032951355, 4.292769432067871, 4.878308296203613, + 5.4638471603393555, 1.8666191101074219, 3.9555397033691406, 6.044460296630859, 8.133380889892578, + 6.044460296630859, 3.9555397033691406, 10.3915433883667, 8.634925842285156, 6.878308296203613, + 5.121691703796387, 3.365074634552002, 1.6084575653076172 + ], + "dims": [1, 3, 6, 1], + "type": "float32" + } + ] + } + ] } ]