diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c40229cde9..f2057df533 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -868,6 +868,7 @@ class ShaderHelperImpl implements ShaderHelper { const paramList = is1DimensionDispatch ? `@builtin(global_invocation_id) global_id : vec3, @builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_index) local_idx : u32, @builtin(local_invocation_id) local_id : vec3` : `@builtin(global_invocation_id) global_id : vec3, @builtin(local_invocation_id) local_id : vec3, @@ -876,7 +877,6 @@ class ShaderHelperImpl implements ShaderHelper { @builtin(num_workgroups) num_workgroups : vec3`; const globalIdxDefinition = is1DimensionDispatch ? `let global_idx = global_id.x; - let local_idx = local_id.x; let workgroup_index = workgroup_id.x;` : `let workgroup_index = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts index 3f4617014e..3e1f1be22e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -266,9 +266,185 @@ export const createMatMulNBitsProgramInfo = ( }; }; +// Currently, only support blockSize = 32. +export const createMatMulNBitsBlockSize32ProgramInfo = ( + inputs: readonly TensorView[], + attributes: MatMulNBitsAttributes, +): ProgramInfo => { + const inputShape = inputs[0].dims; + const aRank = inputShape.length; + const dimAOuter = inputShape[aRank - 2]; + const dimInner = attributes.k; + const dimBOuter = attributes.n; + const batchDims = inputShape.slice(0, aRank - 2); + const batchSize = ShapeUtil.size(batchDims); + const blobSize = inputs[1].dims[2]; + const blobSizeInWords = blobSize / 4; + const dataType = inputs[0].dataType; + const aComponents = getMaxComponents(attributes.k); + const bComponents = getMaxComponents(blobSizeInWords); + const outputShape = batchDims.concat([dimAOuter, dimBOuter]); + + const workgroupSize = 128; + const workgroupY = dimBOuter % 8 === 0 ? 8 : dimBOuter % 4 === 0 ? 4 : 1; + const workgroupX = workgroupSize / workgroupY; + const tileSize = workgroupX * bComponents * 8; // each uint32 has 8 data. + const aLengthPerTile = tileSize / aComponents; + const blocksPerTile = tileSize / attributes.blockSize; + const dispatchSize = ShapeUtil.size(outputShape) / workgroupY; + + const programUniforms: ProgramUniform[] = []; + const inputShapeTemp = [batchSize, dimAOuter, dimInner / aComponents]; + const bShape = ShapeUtil.convertShape(inputs[1].dims).slice(); + bShape.splice(-1, 1, blobSizeInWords / bComponents); + programUniforms.push(...createTensorShapeVariables(inputShapeTemp)); + programUniforms.push(...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter]; + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const inputRank = inputShapeTemp.length; + const a = inputVariable('a', inputs[0].dataType, inputRank, aComponents); + const b = inputVariable('b', DataType.uint32, bShape.length, bComponents); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const outputRank = outputShapeTemp.length; + const output = outputVariable('output', inputs[0].dataType, outputRank); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const readA = () => { + switch (aComponents) { + case 1: + return ` + let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]); + let a_data1 = vec4<${dataType}>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);`; + case 2: + return ` + let a_data0 = vec4<${dataType}>(sub_a[word_offset], sub_a[word_offset + 1]); + let a_data1 = vec4<${dataType}>(sub_a[word_offset + 2], sub_a[word_offset + 3]);`; + case 4: + return ` + let a_data0 = sub_a[word_offset]; + let a_data1 = sub_a[word_offset + 1];`; + default: + throw new Error(`${aComponents}-component is not supported.`); + } + }; + + return ` + var sub_a: array<${a.type.value}, ${aLengthPerTile}>; + var inter_results: array, ${workgroupY}>; + ${shaderHelper.declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart([workgroupX, workgroupY, 1])} + let output_indices = ${output.offsetToIndices(`workgroup_index * ${workgroupY}`)}; + let col = output_indices[2]; + let row = output_indices[1]; + let batch = output_indices[0]; + let n_blocks_per_col = uniforms.b_shape[1]; + let num_tiles = (n_blocks_per_col - 1) / ${blocksPerTile} + 1; + + // Loop over shared dimension. + for (var tile: u32 = 0; tile < num_tiles; tile += 1) { + let a_col_start = tile * ${aLengthPerTile}; + // load one tile A data into shared memory. + for (var a_offset = local_idx; a_offset < ${aLengthPerTile}; a_offset += ${workgroupSize}) + { + let a_col = a_col_start + a_offset; + if (a_col < uniforms.a_shape[2]) + { + sub_a[a_offset] = ${a.getByIndices(`${a.type.indices}(batch, row, a_col)`)}; + } else { + sub_a[a_offset] = ${a.type.value}(0); + } + } + workgroupBarrier(); + + // each thread process one block + let b_row = col + local_id.y; + let block = tile * ${blocksPerTile} + local_id.x; + ${ + zeroPoints + ? ` + let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2; + let zero_point_byte_count = b_row * zero_point_bytes_per_col + (block >> 0x1u); + let zero_point_word_index = zero_point_byte_count >> 0x2u; + let zero_point_byte_offset = zero_point_byte_count & 0x3u; + let zero_point_nibble_offset: u32 = block & 0x1u; + let zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2); + let zero_point_word = ${zeroPoints.getByOffset('zero_point_word_index')} >> zero_point_bits_offset; + let zero_point = ${dataType}((zero_point_word) & 0xFu);` + : ` + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point = ${dataType}(${8.0});` + } + let scale = ${scales.getByOffset(`b_row * n_blocks_per_col + block`)}; + let b_data = ${b.getByIndices(`${b.type.indices}(b_row, block, 0)`)}; + var word_offset = local_id.x * ${attributes.blockSize / aComponents}; + for (var i: u32 = 0; i < ${bComponents}; i++) { + ${readA()} + let b_value = ${bComponents === 1 ? `b_data` : `b_data[i]`}; + let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu); + let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu); + let b_quantized_values = mat2x4<${dataType}>(${Array.from( + { length: 4 }, + (_, i) => `${dataType}(b_value_lower[${i}]), ${dataType}(b_value_upper[${i}])`, + ).join(', ')}); + let b_dequantized_values = (b_quantized_values - mat2x4<${dataType}>(${Array(8).fill('zero_point').join(',')})) * scale; + inter_results[local_id.y][local_id.x] += ${Array.from( + { length: 2 }, + (_, i) => `${`dot(a_data${i}, b_dequantized_values[${i}])`}`, + ).join(' + ')}; + word_offset += ${8 / aComponents}; + } + workgroupBarrier(); + } + + if (local_idx < ${workgroupY}) { + var output_value: ${output.type.value} = ${output.type.value}(0); + for (var b = 0u; b < ${workgroupX}; b++) { + output_value += inter_results[local_idx][b]; + } + if (col + local_idx < uniforms.output_shape[2]) + { + ${output.setByIndices(`${output.type.indices}(batch, row, col + local_idx)`, 'output_value')} + } + } + }`; + }; + return { + name: 'BlockwiseMatMulNBits32', + shaderCache: { + hint: `${attributes.blockSize};${aComponents};${bComponents};${workgroupX};${workgroupY}`, + inputDependencies: Array(inputs.length).fill('rank'), + }, + getRunData: () => ({ + outputs: [{ dims: outputShape, dataType }], + dispatchGroup: { x: dispatchSize }, + programUniforms, + }), + getShaderSource, + }; +}; + export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { validateInputs(context.inputs, attributes); - context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); + if ( + attributes.blockSize === 32 && + context.adapterInfo.isVendor('intel') && + context.adapterInfo.isArchitecture('gen-12lp') + ) { + context.compute(createMatMulNBitsBlockSize32ProgramInfo(context.inputs, attributes)); + } else { + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); + } }; export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 776263b143..3b3c55733c 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -15,7 +15,7 @@ export enum GpuDataType { } export type GpuDataId = number; -export type GpuArchitecture = 'ampere'; +export type GpuArchitecture = 'ampere' | 'gen-12lp'; export type GpuVendor = 'amd' | 'intel' | 'nvidia'; export interface AdapterInfo { isArchitecture: (architecture: GpuArchitecture) => boolean;