[ghstack-poisoned]
This commit is contained in:
Xia, Weiwen 2025-02-08 05:51:03 -08:00
parent c4267ed1c5
commit ad2dc02f7f
8 changed files with 468 additions and 21 deletions

View file

@ -10,6 +10,7 @@ import torch
import torch._dynamo.config
import torch._dynamo.config as dynamo_config
import torch._inductor.config as inductor_config
import torch._inductor.cpu_vec_isa
import torch._inductor.select_algorithm as select_algorithm
from torch._dynamo.utils import counters
from torch._inductor import test_operators
@ -1499,6 +1500,46 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
vec_amx = VecAMX()
self._check_amx_counter(vec_amx)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@dtypes(torch.bfloat16)
@parametrize("batch_size", (32,))
@parametrize("in_features", (128, 256))
@parametrize("out_features", (64, 128))
@parametrize("group_size", (32, 64))
def test_int4_woq_mm_avx512(
self, dtype, batch_size, in_features, out_features, group_size
):
class M(torch.nn.Module):
def __init__(self, K, N, group_size):
super().__init__()
self.linear_weight = torch.randint(
0, 15, (N, K // 2), dtype=torch.uint8
)
self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype)
self.group_size = group_size
def forward(self, x):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
y = torch._weight_int4pack_mm_for_cpu(
x, self.linear_weight, self.group_size, self.qscale_and_zeros
)
return y.reshape(*x_shape[:-1], out_features)
counters.clear()
seq_len = 8
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
mod = M(in_features, out_features, group_size).eval()
self.common(mod, (x,), reference_in_float=False)
available_isa = torch._inductor.cpu_vec_isa.pick_vec_isa()
vax512_available = "avx512" in str(available_isa)
autotune_count = 1 if vax512_available else 0
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"], autotune_count
)
@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad

View file

@ -85,6 +85,7 @@ class CppBmmTemplate(CppGemmTemplate):
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
should_block_weights: bool = False,
name="bmm",
**kwargs,
):
"""
In order to simplify the implementation and increase code reuse, the BMM template implements

View file

@ -102,6 +102,9 @@ GEMM_TEMPLATE_INIT_BLOCKING = r"""
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
{%- endif %}
{%- if is_woq_int4 %}
int64_t group_size = *q_group_size;
{%- endif %}
// make sure all partitions are assigned
{{kernel.assert_function}}(
@ -170,6 +173,8 @@ GEMM_TEMPLATE_MICROKERNEL_DEF = r"""
GEMM_TEMPLATE_STUB_DEF = r"""
{%- if x_scale is not none %}
{%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
{%- elif is_woq_int4 %}
{%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %}
{%- else %}
{%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
{%- endif %}
@ -183,7 +188,7 @@ GEMM_TEMPLATE = r"""
{
{{ kernel.maybe_codegen_profile() }}
{{ template.codegen_blocks(
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W, is_woq_int4
) }}
{%- if maybe_k_slicing %}
@ -226,12 +231,40 @@ GEMM_TEMPLATE = r"""
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
{%- else %}
{%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
{%- if is_woq_int4 %}
{%- set tile_W = kernel.slice_nd(W, [("n_start", "n_start + n_size"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %}
{%- set tile_qparam = kernel.slice_nd(
qscale_and_zeros, [("k_start / group_size", "k_end / group_size"), ("n_start", "n_start + n_size"), ()]) %}
{%- else %}
{%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
{%- endif %}
{%- endif %}
if (kc == k_block_start) {
{%- if is_woq_int4 %}
{{ micro_gemm.codegen_call(kernel,
tile_X,
tile_W,
acc_slice,
accum=False,
is_woq_int4=True,
qscale_and_zeros=tile_qparam)|indent(28, false)
}}
{%- else %}
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
{%- endif %}
} else {
{%- if is_woq_int4 %}
{{ micro_gemm.codegen_call(kernel,
tile_X,
tile_W,
acc_slice,
accum=True,
is_woq_int4=True,
qscale_and_zeros=tile_qparam)|indent(28, false)
}}
{%- else %}
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
{%- endif %}
}
}
}
@ -501,6 +534,8 @@ class CppGemmTemplate(CppTemplate):
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
should_block_weights: bool = True,
name="packed_gemm",
is_woq_int4=False,
q_group_size=None,
) -> None:
assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
super().__init__(
@ -522,6 +557,10 @@ class CppGemmTemplate(CppTemplate):
self.should_block_weights = should_block_weights
self.thread_blocking = self.make_thread_blocking_cache()
self.cache_blocking = self.make_cache_blocking_cache()
self.is_woq_int4 = is_woq_int4
if is_woq_int4:
assert not should_block_weights, "Weight is already packed for WOQ int4"
self.q_group_size = q_group_size
def make_thread_blocking_cache(self):
cache = lru_cache()(self._thread_blocking)
@ -798,6 +837,8 @@ class CppGemmTemplate(CppTemplate):
input_indices=None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None,
is_woq_int4: bool = False,
q_group_size: Optional[int] = None,
):
if input_indices is None:
input_indices = list(range(len(input_nodes)))
@ -877,7 +918,12 @@ class CppGemmTemplate(CppTemplate):
# TODO(jgong5): decide proper number of threads per problem size
num_threads = parallel_num_threads()
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
m, n, k, *_ = mm_args(
new_inputs[0],
new_inputs[1],
mat2_transposed=is_woq_int4,
use_4x2_dim=is_woq_int4,
)
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
new_inputs[0].get_dtype()
)
@ -892,6 +938,8 @@ class CppGemmTemplate(CppTemplate):
compute_dtype=compute_dtype,
alpha=alpha,
num_threads=num_threads,
is_woq_int4=is_woq_int4,
q_group_size=q_group_size,
)
assert micro_gemm is not None
block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm)
@ -947,6 +995,9 @@ class CppGemmTemplate(CppTemplate):
has_bias=has_bias,
epilogue_creator=epilogue_creator,
should_block_weights=block_weights,
name=micro_gemm.__class__.__name__,
is_woq_int4=is_woq_int4,
q_group_size=q_group_size,
)
template.maybe_append_choice(choices)
return template
@ -989,10 +1040,15 @@ class CppGemmTemplate(CppTemplate):
"""
W = inputs[1]
new_inputs = list(inputs)
if isinstance(W, ir.IRNode):
k, n = W.get_size()[-2:]
if micro_gemm.is_woq_int4:
assert (
len(W.get_size()) == 2
if isinstance(W, ir.IRNode)
else len(W.shape) == 2
)
n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape
else:
k, n = W.shape[-2:]
k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:]
_, block_n, _ = micro_gemm.register_blocking
new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight)
padding = padded_n - n
@ -1027,6 +1083,9 @@ class CppGemmTemplate(CppTemplate):
@staticmethod
def check_if_block_weight(W, micro_gemm):
if micro_gemm.is_woq_int4:
# For WOQ INT4, weight is already packed
return False
return True
@classmethod
@ -1141,6 +1200,9 @@ class CppGemmTemplate(CppTemplate):
x_zp = None
w_scale = None
w_zp = None
inp = None
q_group_size = None
qscale_and_zeros = None
if int8_gemm:
X, W = self.input_nodes[0], self.input_nodes[1]
bias_idx = 2 if self.has_bias else 1
@ -1150,6 +1212,11 @@ class CppGemmTemplate(CppTemplate):
w_scale = self.input_nodes[bias_idx + 3]
w_zp = self.input_nodes[bias_idx + 4]
Y = self.output_node
elif self.is_woq_int4:
X, W = self.input_nodes[0], self.input_nodes[1]
Y = self.output_node
q_group_size = self.input_nodes[2]
qscale_and_zeros = self.input_nodes[3]
else:
X, W = self.input_nodes[0], self.input_nodes[1]
Y = self.output_node
@ -1311,6 +1378,8 @@ class CppGemmTemplate(CppTemplate):
compute_dtype=compute_dtype,
alpha=self.alpha,
num_threads=self.num_threads,
is_woq_int4=self.is_woq_int4,
q_group_size=self.q_group_size,
)
assert micro_gemm is not None
assert self.register_blocking == micro_gemm.register_blocking
@ -1359,6 +1428,9 @@ class CppGemmTemplate(CppTemplate):
L2_cache_size=L2_cache_size,
config=config,
fake_buffers=fake_buffers,
is_woq_int4=self.is_woq_int4,
q_group_size=q_group_size,
qscale_and_zeros=qscale_and_zeros,
)
return options
@ -1399,6 +1471,7 @@ class CppGemmTemplate(CppTemplate):
L2_cache_size,
X,
W,
is_woq_int4=False,
):
options = dict(
num_threads=num_threads,
@ -1414,6 +1487,7 @@ class CppGemmTemplate(CppTemplate):
template=self,
X=X,
W=W,
is_woq_int4=is_woq_int4,
)
return self._template_from_string(GEMM_TEMPLATE_INIT_BLOCKING).render(options)

View file

@ -201,6 +201,8 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
input_indices: Optional[list[int]] = None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
is_woq_int4: bool = False,
q_group_size: Optional[int] = None,
) -> DataProcessorTemplateWrapper:
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
gemm_grouped_num = len(has_bias)

View file

@ -81,6 +81,7 @@ inline void {{kernel_name}}(
compute_dtype,
register_blocking,
alpha=1,
is_woq_int4=False,
) -> None:
self.name = name
self.input_dtype = input_dtype
@ -90,6 +91,7 @@ inline void {{kernel_name}}(
self.compute_dtype = compute_dtype
self.register_blocking = register_blocking
self.alpha = alpha
self.is_woq_int4 = is_woq_int4
def get_common_options(self):
if self.input_dtype in [torch.uint8, torch.int8]:
@ -113,6 +115,7 @@ inline void {{kernel_name}}(
"vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2,
"restrict_keyword": get_restrict_keyword(),
"is_msvc_compiler": cpp_builder.is_msvc_cl(),
"is_woq_int4": self.is_woq_int4,
}
def get_kernel_declaration(self):
@ -122,8 +125,8 @@ inline void {{kernel_name}}(
def get_kernel_extra_args_declare(self) -> str:
return ""
def get_kernel_extra_args(self) -> str:
return ""
def get_kernel_extra_args(self, **kwargs) -> list[str]:
return []
def codegen_define(self, kernel: CppTemplateKernel) -> str:
raise NotImplementedError
@ -135,6 +138,8 @@ inline void {{kernel_name}}(
B: ir.Buffer,
C: ir.Buffer,
accum: bool,
is_woq_int4: bool = False,
qscale_and_zeros: Optional[ir.Buffer] = None,
) -> str:
"""
Generate the code for calling the templated kernel that computes
@ -152,9 +157,14 @@ inline void {{kernel_name}}(
res = IndentedBuffer()
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
with res.indent():
extra_args = self.get_kernel_extra_args()
if extra_args:
res.writeline(extra_args)
kwargs_for_extra_args = (
{"kernel": kernel, "qscale_and_zeros": qscale_and_zeros}
if is_woq_int4
else {}
)
extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
for arg in extra_args:
res.writeline(arg)
res.writeline(f"{A_ptr},")
res.writeline(f"{B_ptr},")
res.writeline(f"{C_ptr},")
@ -803,8 +813,8 @@ inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
def get_kernel_extra_args_declare(self) -> str:
return "AMXState& amx_state,"
def get_kernel_extra_args(self) -> str:
return "amx_state,"
def get_kernel_extra_args(self, **kwargs) -> list[str]:
return ["amx_state,"]
def get_b_layout(self):
if self.input_dtype in [torch.uint8, torch.int8]:
@ -876,6 +886,241 @@ class CppMicroBrgemm(CppMicroGemm):
return LayoutType.VNNI2
@register_micro_gemm(
*generate_gemm_config(
VecAVX512,
[(4, 64, 32), (4, 64, 64), (4, 64, 128)],
input_dtype=torch.bfloat16,
input2_dtype=torch.uint8,
output_dtype=torch.float,
compute_dtype=torch.float,
),
)
class CppMicroGemmWoQInt4Vec(CppMicroGemmFP32Vec):
"""
This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics.
It is based on the corresponding ATen kernel.
Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
Shape of packed ScalesAndZeros = [K // group_size, N, 2]
"""
TEMPLATE_ENTRY = r"""
{{declare_kernel}} {
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
auto group_size = q_group_size;
for (int64_t m = 0; m < M; m += {{block_m}}) {
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
for (int64_t n = 0; n < N; n += {{block_n}}) {
if (block_m == {{block_m}}) {
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
A + m * lda,
reinterpret_cast<const uint8_t*>(B) + n * ldb,
C + m * ldc + n,
K,
lda,
/* ldb */ {{block_n}} / 2,
ldc,
group_size,
ScaleAndZeros + n * 2,
lds,
k_start
);
} else {
switch (block_m) {
{%- for b in range(block_m - 1, 0, -1) %}
case {{b}}:
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
A + m * lda,
reinterpret_cast<const uint8_t*>(B) + n * ldb,
C + m * ldc + n,
K,
lda,
/* ldb */ {{block_n}} / 2,
ldc,
group_size,
ScaleAndZeros + n * 2,
lds,
k_start
);
break;
{%- endfor %}
default:
{{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
}
}
}
}
}
"""
TEMPLATE_KERNEL = r"""
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
inline void {{kernel_name}}_kernel(
const {{input_t}}* {{restrict_keyword}} A,
const uint8_t* {{restrict_keyword}} B,
{{output_t}}* {{restrict_keyword}} C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t q_group_size,
const bfloat16* {{restrict_keyword}} ScaleAndZeros,
int64_t lds, // leading dimension of ScaleAndZeros
int64_t k_start) {
constexpr int BLOCK_K = {{block_k}};
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
const int PREFETCH_SIZE_K = 16 * 4;
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
// number of blocks on K
const int KB = K / BLOCK_K;
__m512 va;
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
__m512 scale[COLS];
__m512 zero[COLS];
// Lookup table to de-quantize int4 values to bf16.
// Values are dequantized as truly int4 [-8, 7] range;
//
// dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
//
static const __m512 lut = _mm512_set_ps(
7.0f, 6.0f, 5.0f, 4.0f,
3.0f, 2.0f, 1.0f, 0.0f,
-1.0f, -2.0f, -3.0f, -4.0f,
-5.0f, -6.0f, -7.0f, -8.0f);
// index for transpose
static const __m512i idx1 = _mm512_set_epi32(
30, 28, 26, 24, 22, 20, 18, 16,
14, 12, 10, 8, 6, 4, 2, 0);
static const __m512i idx2 = _mm512_set_epi32(
31, 29, 27, 25, 23, 21, 19, 17,
15, 13, 11, 9, 7, 5, 3, 1);
// load scale and zero point
auto load_scale_and_zeros = [&](int i, int _kb) {
// load 2x bfloat16 vector
__m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
if (_kb + PREFETCH_SIZE_KB < KB) {
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
}
// convert to 2x f32 vector
__m512 a, b;
at::vec::cvtbf16_fp32(t, a, b);
// transpose scale_and_zero from {16, 2} to {2, 16}
// inputs:
// a: {s0, z0, s1, z1, ..., s7, z7}
// b: {s8, z8, s9, z9, ..., s15, z15}
// output:
// scale: {s0, s1, s2, ..., s15}
// zero: {z0, z1, z2, ..., z15}
scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
};
auto loadc = [&](auto i) {
if constexpr (accum) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
} else {
vc[i] = _mm512_setzero_ps();
}
};
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
auto compute = [&, COLS](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
float aa = static_cast<float>(A[row * lda + k]);
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
va = _mm512_set1_ps(aa);
}
if constexpr (row == 0) {
if constexpr (COLS == 4) {
// when BLOCK_N = 64, handle each row at a time
// to reduce de-quantize overhead.
if constexpr (col == 0) {
__m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
if (k + PREFETCH_SIZE_K < K) {
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
}
__m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
vb[0] = _mm512_permutexvar_ps(b32, lut);
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
vb[1] = _mm512_permutexvar_ps(b32, lut);
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
}
} else {
__m128i b8 = convert_int4_to_int8(B + k * ldb + col * 8);
__m512i b32 = _mm512_cvtepu8_epi32(b8);
vb[col] = _mm512_permutexvar_ps(b32, lut);
vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
}
}
constexpr int idx = row * COLS + col;
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
};
for (int k = 0, kb = 0; k < K; ++k) {
if (is_block_start(k, k_start, q_group_size)) {
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
}
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
}
//store to C
auto storec = [&, COLS](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
};
c10::ForcedUnroll<ROWS * COLS>{}(storec);
}
"""
def get_kernel_extra_args_declare(self) -> str:
return (
"const int64_t q_group_size,\n"
" const bfloat16* __restrict__ ScaleAndZeros,\n"
" const int64_t lds,\n"
" int64_t k_start,"
)
def get_kernel_extra_args(self, **kwargs) -> list[str]:
assert "kernel" in kwargs
assert "qscale_and_zeros" in kwargs
kernel = kwargs["kernel"]
qscale_and_zeros = kwargs["qscale_and_zeros"]
return [
"group_size,",
f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
"N * 2,",
"k_start,",
]
def create_micro_gemm(
name,
m,
@ -888,6 +1133,8 @@ def create_micro_gemm(
alpha=1,
num_threads=-1,
use_ref=True,
is_woq_int4=False,
q_group_size=None,
) -> Optional[CppMicroGemm]:
def create_from_config(cls, config: CppMicroGemmConfig):
return cls(
@ -898,8 +1145,38 @@ def create_micro_gemm(
config.compute_dtype,
config.register_blocking,
alpha,
is_woq_int4,
)
def woq_extra_check(config, m, n, k, alpha, is_woq_int4, q_group_size):
if input_dtype != torch.bfloat16 or input2_dtype not in [
torch.uint8,
torch.int8,
]:
# non-WOQ cases or WOQ with invalid input types
return True
if alpha != 1:
return False
if is_woq_int4:
assert q_group_size is not None
if (
q_group_size < 32
or k % q_group_size != 0
or config.register_blocking.block_k > q_group_size
):
return False
return k % config.register_blocking.block_k == 0 and n % 64 == 0
else: # WOQ INT8
if (
config.vec_isa_cls == VecAMX
and m < block_m
and input_dtype == torch.bfloat16
and input2_dtype == torch.int8
):
# For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
return False
return True
assert isinstance(n, int) or n.is_number, n
assert isinstance(k, int) or k.is_number, k
m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
@ -932,13 +1209,9 @@ def create_micro_gemm(
):
continue
block_m, block_n, block_k = config.register_blocking
if (
config.vec_isa_cls == VecAMX
and m < block_m
and input_dtype == torch.bfloat16
and input2_dtype == torch.int8
if not woq_extra_check(
config, m, n, k, alpha, is_woq_int4, q_group_size
):
# For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
continue
# Criteria on the ranking of configurations
# 1. ISA: AMX > VEC
@ -975,7 +1248,7 @@ def create_micro_gemm(
)
)
if len(matched_configs) == 0:
if use_ref:
if use_ref and not is_woq_int4:
return CppMicroGemmRef(
name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
)

View file

@ -484,6 +484,18 @@ inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x,
}
throw std::runtime_error("Unhandled vec_shuffle_down value " + std::to_string(n));
}
// For WOQ INT4
inline __m128i convert_int4_to_int8(const uint8_t* data) {
__m128i tmp = _mm_loadu_si64((const __m128i*)data);
__m128i bytes = _mm_cvtepu8_epi16(tmp);
const __m128i lowMask = _mm_set1_epi8(0xF);
__m128i high = _mm_andnot_si128(lowMask, bytes);
__m128i low = _mm_and_si128(lowMask, bytes);
high = _mm_slli_epi16(high, 4);
bytes = _mm_or_si128(low, high);
return bytes;
}
#endif
template <typename scalar_t>
@ -945,3 +957,8 @@ class AMXState {
tile_release();
}
};
// For group-wise quantization, e.g., WOQ INT4
inline bool is_block_start(int index, int k_start, int group_size) {
return (k_start + index) % group_size == 0;
}

View file

@ -2,6 +2,7 @@ import logging
from typing import Any
import torch
from torch._inductor.ir import IRNode
from torch._inductor.kernel.mm_common import mm_args
from . import config as inductor_config, lowering
@ -133,6 +134,21 @@ def register_woq_mm_ops() -> None:
if use_aten_gemm_kernels()
else []
)
if use_cpp_gemm_template(
aten_layout,
mat1,
mat2,
mat2_transposed=True,
is_woq_int4=True,
q_group_size=qGroupSize,
):
CppGemmTemplate.add_choices(
choices,
aten_layout,
[mat1, mat2, group_size, qScaleAndZeros],
is_woq_int4=True,
q_group_size=qGroupSize,
)
if (
len(choices) == 0
@ -144,11 +160,25 @@ def register_woq_mm_ops() -> None:
(mat1, mat2, group_size, qScaleAndZeros), aten_layout
).output_node()
# define functions to generate example inputs for weight and group size
# otherwise, autotuner generates example inputs of all zeros for them
def get_example_weight(x: IRNode) -> torch.Tensor:
shape = x.get_size()
device = x.get_device()
return torch.randint(0, 255, shape, dtype=torch.uint8, device=device)
def get_example_group_size(x: IRNode) -> torch.Tensor:
return torch.tensor(qGroupSize, dtype=torch.int64)
return autotune_select_algorithm(
"_weight_int4pack_mm_for_cpu",
choices,
[mat1, mat2, group_size, qScaleAndZeros],
aten_layout,
input_gen_fns={
1: get_example_weight,
2: get_example_group_size,
},
)
lowering.make_fallback(aten._dyn_quant_matmul_4bit)

View file

@ -1395,7 +1395,13 @@ def use_cpp_bmm_template(layout, mat1, mat2):
def use_cpp_gemm_template(
layout, mat1, mat2, mat2_transposed=False, require_constant_mat2=True
layout,
mat1,
mat2,
mat2_transposed=False,
require_constant_mat2=True,
is_woq_int4=False,
q_group_size=None,
):
from . import ir
from .codegen.cpp_micro_gemm import create_micro_gemm
@ -1415,6 +1421,7 @@ def use_cpp_gemm_template(
mat2,
out_dtype=layout.dtype if int8_gemm else None,
mat2_transposed=mat2_transposed,
use_4x2_dim=is_woq_int4,
)
# TODO(jgong5): support dynamic shapes for n or k
@ -1433,6 +1440,8 @@ def use_cpp_gemm_template(
input2_dtype=mat2.get_dtype(),
output_dtype=output_dtype,
num_threads=parallel_num_threads(),
is_woq_int4=is_woq_int4,
q_group_size=q_group_size,
)
def is_last_dim_stride1(x):