mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[webgpu] Optimize matmulnbits with M > 1 (#23102)
This is the webgpu native ep implementation of #23092. I used https://github.com/fs-eire/ort-webgpu-nodejs-chatapp-prototype to test. Meanwhile, applied https://github.com/fs-eire/ort-webgpu-nodejs-chatapp-prototype/pull/2 to print the first token time. The result is like below: The latest main branch: Intel Arc Graphics ``` 659 tokens in 24.8sec, 26.57 tokens/sec Decoding first token with input 449 tokens: 13.0 sec Decoding remaining 210 tokens: 11.8 sec 17.79 tokens/sec ``` NV RTX 2000 ``` 659 tokens in 14.4sec, 45.85 tokens/sec Decoding first token with input 449 tokens: 7.3 sec Decoding remaining 210 tokens: 7.0 sec 29.81 tokens/sec ``` ------------------------------------------------------------------------- With this PR: Intel Arc Graphics ``` 657 tokens in 20.6sec, 31.92 tokens/sec Decoding first token with input 449 tokens: 8.5 sec Decoding remaining 208 tokens: 12.1 sec 17.23 tokens/sec ``` NV RTX 2000 ``` 659 tokens in 11.4sec, 57.93 tokens/sec Decoding first token with input 449 tokens: 4.1 sec Decoding remaining 210 tokens: 7.2 sec 28.98 tokens/sec ``` From above data, you can see that with this PR, both intel (13s -> 8.5s) and NV (7.3s -> 4.1s) GPUs for the first token time are performing better.
This commit is contained in:
parent
9115682d69
commit
0981bbf4ca
2 changed files with 152 additions and 243 deletions
|
|
@ -39,7 +39,7 @@ std::string QuantizedDataType(int components) {
|
|||
}
|
||||
}
|
||||
|
||||
constexpr unsigned int kMinSequenceLengthForPrefillOptimization = 16;
|
||||
constexpr unsigned int kMinMForTileOptimization = 4;
|
||||
} // namespace
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -60,33 +60,59 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform);
|
||||
const auto& y = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
|
||||
|
||||
if (use_block32_) {
|
||||
if ((is_intel_ || tile_m_ > 1) && block_size_ == 32) {
|
||||
const uint32_t workgroup_size = WorkgroupSizeX() * WorkgroupSizeY();
|
||||
const uint32_t tile_size = WorkgroupSizeX() * components_b_ * 8; // each uint32 has 8 data.
|
||||
const uint32_t a_length_per_tile = tile_size / a.NumComponents();
|
||||
constexpr uint32_t block_size = 32;
|
||||
const uint32_t blocks_per_tile = tile_size / block_size;
|
||||
shader.AdditionalImplementation() << "var<workgroup> sub_a: array<input_a_value_t, " << a_length_per_tile << ">;\n"
|
||||
<< "var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX() << ">, " << WorkgroupSizeY() << ">;\n";
|
||||
std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY());
|
||||
shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n"
|
||||
<< " let col = output_indices[2];\n"
|
||||
" let row = output_indices[1];\n"
|
||||
" let batch = output_indices[0];\n"
|
||||
" let n_blocks_per_col = uniforms.input_b_shape[1];\n"
|
||||
const uint32_t blocks_per_tile = tile_size / block_size_;
|
||||
if (tile_m_ == 1) {
|
||||
shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"
|
||||
" if (col < uniforms.input_a_shape[2]) {\n"
|
||||
<< " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"
|
||||
<< " } else {\n"
|
||||
" return input_a_value_t(0);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
<< "var<workgroup> sub_a: array<input_a_value_t, " << a_length_per_tile << ">;\n"
|
||||
<< "var<workgroup> inter_results: array<array<output_value_t, " << WorkgroupSizeX() << ">, " << WorkgroupSizeY() << ">;\n";
|
||||
std::string offset = "workgroup_idx * " + std::to_string(WorkgroupSizeY());
|
||||
shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n"
|
||||
<< " let col = output_indices[2];\n"
|
||||
" let row = output_indices[1];\n"
|
||||
" let batch = output_indices[0];\n";
|
||||
} else {
|
||||
ORT_ENFORCE(tile_m_ < WorkgroupSizeY(), "tile_m must be less than or equal to WorkgroupSizeY.");
|
||||
ORT_ENFORCE(WorkgroupSizeX() == WorkgroupSizeY(), "WorkgroupSizeX must be equal to WorkgroupSizeY.");
|
||||
|
||||
shader.AdditionalImplementation() << "fn mm_readA(batch : u32, row : u32, col : u32) -> input_a_value_t {\n"
|
||||
" if (row < uniforms.input_a_shape[1] && col < uniforms.input_a_shape[2]) {\n"
|
||||
<< " return " << a.GetByIndices("input_a_indices_t(batch, row, col)") << ";\n"
|
||||
<< " } else {\n"
|
||||
" return input_a_value_t(0);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
<< "var<workgroup> sub_a: array<array<input_a_value_t, " << a_length_per_tile << ">," << tile_m_ << ">;\n"
|
||||
<< "var<workgroup> inter_results: array<array<array<output_value_t, " << WorkgroupSizeX() << ">, " << WorkgroupSizeY() << ">," << tile_m_ << ">;\n";
|
||||
shader.MainFunctionBody() << " let col = workgroup_id.x * " << WorkgroupSizeY() << ";\n"
|
||||
<< " let row = workgroup_id.y * " << tile_m_ << ";\n"
|
||||
<< " let batch = workgroup_id.z;\n";
|
||||
}
|
||||
shader.MainFunctionBody() << " let n_blocks_per_col = uniforms.input_b_shape[1];\n"
|
||||
<< " let num_tiles = (n_blocks_per_col - 1) / " << blocks_per_tile << " + 1;\n"
|
||||
// Loop over shared dimension.
|
||||
<< " for (var tile: u32 = 0; tile < num_tiles; tile += 1) {\n"
|
||||
<< " let a_col_start = tile * " << a_length_per_tile << ";\n"
|
||||
<< " // load one tile A data into shared memory.\n"
|
||||
<< " for (var a_offset = local_idx; a_offset < " << a_length_per_tile << "; a_offset += " << workgroup_size << ") {\n"
|
||||
<< " let a_col = a_col_start + a_offset;\n"
|
||||
" if (a_col < uniforms.input_a_shape[2]) {\n"
|
||||
<< " sub_a[a_offset] = " << a.GetByIndices("input_a_indices_t(batch, row, a_col)") << ";\n"
|
||||
<< " } else {\n"
|
||||
" sub_a[a_offset] = input_a_value_t(0);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
<< " let a_col = a_col_start + a_offset;\n";
|
||||
if (tile_m_ == 1) {
|
||||
shader.MainFunctionBody() << " sub_a[a_offset] = mm_readA(batch, row, a_col);\n";
|
||||
} else {
|
||||
for (uint32_t i = 0; i < tile_m_; i++) {
|
||||
shader.MainFunctionBody() << " sub_a[" << i << "][a_offset] = mm_readA(batch, row + " << i << ", a_col);\n";
|
||||
}
|
||||
}
|
||||
shader.MainFunctionBody() << " }\n"
|
||||
" workgroupBarrier();\n"
|
||||
// Each thread processes one block.
|
||||
" let b_row = col + local_id.y;\n"
|
||||
|
|
@ -111,24 +137,8 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
<< " scale = " << scales.GetByOffset("b_row * n_blocks_per_col + block") << ";\n"
|
||||
<< " b_data = " << b.GetByIndices("input_b_indices_t(b_row, block, 0)") << ";\n"
|
||||
<< " }\n"
|
||||
<< " var word_offset = local_id.x * " << block_size / a.NumComponents() << ";\n"
|
||||
<< " var word_offset = local_id.x * " << block_size_ / a.NumComponents() << ";\n"
|
||||
<< " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n";
|
||||
switch (a.NumComponents()) {
|
||||
case 1:
|
||||
shader.MainFunctionBody() << " let a_data0 = vec4<output_element_t>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]);\n"
|
||||
" let a_data1 = vec4<output_element_t>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]);\n";
|
||||
break;
|
||||
case 2:
|
||||
shader.MainFunctionBody() << " let a_data0 = vec4<output_element_t>(sub_a[word_offset], sub_a[word_offset + 1]);\n"
|
||||
" let a_data1 = vec4<output_element_t>(sub_a[word_offset + 2], sub_a[word_offset + 3]);\n";
|
||||
break;
|
||||
case 4:
|
||||
shader.MainFunctionBody() << " let a_data0 = sub_a[word_offset];\n"
|
||||
" let a_data1 = sub_a[word_offset + 1];\n";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
shader.MainFunctionBody() << " let b_value = b_data";
|
||||
if (components_b_ > 1) {
|
||||
shader.MainFunctionBody() << "[i]";
|
||||
|
|
@ -144,21 +154,63 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
shader.MainFunctionBody() << ", ";
|
||||
}
|
||||
}
|
||||
shader.MainFunctionBody() << ")) * scale;\n"
|
||||
" inter_results[local_id.y][local_id.x] += dot(a_data0, b_dequantized_values[0]) + dot(a_data1, b_dequantized_values[1]);\n"
|
||||
<< " word_offset += " << 8 / a.NumComponents() << ";\n"
|
||||
shader.MainFunctionBody() << ")) * scale;\n";
|
||||
if (tile_m_ == 1) {
|
||||
switch (a.NumComponents()) {
|
||||
case 1:
|
||||
shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4<output_element_t>(sub_a[word_offset], sub_a[word_offset + 1], sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[0]) + dot(vec4<output_element_t>(sub_a[word_offset + 4], sub_a[word_offset + 5], sub_a[word_offset + 6], sub_a[word_offset + 7]), b_dequantized_values[1]);\n";
|
||||
break;
|
||||
case 2:
|
||||
shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(vec4<output_element_t>(sub_a[word_offset], sub_a[word_offset + 1]), b_dequantized_values[0]) + dot(vec4<output_element_t>(sub_a[word_offset + 2], sub_a[word_offset + 3]), b_dequantized_values[1]);\n";
|
||||
break;
|
||||
case 4:
|
||||
shader.MainFunctionBody() << " inter_results[local_id.y][local_id.x] += dot(sub_a[word_offset], b_dequantized_values[0]) + dot(sub_a[word_offset + 1], b_dequantized_values[1]);\n";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
for (uint32_t i = 0; i < tile_m_; i++) {
|
||||
switch (a.NumComponents()) {
|
||||
case 1:
|
||||
shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4<output_element_t>(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1], sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[0]) + dot(vec4<output_element_t>(sub_a[" << i << "][word_offset + 4], sub_a[" << i << "][word_offset + 5], sub_a[" << i << "][word_offset + 6], sub_a[" << i << "][word_offset + 7]), b_dequantized_values[1]);\n";
|
||||
break;
|
||||
case 2:
|
||||
shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(vec4<output_element_t>(sub_a[" << i << "][word_offset], sub_a[" << i << "][word_offset + 1]), b_dequantized_values[0]) + dot(vec4<output_element_t>(sub_a[" << i << "][word_offset + 2], sub_a[" << i << "][word_offset + 3]), b_dequantized_values[1]);\n";
|
||||
break;
|
||||
case 4:
|
||||
shader.MainFunctionBody() << " inter_results[" << i << "][local_id.y][local_id.x] += dot(sub_a[" << i << "][word_offset], b_dequantized_values[0]) + dot(sub_a[" << i << "][word_offset + 1], b_dequantized_values[1]);\n";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n"
|
||||
<< " }\n"
|
||||
" workgroupBarrier();\n"
|
||||
" }\n"
|
||||
<< " if (local_idx < " << WorkgroupSizeY() << ") {\n"
|
||||
<< " var output_value = output_value_t(0);\n"
|
||||
<< " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
|
||||
<< " output_value += inter_results[local_idx][b];\n"
|
||||
" }\n"
|
||||
" if (col + local_idx < uniforms.output_shape[2]) {\n"
|
||||
<< " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n"
|
||||
<< " }\n"
|
||||
" }\n";
|
||||
if (tile_m_ == 1) {
|
||||
shader.MainFunctionBody() << " if (local_idx < " << WorkgroupSizeY() << ") {\n"
|
||||
<< " var output_value = output_value_t(0);\n"
|
||||
<< " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
|
||||
<< " output_value += inter_results[local_idx][b];\n"
|
||||
" }\n"
|
||||
" if (col + local_idx < uniforms.output_shape[2]) {\n"
|
||||
<< " " << y.SetByIndices("output_indices_t(batch, row, col + local_idx)", "output_value") << ";\n"
|
||||
<< " }\n"
|
||||
" }\n";
|
||||
} else {
|
||||
shader.MainFunctionBody() << " if (local_id.y < " << tile_m_ << ") {\n"
|
||||
<< " var output_value = output_value_t(0);\n"
|
||||
<< " for (var b = 0u; b < " << WorkgroupSizeX() << "; b++) {\n"
|
||||
<< " output_value += inter_results[local_id.y][local_id.x][b];\n"
|
||||
" }\n"
|
||||
" if (row + local_id.y < uniforms.output_shape[1] && col + local_id.x < uniforms.output_shape[2]) {\n"
|
||||
<< " " << y.SetByIndices("output_indices_t(batch, row + local_id.y, col + local_id.x)", "output_value") << ";\n"
|
||||
<< " }\n"
|
||||
" }\n";
|
||||
}
|
||||
} else {
|
||||
const std::string quantized_data_type = QuantizedDataType(a.NumComponents());
|
||||
const int output_element_number = y.NumComponents() * gsl::narrow<int>(output_number_);
|
||||
|
|
@ -322,121 +374,6 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulNBitsProgramPrefill::GenerateShaderCode(ShaderHelper& shader) const {
|
||||
shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
|
||||
shader.AddInput("scales", ShaderUsage::UseUniform);
|
||||
shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias);
|
||||
// This shader uses uniforms with the M,N,K convention from traditional matrix multiplicatiion
|
||||
// M is the number of rows in A and M rows in the output.
|
||||
// N is the number of columns in B and N columns in the output.
|
||||
// K is the hidden/shared dimension number of columns in A and K rows in B.
|
||||
// Note in matmulnbits, B matrix is already transposed, however the following remains true
|
||||
// for the shader below M describes A, N describes B and K is the hidden/shared dimension.
|
||||
// K4/K8 are simply K divided by 4 or 8 respectively.
|
||||
shader.AdditionalImplementation() << R"INIT_SECTION(
|
||||
// Matrix dimensions and quantization parameters
|
||||
const TILE_SIZE : u32 = 16u;
|
||||
const VALUES_PER_VEC4 : u32 = 4u;
|
||||
const QUANTIZATION_BLOCK_SIZE : u32 = 32;
|
||||
// We want INNER_DIMENSION_ITEMS_PER_CYCLE to be the number of lanes in an EU/SM,
|
||||
// so we use BLOCKS_PER_CYCLE as 2u, or process weights 2 blocks at a time.
|
||||
// This uses all 16 lanes on 12th gen intel chips.
|
||||
const BLOCKS_PER_CYCLE : u32 = 2u;
|
||||
const INNER_DIMENSION_ITEMS_PER_CYCLE : u32 = 16u; // (QUANTIZATION_BLOCK_SIZE/VALUES_PER_VEC4)*BLOCKS_PER_CYCLE
|
||||
const VECTORIZED_QUANTIZATION_BLOCK_SIZE: u32 = 8u; // QUANTIZATION_BLOCK_SIZE / VALUES_PER_VEC4;
|
||||
|
||||
//Shared memory
|
||||
var<workgroup> tile_A : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
|
||||
var<workgroup> tile_B : array<array<input_a_value_t, INNER_DIMENSION_ITEMS_PER_CYCLE>, TILE_SIZE>;
|
||||
var<workgroup> tile_O : array<array<output_value_t, TILE_SIZE>, TILE_SIZE>;
|
||||
|
||||
fn loadA(slot: u32, a_global : u32, step_idx : u32, parallel_id : u32)
|
||||
{
|
||||
if (a_global >= uniforms.M) {
|
||||
return;
|
||||
}
|
||||
let local_A = input_a[a_global*uniforms.K4+step_idx*INNER_DIMENSION_ITEMS_PER_CYCLE+parallel_id];
|
||||
tile_A[slot][parallel_id] = local_A;
|
||||
}
|
||||
|
||||
fn getBScale(slot: u32, b_global : u32, vec_step_idx : u32, scale_idx: u32) -> output_value_t
|
||||
{
|
||||
// Since scales are output_value_t holding 1 for every 32 values, vec_step_idx jumps over 64 weights at
|
||||
// a time or 2 scales at every step.
|
||||
let scale_offset = vec_step_idx*2;
|
||||
let idx = u32(b_global*(uniforms.K/QUANTIZATION_BLOCK_SIZE)+scale_offset);
|
||||
return scales[idx+scale_idx];
|
||||
}
|
||||
|
||||
fn loadB(slot: u32, b_global : u32, vec_step_idx : u32, parallel_id : u32)
|
||||
{
|
||||
if (b_global >= uniforms.N) {
|
||||
return;
|
||||
}
|
||||
let scale = getBScale(slot, b_global, vec_step_idx, u32(parallel_id/VECTORIZED_QUANTIZATION_BLOCK_SIZE));
|
||||
let idx:u32 = parallel_id;
|
||||
if (idx % 2 == 0)
|
||||
{
|
||||
// Weights are u32 holding 8 values each, each step (vec_step_idx) jumps over 64 weights at a time.
|
||||
// Therefore the weight_offset begin for the current step would be vec_step_idx * 64 if weight
|
||||
// elements were holding one element each. For the case of each element holding 8 values, begin
|
||||
// would become vec_step_idx * 64/8 or vec_step_idx * 8.
|
||||
var weight_offset:u32 = (vec_step_idx*8)+ u32(idx/2);
|
||||
let b_value = input_b[b_global*uniforms.K8+weight_offset];
|
||||
let b_value_lower = unpack4xU8(b_value & 0x0F0F0F0Fu);
|
||||
let b_value_upper = unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu);
|
||||
tile_B[slot][idx].x = (output_value_t(b_value_lower[0]) - 8.0) * scale;
|
||||
tile_B[slot][idx].y = (output_value_t(b_value_upper[0]) - 8.0) * scale;
|
||||
tile_B[slot][idx].z = (output_value_t(b_value_lower[1]) - 8.0) * scale;
|
||||
tile_B[slot][idx].w = (output_value_t(b_value_upper[1]) - 8.0) * scale;
|
||||
tile_B[slot][idx+1].x = (output_value_t(b_value_lower[2]) - 8.0)* scale;
|
||||
tile_B[slot][idx+1].y = (output_value_t(b_value_upper[2]) - 8.0)* scale;
|
||||
tile_B[slot][idx+1].z = (output_value_t(b_value_lower[3]) - 8.0)* scale;
|
||||
tile_B[slot][idx+1].w = (output_value_t(b_value_upper[3]) - 8.0)* scale;
|
||||
}
|
||||
}
|
||||
|
||||
fn computeDotProduct(slot_a: u32, slot_b:u32) -> output_value_t
|
||||
{
|
||||
var sum:output_value_t = 0;
|
||||
for (var idx:u32 = 0 ; idx < INNER_DIMENSION_ITEMS_PER_CYCLE; idx++)
|
||||
{
|
||||
sum += dot(tile_A[slot_a][idx], tile_B[slot_b][idx]);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
)INIT_SECTION";
|
||||
|
||||
shader.MainFunctionBody() << R"MAIN_FN(
|
||||
// Indexing with idx,idy instead of using a 2d dispatch of TILE_SIZE, TILE_SIZE
|
||||
// appears to give a performance win on Intel Gen12LP architecture.
|
||||
// This is likley because of locality of memory access, idy below in this approach
|
||||
// is the same as subgroup_id or lane id, while idx is the wave_id.
|
||||
// The work distribution therefore keeps memory accesses close together in
|
||||
// a single wave in this approach of indexing.
|
||||
let idx = u32(local_idx / TILE_SIZE);
|
||||
let idy = u32(local_idx % TILE_SIZE);
|
||||
let a_global_base = workgroup_id.x * TILE_SIZE;
|
||||
let b_global_base = workgroup_id.y * TILE_SIZE;
|
||||
let step_count:u32 = u32(uniforms.K/(BLOCKS_PER_CYCLE*QUANTIZATION_BLOCK_SIZE));
|
||||
for (var vec_step:u32 = 0; vec_step < step_count; vec_step++)
|
||||
{
|
||||
workgroupBarrier();
|
||||
loadA(idx, a_global_base+idx, vec_step, idy);
|
||||
loadB(idx, b_global_base+idx, vec_step, idy);
|
||||
workgroupBarrier();
|
||||
let result = computeDotProduct(idx, idy);
|
||||
tile_O[idx][idy]+=result;
|
||||
}
|
||||
workgroupBarrier();
|
||||
if (a_global_base+idx < uniforms.M && b_global_base+idy < uniforms.N) {
|
||||
output[(a_global_base+idx) * uniforms.N + b_global_base + idy] = tile_O[idx][idy];
|
||||
}
|
||||
)MAIN_FN";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
|
||||
const Tensor* a = context.Input(0);
|
||||
const Tensor* b = context.Input(1);
|
||||
|
|
@ -471,70 +408,52 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
|
|||
const uint32_t components_b = GetMaxComponents(blob_size_in_words);
|
||||
uint32_t components = GetMaxComponents(N);
|
||||
|
||||
// Use block32 for Intel Gen12LP architecture.
|
||||
const bool use_block32 = context.AdapterInfo().vendor == std::string_view{"intel"} &&
|
||||
context.AdapterInfo().architecture == std::string_view{"gen-12lp"} &&
|
||||
block_size == 32;
|
||||
const bool is_intel = context.AdapterInfo().vendor == std::string_view{"intel"} &&
|
||||
context.AdapterInfo().architecture == std::string_view{"gen-12lp"};
|
||||
const bool has_zero_points = zero_points != nullptr;
|
||||
|
||||
if (use_block32 && batch_count == 1 &&
|
||||
components_a == 4 && components_b == 4 &&
|
||||
!has_zero_points && M >= kMinSequenceLengthForPrefillOptimization) {
|
||||
MatMulNBitsProgramPrefill program;
|
||||
constexpr int32_t tile_size = 16;
|
||||
// subgroup_size here controls how many elements of the hidden dimension we load in a cycle.
|
||||
// MatMulNBitsProgramPrefill does not use any of the subgroup wgsl instructions. The subgroup
|
||||
// size just helps with optimal lane usage in the shader.
|
||||
constexpr int32_t subgroup_size = 16;
|
||||
program.SetWorkgroupSize(tile_size * subgroup_size);
|
||||
program.SetDispatchGroupSize((M + tile_size - 1) / tile_size,
|
||||
(N + tile_size - 1) / tile_size,
|
||||
1);
|
||||
program
|
||||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(4)},
|
||||
{scales, ProgramTensorMetadataDependency::None}})
|
||||
.AddUniformVariables({{static_cast<uint32_t>(M)},
|
||||
{static_cast<uint32_t>(N)},
|
||||
{static_cast<uint32_t>(K)},
|
||||
{static_cast<uint32_t>(K / 4)},
|
||||
{static_cast<uint32_t>(K / 8)}})
|
||||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)});
|
||||
return context.RunProgram(program);
|
||||
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
|
||||
constexpr uint32_t output_number = 1;
|
||||
const uint32_t tile_m = M > kMinMForTileOptimization ? 4 : 1;
|
||||
MatMulNBitsProgram program{output_number, block_size, tile_m, gsl::narrow<int>(components_b), has_zero_points, is_intel};
|
||||
if (M > kMinMForTileOptimization && block_size == 32) {
|
||||
components = 1;
|
||||
constexpr uint32_t workgroup_size = 64;
|
||||
constexpr uint32_t workgroup_y = 8;
|
||||
constexpr uint32_t workgroup_x = workgroup_size / workgroup_y;
|
||||
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
|
||||
program.SetDispatchGroupSize((N + workgroup_y - 1) / workgroup_y,
|
||||
(M + tile_m - 1) / tile_m,
|
||||
batch_count);
|
||||
program.CacheHint("T_M" + std::to_string(tile_m));
|
||||
} else if (is_intel && block_size == 32) {
|
||||
components = 1;
|
||||
constexpr uint32_t workgroup_size = 128;
|
||||
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
|
||||
: 1;
|
||||
const uint32_t workgroup_x = workgroup_size / workgroup_y;
|
||||
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
|
||||
program.SetDispatchGroupSize(data_size / components / workgroup_y);
|
||||
program.CacheHint("T_M" + std::to_string(tile_m));
|
||||
} else {
|
||||
// TODO: Support output_number > 1. Some cases are failed when output_number > 1.
|
||||
// const uint32_t output_number = M > 1 && (N / components) % 2 == 0 ? 2 : 1;
|
||||
constexpr uint32_t output_number = 1;
|
||||
MatMulNBitsProgram program{output_number, gsl::narrow<int>(components_b), has_zero_points, use_block32};
|
||||
|
||||
if (use_block32) {
|
||||
components = 1;
|
||||
constexpr uint32_t workgroup_size = 128;
|
||||
const uint32_t workgroup_y = N % 8 == 0 ? 8 : N % 4 == 0 ? 4
|
||||
: 1;
|
||||
const uint32_t workgroup_x = workgroup_size / workgroup_y;
|
||||
program.SetWorkgroupSize(workgroup_x, workgroup_y, 1);
|
||||
program.SetDispatchGroupSize(data_size / components / workgroup_y);
|
||||
} else {
|
||||
program.SetDispatchGroupSize(data_size / components / output_number);
|
||||
}
|
||||
|
||||
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
|
||||
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
|
||||
TensorShape reshaped_y_shape{batch_count, M, N / components};
|
||||
|
||||
program
|
||||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
|
||||
{scales, ProgramTensorMetadataDependency::None}})
|
||||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
|
||||
.AddUniformVariable({block_size})
|
||||
.CacheHint(std::to_string(output_number));
|
||||
if (has_zero_points) {
|
||||
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
|
||||
}
|
||||
return context.RunProgram(program);
|
||||
program.SetDispatchGroupSize(data_size / components / output_number);
|
||||
program.CacheHint("O_N" + std::to_string(output_number));
|
||||
}
|
||||
|
||||
TensorShape reshaped_a_shape{batch_count, M, K / components_a};
|
||||
TensorShape reshaped_b_shape{N, n_blocks_per_col, blob_size_in_words / components_b};
|
||||
TensorShape reshaped_y_shape{batch_count, M, N / components};
|
||||
|
||||
program
|
||||
.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, reshaped_a_shape, gsl::narrow<int>(components_a)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, reshaped_b_shape, gsl::narrow<int>(components_b * 4 /** b will be accessed as uint32 which includs 4 uint8. So here we need to multiply 4.*/)},
|
||||
{scales, ProgramTensorMetadataDependency::None}})
|
||||
.AddOutput({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, gsl::narrow<int>(components)})
|
||||
.AddUniformVariable({block_size});
|
||||
if (has_zero_points) {
|
||||
program.AddInput({zero_points, ProgramTensorMetadataDependency::None, {(zero_points->Shape().Size() + 3) / 4}, 4});
|
||||
}
|
||||
return context.RunProgram(program);
|
||||
}
|
||||
|
||||
} // namespace webgpu
|
||||
|
|
|
|||
|
|
@ -14,11 +14,13 @@ using namespace onnxruntime::webgpu;
|
|||
|
||||
class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
|
||||
public:
|
||||
MatMulNBitsProgram(uint32_t output_number, int components_b, bool has_zero_points, bool use_block32) : Program{"MatMulNBits"},
|
||||
output_number_{output_number},
|
||||
components_b_{components_b},
|
||||
has_zero_points_{has_zero_points},
|
||||
use_block32_{use_block32} {
|
||||
MatMulNBitsProgram(uint32_t output_number, uint32_t block_size, uint32_t tile_m, int components_b, bool has_zero_points, bool is_intel) : Program{"MatMulNBits"},
|
||||
output_number_{output_number},
|
||||
block_size_{block_size},
|
||||
tile_m_{tile_m},
|
||||
components_b_{components_b},
|
||||
has_zero_points_{has_zero_points},
|
||||
is_intel_{is_intel} {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
|
|
@ -26,23 +28,11 @@ class MatMulNBitsProgram final : public Program<MatMulNBitsProgram> {
|
|||
|
||||
private:
|
||||
uint32_t output_number_;
|
||||
uint32_t block_size_;
|
||||
uint32_t tile_m_;
|
||||
int components_b_;
|
||||
bool has_zero_points_;
|
||||
bool use_block32_;
|
||||
};
|
||||
|
||||
class MatMulNBitsProgramPrefill final : public Program<MatMulNBitsProgramPrefill> {
|
||||
public:
|
||||
MatMulNBitsProgramPrefill() : Program{"MatMulNBitsPrefill"} {
|
||||
}
|
||||
|
||||
Status GenerateShaderCode(ShaderHelper& sh) const override;
|
||||
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
|
||||
{"M", ProgramUniformVariableDataType::Uint32},
|
||||
{"N", ProgramUniformVariableDataType::Uint32},
|
||||
{"K", ProgramUniformVariableDataType::Uint32},
|
||||
{"K4", ProgramUniformVariableDataType::Uint32},
|
||||
{"K8", ProgramUniformVariableDataType::Uint32});
|
||||
bool is_intel_;
|
||||
};
|
||||
|
||||
class MatMulNBits final : public WebGpuKernel {
|
||||
|
|
|
|||
Loading…
Reference in a new issue