mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Implement memory layout optimizations. Move away from array for lane_output which leads to it being put on registers.
This commit is contained in:
parent
97c2bbe3eb
commit
d613a0ff80
1 changed files with 96 additions and 64 deletions
|
|
@ -613,17 +613,14 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
const tile_size_k = 32;
|
||||
const vec_factor = 4;
|
||||
const u32_factor = 4;
|
||||
const tile_size_k_vec = 4;
|
||||
const tile_size_k_vec = 2;
|
||||
const block_size = 32;
|
||||
|
||||
// Shared memory
|
||||
var<workgroup> tile_A : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
|
||||
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
|
||||
var<workgroup> tile_B : array<array<vec2<u32>, tile_size_k_vec>, tile_size>; // 64 x 32
|
||||
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
|
||||
|
||||
// Private memory
|
||||
var<private> lane_output: array<output_element_t, 16>;
|
||||
var<workgroup> tile_A : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
|
||||
var<workgroup> scale_A : array<output_element_t, tile_size>; // 64 x 1
|
||||
var<workgroup> tile_B : array<array<vec4<u32>, tile_size>, tile_size_k_vec>; // 64 x 32
|
||||
var<workgroup> scale_B : array<output_element_t, tile_size>; // 64 x 1
|
||||
|
||||
fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32)
|
||||
{
|
||||
|
|
@ -632,11 +629,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
{
|
||||
return;
|
||||
}
|
||||
tile_A[row][col] = input_a[a_global*uniforms.K8+kidx_v+col];
|
||||
tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col];
|
||||
if (col == 0)
|
||||
{
|
||||
// kidx_v - covers 8 values of k
|
||||
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/16];
|
||||
// kidx_v - covers 16 values of k
|
||||
scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8];
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -648,36 +645,45 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
return;
|
||||
}
|
||||
|
||||
let b_value = input_b[b_global*uniforms.K8+kidx_v+col];
|
||||
var b_value_lower = vec4<i32>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<i32>(8);
|
||||
var b_value_upper = vec4<i32>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
|
||||
tile_B[row][col][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
|
||||
tile_B[row][col][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
|
||||
let b_value = input_b[b_global*uniforms.K16+kidx_v+col];
|
||||
var b_value_lower = vec4<i32>(unpack4xU8(b_value[0] & 0x0F0F0F0Fu)) - vec4<i32>(8);
|
||||
var b_value_upper = vec4<i32>(unpack4xU8((b_value[0] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
|
||||
tile_B[col][row][0] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
|
||||
tile_B[col][row][1] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
|
||||
b_value_lower = vec4<i32>(unpack4xU8(b_value[1] & 0x0F0F0F0Fu)) - vec4<i32>(8);
|
||||
b_value_upper = vec4<i32>(unpack4xU8((b_value[1] >> 4) & 0x0F0F0F0Fu)) - vec4<i32>(8);
|
||||
tile_B[col][row][2] = pack4xI8(vec4<i32>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]));
|
||||
tile_B[col][row][3] = pack4xI8(vec4<i32>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]));
|
||||
if (col == 0)
|
||||
{
|
||||
// kidx_v - each kidx_v covers 8 values of k
|
||||
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/4];
|
||||
// kidx_v - each kidx_v covers 16 values of k
|
||||
scale_B[row] = scales_b[b_global*(uniforms.K/32) + kidx_v/2];
|
||||
}
|
||||
}
|
||||
|
||||
fn DP4AI(a:vec4<u32>, b:vec4<u32>) -> i32
|
||||
// Scaled dot product of 8 packed unsigned integers.
|
||||
fn SDP8AI(a1:vec4<u32>, b1:vec4<u32>, a2:vec4<u32>, b2:vec4<u32>, scale:output_element_t) -> output_element_t
|
||||
{
|
||||
var local_sum = dot4I8Packed(a[0], b[0]);
|
||||
local_sum += dot4I8Packed(a[1], b[1]);
|
||||
local_sum += dot4I8Packed(a[2], b[2]);
|
||||
local_sum += dot4I8Packed(a[3], b[3]);
|
||||
return local_sum;
|
||||
var local_sum = dot4I8Packed(a1[0], b1[0]);
|
||||
local_sum += dot4I8Packed(a1[1], b1[1]);
|
||||
local_sum += dot4I8Packed(a1[2], b1[2]);
|
||||
local_sum += dot4I8Packed(a1[3], b1[3]);
|
||||
local_sum += dot4I8Packed(a2[0], b2[0]);
|
||||
local_sum += dot4I8Packed(a2[1], b2[1]);
|
||||
local_sum += dot4I8Packed(a2[2], b2[2]);
|
||||
local_sum += dot4I8Packed(a2[3], b2[3]);
|
||||
return output_element_t(local_sum) * scale;
|
||||
}
|
||||
|
||||
)ADDNL_FN";
|
||||
|
||||
shader.MainFunctionBody() << R"MAIN_FN(
|
||||
// During the load phase we use all 256 threads to load 64 rows of A/B.
|
||||
// For each row we load 4 vectorized elements, which are 32 elements of K.
|
||||
// For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
|
||||
let a_global_base = workgroup_id.x * tile_size;
|
||||
let b_global_base = workgroup_id.y * tile_size;
|
||||
let load_row = u32(local_idx/4);
|
||||
let load_col = u32(local_idx%4);
|
||||
let load_AorB = u32(local_idx/128);
|
||||
let load_row = u32((local_idx%128)/2);
|
||||
let load_col = u32(local_idx%2);
|
||||
|
||||
// During the compute phase, we have the 64x64 tile split into
|
||||
// subtiles of 16x16. We have a grid of 4x4 subtiles.
|
||||
|
|
@ -689,41 +695,68 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
// For each subtile we have 16 threads assigned.
|
||||
let a_idx = u32(local_idx % subtile_size);
|
||||
|
||||
// K's vectrorization is 8 items per index. See input_a/input_b.
|
||||
// tile_size_k_vec - is the k tile size in vectorized k units/space (1/8).
|
||||
for (var kidx_v:u32 = 0; kidx_v < uniforms.K8; kidx_v+=tile_size_k_vec)
|
||||
var lane_output1: vec4<output_element_t>;
|
||||
var lane_output2: vec4<output_element_t>;
|
||||
var lane_output3: vec4<output_element_t>;
|
||||
var lane_output4: vec4<output_element_t>;
|
||||
// K's vectrorization is 16 items per index. See input_a/input_b.
|
||||
// tile_size_k_vec - is the k tile size in vectorized space (1/16). That is
|
||||
// k tile size is 32. In vectorized space that is 32/16 = 2.
|
||||
for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec)
|
||||
{
|
||||
// Populate shared memory for the workgroup
|
||||
loadSHMA(a_global_base, kidx_v, load_row, load_col);
|
||||
loadSHMB(b_global_base, kidx_v, load_row, load_col);
|
||||
workgroupBarrier();
|
||||
|
||||
var own_a0: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][0], tile_A[base_A + a_idx][1]);
|
||||
var own_a1: vec4<u32> = vec4<u32>(tile_A[base_A + a_idx][2], tile_A[base_A + a_idx][3]);
|
||||
var own_scale_a = scale_A[base_A + a_idx];
|
||||
if (sg_size == 16)
|
||||
// Load Phase: Populate shared memory for the workgroup.
|
||||
if (load_AorB == 0)
|
||||
{
|
||||
var own_b0: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][0], tile_B[base_B + sg_id][1]);
|
||||
var own_b1: vec4<u32> = vec4<u32>(tile_B[base_B + sg_id][2], tile_B[base_B + sg_id][3]);
|
||||
var own_scale_b = scale_B[base_B + sg_id];
|
||||
for (var col:u32 = 0; col < 16; col++)
|
||||
{
|
||||
var local_scale_b = subgroupShuffle(own_scale_b, col);
|
||||
local_scale_b = local_scale_b * own_scale_a;
|
||||
var local_sum = DP4AI(own_a0, subgroupShuffle(own_b0, col));
|
||||
local_sum += DP4AI(own_a1, subgroupShuffle(own_b1, col));
|
||||
lane_output[col] += (output_element_t(local_sum) * local_scale_b);
|
||||
}
|
||||
loadSHMA(a_global_base, kidx_v, load_row, load_col);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (var col:u32 = 0; col < 16; col++)
|
||||
loadSHMB(b_global_base, kidx_v, load_row, load_col);
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute phase: Perform matmul for this subtile 16 x 32 x 16.
|
||||
// Step 1: Load from shared memory into registers across entire subgroup.
|
||||
var own_a0: vec4<u32> = tile_A[0][base_A + a_idx];
|
||||
var own_a1: vec4<u32> = tile_A[1][base_A + a_idx];
|
||||
var own_scale_a: output_element_t = scale_A[base_A + a_idx];
|
||||
if (sg_size == 16)
|
||||
{
|
||||
var own_b0: vec4<u32> = tile_B[0][base_B + sg_id];
|
||||
var own_b1: vec4<u32> = tile_B[1][base_B + sg_id];
|
||||
var own_scale_b: output_element_t = scale_B[base_B + sg_id];
|
||||
// Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul.
|
||||
lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a);
|
||||
lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a);
|
||||
lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a);
|
||||
lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a);
|
||||
|
||||
lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a);
|
||||
lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a);
|
||||
lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a);
|
||||
lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a);
|
||||
|
||||
lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a);
|
||||
lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a);
|
||||
lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a);
|
||||
lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a);
|
||||
|
||||
lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a);
|
||||
lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a);
|
||||
lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a);
|
||||
lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Code for other subgroup sizes, simply doesnt use subgroups at all.
|
||||
// Relies on reads from single location tile_B[][base_B + col] by all
|
||||
// being optimized by the hardware.
|
||||
for (var col:u32 = 0; col < 4; col++)
|
||||
{
|
||||
var b0: vec4<u32> = vec4<u32>(tile_B[base_B + col][0], tile_B[base_B + col][1]);
|
||||
var b1: vec4<u32> = vec4<u32>(tile_B[base_B + col][2], tile_B[base_B + col][3]);
|
||||
var local_sum = DP4AI(own_a0, b0);
|
||||
local_sum += DP4AI(own_a1, b1);
|
||||
lane_output[col] += (output_element_t(local_sum) * own_scale_a * scale_B[base_B + col]);
|
||||
lane_output1[col] += SDP8AI(own_a0, tile_B[0][base_B + col], own_a1, tile_B[1][base_B + col], own_scale_a * scale_B[base_B + col]);
|
||||
lane_output2[col] += SDP8AI(own_a0, tile_B[0][base_B + col + 4], own_a1, tile_B[1][base_B + col + 4], own_scale_a * scale_B[base_B + col + 4]);
|
||||
lane_output3[col] += SDP8AI(own_a0, tile_B[0][base_B + col + 8], own_a1, tile_B[1][base_B + col + 8], own_scale_a * scale_B[base_B + col + 8]);
|
||||
lane_output4[col] += SDP8AI(own_a0, tile_B[0][base_B + col + 12], own_a1, tile_B[1][base_B + col + 12], own_scale_a * scale_B[base_B + col + 12]);
|
||||
}
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
|
@ -735,11 +768,10 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
|
|||
// This creates a shader requirement that uniforms.N % 16 == 0
|
||||
if (a_global < uniforms.M && b_global < uniforms.N)
|
||||
{
|
||||
for (var i:u32 = 0; i < 4; i++)
|
||||
{
|
||||
let lidx = i * 4;
|
||||
output[output_idx+i] = vec4<output_element_t>(lane_output[lidx], lane_output[lidx+1] , lane_output[lidx+2], lane_output[lidx+3]);
|
||||
}
|
||||
output[output_idx] = lane_output1;
|
||||
output[output_idx+1] = lane_output2;
|
||||
output[output_idx+2] = lane_output3;
|
||||
output[output_idx+3] = lane_output4;
|
||||
}
|
||||
)MAIN_FN";
|
||||
|
||||
|
|
@ -812,9 +844,9 @@ Status MatMulNBits::ComputeInternal(onnxruntime::webgpu::ComputeContext& context
|
|||
mul_program.SetDispatchGroupSize(
|
||||
(M + kTileSize - 1) / kTileSize,
|
||||
(N + kTileSize - 1) / kTileSize, 1);
|
||||
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components)},
|
||||
mul_program.AddInputs({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec4Components)},
|
||||
{&a_scale, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kU32Components)},
|
||||
{b, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(kVec2Components * kU32Components)},
|
||||
{scales, ProgramTensorMetadataDependency::TypeAndRank, gsl::narrow<int>(1)}})
|
||||
.AddUniformVariables({{static_cast<uint32_t>(M)},
|
||||
{static_cast<uint32_t>(N)},
|
||||
|
|
|
|||
Loading…
Reference in a new issue