mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update
[ghstack-poisoned]
This commit is contained in:
parent
c4267ed1c5
commit
ad2dc02f7f
8 changed files with 468 additions and 21 deletions
|
|
@ -10,6 +10,7 @@ import torch
|
|||
import torch._dynamo.config
|
||||
import torch._dynamo.config as dynamo_config
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch._inductor.cpu_vec_isa
|
||||
import torch._inductor.select_algorithm as select_algorithm
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor import test_operators
|
||||
|
|
@ -1499,6 +1500,46 @@ class TestSelectAlgorithm(BaseTestSelectAlgorithm):
|
|||
vec_amx = VecAMX()
|
||||
self._check_amx_counter(vec_amx)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
@dtypes(torch.bfloat16)
|
||||
@parametrize("batch_size", (32,))
|
||||
@parametrize("in_features", (128, 256))
|
||||
@parametrize("out_features", (64, 128))
|
||||
@parametrize("group_size", (32, 64))
|
||||
def test_int4_woq_mm_avx512(
|
||||
self, dtype, batch_size, in_features, out_features, group_size
|
||||
):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, K, N, group_size):
|
||||
super().__init__()
|
||||
self.linear_weight = torch.randint(
|
||||
0, 15, (N, K // 2), dtype=torch.uint8
|
||||
)
|
||||
self.qscale_and_zeros = torch.rand(K // group_size, N, 2, dtype=dtype)
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
x_shape = x.shape
|
||||
x = x.reshape(-1, x_shape[-1])
|
||||
y = torch._weight_int4pack_mm_for_cpu(
|
||||
x, self.linear_weight, self.group_size, self.qscale_and_zeros
|
||||
)
|
||||
return y.reshape(*x_shape[:-1], out_features)
|
||||
|
||||
counters.clear()
|
||||
seq_len = 8
|
||||
x = torch.rand((batch_size, seq_len, in_features), dtype=dtype)
|
||||
mod = M(in_features, out_features, group_size).eval()
|
||||
self.common(mod, (x,), reference_in_float=False)
|
||||
available_isa = torch._inductor.cpu_vec_isa.pick_vec_isa()
|
||||
vax512_available = "avx512" in str(available_isa)
|
||||
autotune_count = 1 if vax512_available else 0
|
||||
self.assertEqual(
|
||||
counters["inductor"]["select_algorithm_autotune"], autotune_count
|
||||
)
|
||||
|
||||
@inductor_config.patch({"freezing": True})
|
||||
@patches
|
||||
@torch.no_grad
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class CppBmmTemplate(CppGemmTemplate):
|
|||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
should_block_weights: bool = False,
|
||||
name="bmm",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
In order to simplify the implementation and increase code reuse, the BMM template implements
|
||||
|
|
|
|||
|
|
@ -102,6 +102,9 @@ GEMM_TEMPLATE_INIT_BLOCKING = r"""
|
|||
constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
|
||||
constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;
|
||||
{%- endif %}
|
||||
{%- if is_woq_int4 %}
|
||||
int64_t group_size = *q_group_size;
|
||||
{%- endif %}
|
||||
|
||||
// make sure all partitions are assigned
|
||||
{{kernel.assert_function}}(
|
||||
|
|
@ -170,6 +173,8 @@ GEMM_TEMPLATE_MICROKERNEL_DEF = r"""
|
|||
GEMM_TEMPLATE_STUB_DEF = r"""
|
||||
{%- if x_scale is not none %}
|
||||
{%- set kernel_args = {"X": X, "W": W, "inp": inp, "x_scale": x_scale, "x_zp": x_zp, "w_scale": w_scale, "w_zp": w_zp,} %}
|
||||
{%- elif is_woq_int4 %}
|
||||
{%- set kernel_args = {"X": X, "W": W, "q_group_size": q_group_size, "qscale_and_zeros": qscale_and_zeros} %}
|
||||
{%- else %}
|
||||
{%- set kernel_args = {"X": X, "W": W, "inp": inp} %}
|
||||
{%- endif %}
|
||||
|
|
@ -183,7 +188,7 @@ GEMM_TEMPLATE = r"""
|
|||
{
|
||||
{{ kernel.maybe_codegen_profile() }}
|
||||
{{ template.codegen_blocks(
|
||||
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W
|
||||
num_threads, N, K, micro_gemm, is_dynamic_M, kernel, GemmOut, config, L1_cache_size, L2_cache_size, X, W, is_woq_int4
|
||||
) }}
|
||||
|
||||
{%- if maybe_k_slicing %}
|
||||
|
|
@ -226,12 +231,40 @@ GEMM_TEMPLATE = r"""
|
|||
{%- set tile_W_3d = kernel.slice_nd(W, [("nci", "nci + 1"), ("k_start", "k_end"), ()]) %}
|
||||
{%- set tile_W = kernel.view(tile_W_3d, ["k_end - k_start", micro_gemm.register_blocking.block_n]) %}
|
||||
{%- else %}
|
||||
{%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
|
||||
{%- if is_woq_int4 %}
|
||||
{%- set tile_W = kernel.slice_nd(W, [("n_start", "n_start + n_size"), ("k_start * Nr / 2", "k_end * Nr / 2")]) %}
|
||||
{%- set tile_qparam = kernel.slice_nd(
|
||||
qscale_and_zeros, [("k_start / group_size", "k_end / group_size"), ("n_start", "n_start + n_size"), ()]) %}
|
||||
{%- else %}
|
||||
{%- set tile_W = kernel.slice_nd(W, [("k_start", "k_end"), ("n_start", "n_start + n_size")]) %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
if (kc == k_block_start) {
|
||||
{%- if is_woq_int4 %}
|
||||
{{ micro_gemm.codegen_call(kernel,
|
||||
tile_X,
|
||||
tile_W,
|
||||
acc_slice,
|
||||
accum=False,
|
||||
is_woq_int4=True,
|
||||
qscale_and_zeros=tile_qparam)|indent(28, false)
|
||||
}}
|
||||
{%- else %}
|
||||
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=False)|indent(28, false) }}
|
||||
{%- endif %}
|
||||
} else {
|
||||
{%- if is_woq_int4 %}
|
||||
{{ micro_gemm.codegen_call(kernel,
|
||||
tile_X,
|
||||
tile_W,
|
||||
acc_slice,
|
||||
accum=True,
|
||||
is_woq_int4=True,
|
||||
qscale_and_zeros=tile_qparam)|indent(28, false)
|
||||
}}
|
||||
{%- else %}
|
||||
{{ micro_gemm.codegen_call(kernel, tile_X, tile_W, acc_slice, accum=True)|indent(28, false) }}
|
||||
{%- endif %}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -501,6 +534,8 @@ class CppGemmTemplate(CppTemplate):
|
|||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
should_block_weights: bool = True,
|
||||
name="packed_gemm",
|
||||
is_woq_int4=False,
|
||||
q_group_size=None,
|
||||
) -> None:
|
||||
assert layout.dtype in [torch.float, torch.bfloat16, torch.half, torch.uint8]
|
||||
super().__init__(
|
||||
|
|
@ -522,6 +557,10 @@ class CppGemmTemplate(CppTemplate):
|
|||
self.should_block_weights = should_block_weights
|
||||
self.thread_blocking = self.make_thread_blocking_cache()
|
||||
self.cache_blocking = self.make_cache_blocking_cache()
|
||||
self.is_woq_int4 = is_woq_int4
|
||||
if is_woq_int4:
|
||||
assert not should_block_weights, "Weight is already packed for WOQ int4"
|
||||
self.q_group_size = q_group_size
|
||||
|
||||
def make_thread_blocking_cache(self):
|
||||
cache = lru_cache()(self._thread_blocking)
|
||||
|
|
@ -798,6 +837,8 @@ class CppGemmTemplate(CppTemplate):
|
|||
input_indices=None,
|
||||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
act_mapping: Optional[dict[int, ir.IRNode]] = None,
|
||||
is_woq_int4: bool = False,
|
||||
q_group_size: Optional[int] = None,
|
||||
):
|
||||
if input_indices is None:
|
||||
input_indices = list(range(len(input_nodes)))
|
||||
|
|
@ -877,7 +918,12 @@ class CppGemmTemplate(CppTemplate):
|
|||
# TODO(jgong5): decide proper number of threads per problem size
|
||||
num_threads = parallel_num_threads()
|
||||
new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout))
|
||||
m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1])
|
||||
m, n, k, *_ = mm_args(
|
||||
new_inputs[0],
|
||||
new_inputs[1],
|
||||
mat2_transposed=is_woq_int4,
|
||||
use_4x2_dim=is_woq_int4,
|
||||
)
|
||||
output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype(
|
||||
new_inputs[0].get_dtype()
|
||||
)
|
||||
|
|
@ -892,6 +938,8 @@ class CppGemmTemplate(CppTemplate):
|
|||
compute_dtype=compute_dtype,
|
||||
alpha=alpha,
|
||||
num_threads=num_threads,
|
||||
is_woq_int4=is_woq_int4,
|
||||
q_group_size=q_group_size,
|
||||
)
|
||||
assert micro_gemm is not None
|
||||
block_weights = cls.check_if_block_weight(new_inputs[1], micro_gemm)
|
||||
|
|
@ -947,6 +995,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
has_bias=has_bias,
|
||||
epilogue_creator=epilogue_creator,
|
||||
should_block_weights=block_weights,
|
||||
name=micro_gemm.__class__.__name__,
|
||||
is_woq_int4=is_woq_int4,
|
||||
q_group_size=q_group_size,
|
||||
)
|
||||
template.maybe_append_choice(choices)
|
||||
return template
|
||||
|
|
@ -989,10 +1040,15 @@ class CppGemmTemplate(CppTemplate):
|
|||
"""
|
||||
W = inputs[1]
|
||||
new_inputs = list(inputs)
|
||||
if isinstance(W, ir.IRNode):
|
||||
k, n = W.get_size()[-2:]
|
||||
if micro_gemm.is_woq_int4:
|
||||
assert (
|
||||
len(W.get_size()) == 2
|
||||
if isinstance(W, ir.IRNode)
|
||||
else len(W.shape) == 2
|
||||
)
|
||||
n, k = W.get_size() if isinstance(W, ir.IRNode) else W.shape
|
||||
else:
|
||||
k, n = W.shape[-2:]
|
||||
k, n = W.get_size()[-2:] if isinstance(W, ir.IRNode) else W.shape[-2:]
|
||||
_, block_n, _ = micro_gemm.register_blocking
|
||||
new_size, padded_n = cls.get_padded_size(n, block_n, k, should_block_weight)
|
||||
padding = padded_n - n
|
||||
|
|
@ -1027,6 +1083,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
|
||||
@staticmethod
|
||||
def check_if_block_weight(W, micro_gemm):
|
||||
if micro_gemm.is_woq_int4:
|
||||
# For WOQ INT4, weight is already packed
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1141,6 +1200,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
x_zp = None
|
||||
w_scale = None
|
||||
w_zp = None
|
||||
inp = None
|
||||
q_group_size = None
|
||||
qscale_and_zeros = None
|
||||
if int8_gemm:
|
||||
X, W = self.input_nodes[0], self.input_nodes[1]
|
||||
bias_idx = 2 if self.has_bias else 1
|
||||
|
|
@ -1150,6 +1212,11 @@ class CppGemmTemplate(CppTemplate):
|
|||
w_scale = self.input_nodes[bias_idx + 3]
|
||||
w_zp = self.input_nodes[bias_idx + 4]
|
||||
Y = self.output_node
|
||||
elif self.is_woq_int4:
|
||||
X, W = self.input_nodes[0], self.input_nodes[1]
|
||||
Y = self.output_node
|
||||
q_group_size = self.input_nodes[2]
|
||||
qscale_and_zeros = self.input_nodes[3]
|
||||
else:
|
||||
X, W = self.input_nodes[0], self.input_nodes[1]
|
||||
Y = self.output_node
|
||||
|
|
@ -1311,6 +1378,8 @@ class CppGemmTemplate(CppTemplate):
|
|||
compute_dtype=compute_dtype,
|
||||
alpha=self.alpha,
|
||||
num_threads=self.num_threads,
|
||||
is_woq_int4=self.is_woq_int4,
|
||||
q_group_size=self.q_group_size,
|
||||
)
|
||||
assert micro_gemm is not None
|
||||
assert self.register_blocking == micro_gemm.register_blocking
|
||||
|
|
@ -1359,6 +1428,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
L2_cache_size=L2_cache_size,
|
||||
config=config,
|
||||
fake_buffers=fake_buffers,
|
||||
is_woq_int4=self.is_woq_int4,
|
||||
q_group_size=q_group_size,
|
||||
qscale_and_zeros=qscale_and_zeros,
|
||||
)
|
||||
return options
|
||||
|
||||
|
|
@ -1399,6 +1471,7 @@ class CppGemmTemplate(CppTemplate):
|
|||
L2_cache_size,
|
||||
X,
|
||||
W,
|
||||
is_woq_int4=False,
|
||||
):
|
||||
options = dict(
|
||||
num_threads=num_threads,
|
||||
|
|
@ -1414,6 +1487,7 @@ class CppGemmTemplate(CppTemplate):
|
|||
template=self,
|
||||
X=X,
|
||||
W=W,
|
||||
is_woq_int4=is_woq_int4,
|
||||
)
|
||||
return self._template_from_string(GEMM_TEMPLATE_INIT_BLOCKING).render(options)
|
||||
|
||||
|
|
|
|||
|
|
@ -201,6 +201,8 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
|||
input_indices: Optional[list[int]] = None,
|
||||
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
|
||||
act_mapping: Optional[dict[int, ir.IRNode]] = None, # gemm idx to its act buf
|
||||
is_woq_int4: bool = False,
|
||||
q_group_size: Optional[int] = None,
|
||||
) -> DataProcessorTemplateWrapper:
|
||||
# Input nodes order: x, optional[x1], ... w0, w1, ... optional[b0], optional[b1], ...
|
||||
gemm_grouped_num = len(has_bias)
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ inline void {{kernel_name}}(
|
|||
compute_dtype,
|
||||
register_blocking,
|
||||
alpha=1,
|
||||
is_woq_int4=False,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.input_dtype = input_dtype
|
||||
|
|
@ -90,6 +91,7 @@ inline void {{kernel_name}}(
|
|||
self.compute_dtype = compute_dtype
|
||||
self.register_blocking = register_blocking
|
||||
self.alpha = alpha
|
||||
self.is_woq_int4 = is_woq_int4
|
||||
|
||||
def get_common_options(self):
|
||||
if self.input_dtype in [torch.uint8, torch.int8]:
|
||||
|
|
@ -113,6 +115,7 @@ inline void {{kernel_name}}(
|
|||
"vnni_size": 4 if self.input_dtype in [torch.uint8, torch.int8] else 2,
|
||||
"restrict_keyword": get_restrict_keyword(),
|
||||
"is_msvc_compiler": cpp_builder.is_msvc_cl(),
|
||||
"is_woq_int4": self.is_woq_int4,
|
||||
}
|
||||
|
||||
def get_kernel_declaration(self):
|
||||
|
|
@ -122,8 +125,8 @@ inline void {{kernel_name}}(
|
|||
def get_kernel_extra_args_declare(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_kernel_extra_args(self) -> str:
|
||||
return ""
|
||||
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
||||
return []
|
||||
|
||||
def codegen_define(self, kernel: CppTemplateKernel) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
@ -135,6 +138,8 @@ inline void {{kernel_name}}(
|
|||
B: ir.Buffer,
|
||||
C: ir.Buffer,
|
||||
accum: bool,
|
||||
is_woq_int4: bool = False,
|
||||
qscale_and_zeros: Optional[ir.Buffer] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate the code for calling the templated kernel that computes
|
||||
|
|
@ -152,9 +157,14 @@ inline void {{kernel_name}}(
|
|||
res = IndentedBuffer()
|
||||
res.writeline(f"{self.name}<{value_to_cpp(accum, 'bool')}>(")
|
||||
with res.indent():
|
||||
extra_args = self.get_kernel_extra_args()
|
||||
if extra_args:
|
||||
res.writeline(extra_args)
|
||||
kwargs_for_extra_args = (
|
||||
{"kernel": kernel, "qscale_and_zeros": qscale_and_zeros}
|
||||
if is_woq_int4
|
||||
else {}
|
||||
)
|
||||
extra_args = self.get_kernel_extra_args(**kwargs_for_extra_args)
|
||||
for arg in extra_args:
|
||||
res.writeline(arg)
|
||||
res.writeline(f"{A_ptr},")
|
||||
res.writeline(f"{B_ptr},")
|
||||
res.writeline(f"{C_ptr},")
|
||||
|
|
@ -803,8 +813,8 @@ inline void {{kernel_name}}_amx_kernel_{{num_rows}}_{{num_columns}}(
|
|||
def get_kernel_extra_args_declare(self) -> str:
|
||||
return "AMXState& amx_state,"
|
||||
|
||||
def get_kernel_extra_args(self) -> str:
|
||||
return "amx_state,"
|
||||
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
||||
return ["amx_state,"]
|
||||
|
||||
def get_b_layout(self):
|
||||
if self.input_dtype in [torch.uint8, torch.int8]:
|
||||
|
|
@ -876,6 +886,241 @@ class CppMicroBrgemm(CppMicroGemm):
|
|||
return LayoutType.VNNI2
|
||||
|
||||
|
||||
@register_micro_gemm(
|
||||
*generate_gemm_config(
|
||||
VecAVX512,
|
||||
[(4, 64, 32), (4, 64, 64), (4, 64, 128)],
|
||||
input_dtype=torch.bfloat16,
|
||||
input2_dtype=torch.uint8,
|
||||
output_dtype=torch.float,
|
||||
compute_dtype=torch.float,
|
||||
),
|
||||
)
|
||||
class CppMicroGemmWoQInt4Vec(CppMicroGemmFP32Vec):
|
||||
"""
|
||||
This class generates the code for WoQ int4 micro gemm using AVX512 intrinsics.
|
||||
It is based on the corresponding ATen kernel.
|
||||
Shape of packed weight = [N // 64, K, 32], viewed as [N, K // 2]
|
||||
Shape of packed ScalesAndZeros = [K // group_size, N, 2]
|
||||
"""
|
||||
|
||||
TEMPLATE_ENTRY = r"""
|
||||
{{declare_kernel}} {
|
||||
{{kernel.assert_function}}(N % {{block_n}} == 0, "N dimension must be multiple of {{block_n}}");
|
||||
{{kernel.assert_function}}(K % {{block_k}} == 0, "K dimension must be multiple of {{block_k}}");
|
||||
auto group_size = q_group_size;
|
||||
for (int64_t m = 0; m < M; m += {{block_m}}) {
|
||||
int64_t block_m = std::min<int64_t>(M - m, {{block_m}});
|
||||
for (int64_t n = 0; n < N; n += {{block_n}}) {
|
||||
if (block_m == {{block_m}}) {
|
||||
{{kernel_name}}_kernel<{{block_m}}, {{block_n}}, accum>(
|
||||
A + m * lda,
|
||||
reinterpret_cast<const uint8_t*>(B) + n * ldb,
|
||||
C + m * ldc + n,
|
||||
K,
|
||||
lda,
|
||||
/* ldb */ {{block_n}} / 2,
|
||||
ldc,
|
||||
group_size,
|
||||
ScaleAndZeros + n * 2,
|
||||
lds,
|
||||
k_start
|
||||
);
|
||||
} else {
|
||||
switch (block_m) {
|
||||
{%- for b in range(block_m - 1, 0, -1) %}
|
||||
case {{b}}:
|
||||
{{kernel_name}}_kernel<{{b}}, {{block_n}}, accum>(
|
||||
A + m * lda,
|
||||
reinterpret_cast<const uint8_t*>(B) + n * ldb,
|
||||
C + m * ldc + n,
|
||||
K,
|
||||
lda,
|
||||
/* ldb */ {{block_n}} / 2,
|
||||
ldc,
|
||||
group_size,
|
||||
ScaleAndZeros + n * 2,
|
||||
lds,
|
||||
k_start
|
||||
);
|
||||
break;
|
||||
{%- endfor %}
|
||||
default:
|
||||
{{kernel.assert_function}}(false, "Unsupported block_m: ", block_m);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
TEMPLATE_KERNEL = r"""
|
||||
template <int64_t BLOCK_M, int64_t BLOCK_N, bool accum>
|
||||
inline void {{kernel_name}}_kernel(
|
||||
const {{input_t}}* {{restrict_keyword}} A,
|
||||
const uint8_t* {{restrict_keyword}} B,
|
||||
{{output_t}}* {{restrict_keyword}} C,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
int64_t q_group_size,
|
||||
const bfloat16* {{restrict_keyword}} ScaleAndZeros,
|
||||
int64_t lds, // leading dimension of ScaleAndZeros
|
||||
int64_t k_start) {
|
||||
constexpr int BLOCK_K = {{block_k}};
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int PREFETCH_SIZE_K = 16 * 4;
|
||||
const int PREFETCH_SIZE_KB = (PREFETCH_SIZE_K + BLOCK_K - 1) / BLOCK_K;
|
||||
|
||||
// number of blocks on K
|
||||
const int KB = K / BLOCK_K;
|
||||
|
||||
__m512 va;
|
||||
__m512 vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 scale[COLS];
|
||||
__m512 zero[COLS];
|
||||
|
||||
// Lookup table to de-quantize int4 values to bf16.
|
||||
// Values are dequantized as truly int4 [-8, 7] range;
|
||||
//
|
||||
// dequant = (bf16(int4_value) * bf16_scale) + bf16_zero
|
||||
//
|
||||
static const __m512 lut = _mm512_set_ps(
|
||||
7.0f, 6.0f, 5.0f, 4.0f,
|
||||
3.0f, 2.0f, 1.0f, 0.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f,
|
||||
-5.0f, -6.0f, -7.0f, -8.0f);
|
||||
|
||||
// index for transpose
|
||||
static const __m512i idx1 = _mm512_set_epi32(
|
||||
30, 28, 26, 24, 22, 20, 18, 16,
|
||||
14, 12, 10, 8, 6, 4, 2, 0);
|
||||
static const __m512i idx2 = _mm512_set_epi32(
|
||||
31, 29, 27, 25, 23, 21, 19, 17,
|
||||
15, 13, 11, 9, 7, 5, 3, 1);
|
||||
|
||||
// load scale and zero point
|
||||
auto load_scale_and_zeros = [&](int i, int _kb) {
|
||||
// load 2x bfloat16 vector
|
||||
__m512i t = _mm512_loadu_si512((__m512i*)(ScaleAndZeros + _kb * lds + 32 * i));
|
||||
if (_kb + PREFETCH_SIZE_KB < KB) {
|
||||
_mm_prefetch(ScaleAndZeros + (_kb + PREFETCH_SIZE_KB) * lds + 32 * i, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
// convert to 2x f32 vector
|
||||
__m512 a, b;
|
||||
at::vec::cvtbf16_fp32(t, a, b);
|
||||
|
||||
// transpose scale_and_zero from {16, 2} to {2, 16}
|
||||
// inputs:
|
||||
// a: {s0, z0, s1, z1, ..., s7, z7}
|
||||
// b: {s8, z8, s9, z9, ..., s15, z15}
|
||||
// output:
|
||||
// scale: {s0, s1, s2, ..., s15}
|
||||
// zero: {z0, z1, z2, ..., z15}
|
||||
scale[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
|
||||
zero[i] = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
|
||||
};
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
if constexpr (accum) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
auto compute = [&, COLS](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
float aa = static_cast<float>(A[row * lda + k]);
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(A + row * lda + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
va = _mm512_set1_ps(aa);
|
||||
}
|
||||
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (COLS == 4) {
|
||||
// when BLOCK_N = 64, handle each row at a time
|
||||
// to reduce de-quantize overhead.
|
||||
if constexpr (col == 0) {
|
||||
__m256i b4 = _mm256_loadu_si256((__m256i*)(B + k * ldb));
|
||||
if (k + PREFETCH_SIZE_K < K) {
|
||||
_mm_prefetch(B + (k + PREFETCH_SIZE_K) * ldb, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m512i b32 = _mm512_cvtepu8_epi32(_mm256_castsi256_si128(b4));
|
||||
vb[0] = _mm512_permutexvar_ps(b32, lut);
|
||||
vb[0] = _mm512_fmadd_ps(vb[0], scale[0], zero[0]);
|
||||
vb[2] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
|
||||
vb[2] = _mm512_fmadd_ps(vb[2], scale[2], zero[2]);
|
||||
|
||||
b32 = _mm512_cvtepu8_epi32(_mm256_extracti128_si256(b4, 1));
|
||||
vb[1] = _mm512_permutexvar_ps(b32, lut);
|
||||
vb[1] = _mm512_fmadd_ps(vb[1], scale[1], zero[1]);
|
||||
vb[3] = _mm512_permutexvar_ps(_mm512_srli_epi32(b32, 4), lut);
|
||||
vb[3] = _mm512_fmadd_ps(vb[3], scale[3], zero[3]);
|
||||
}
|
||||
} else {
|
||||
__m128i b8 = convert_int4_to_int8(B + k * ldb + col * 8);
|
||||
__m512i b32 = _mm512_cvtepu8_epi32(b8);
|
||||
vb[col] = _mm512_permutexvar_ps(b32, lut);
|
||||
vb[col] = _mm512_fmadd_ps(vb[col], scale[col], zero[col]);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr int idx = row * COLS + col;
|
||||
vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
|
||||
};
|
||||
|
||||
for (int k = 0, kb = 0; k < K; ++k) {
|
||||
if (is_block_start(k, k_start, q_group_size)) {
|
||||
c10::ForcedUnroll<COLS>{}(load_scale_and_zeros, kb++);
|
||||
}
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
//store to C
|
||||
auto storec = [&, COLS](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
|
||||
};
|
||||
c10::ForcedUnroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
"""
|
||||
|
||||
def get_kernel_extra_args_declare(self) -> str:
|
||||
return (
|
||||
"const int64_t q_group_size,\n"
|
||||
" const bfloat16* __restrict__ ScaleAndZeros,\n"
|
||||
" const int64_t lds,\n"
|
||||
" int64_t k_start,"
|
||||
)
|
||||
|
||||
def get_kernel_extra_args(self, **kwargs) -> list[str]:
|
||||
assert "kernel" in kwargs
|
||||
assert "qscale_and_zeros" in kwargs
|
||||
kernel = kwargs["kernel"]
|
||||
qscale_and_zeros = kwargs["qscale_and_zeros"]
|
||||
return [
|
||||
"group_size,",
|
||||
f"&({kernel.index(qscale_and_zeros, [0, 0, 0])}),",
|
||||
"N * 2,",
|
||||
"k_start,",
|
||||
]
|
||||
|
||||
|
||||
def create_micro_gemm(
|
||||
name,
|
||||
m,
|
||||
|
|
@ -888,6 +1133,8 @@ def create_micro_gemm(
|
|||
alpha=1,
|
||||
num_threads=-1,
|
||||
use_ref=True,
|
||||
is_woq_int4=False,
|
||||
q_group_size=None,
|
||||
) -> Optional[CppMicroGemm]:
|
||||
def create_from_config(cls, config: CppMicroGemmConfig):
|
||||
return cls(
|
||||
|
|
@ -898,8 +1145,38 @@ def create_micro_gemm(
|
|||
config.compute_dtype,
|
||||
config.register_blocking,
|
||||
alpha,
|
||||
is_woq_int4,
|
||||
)
|
||||
|
||||
def woq_extra_check(config, m, n, k, alpha, is_woq_int4, q_group_size):
|
||||
if input_dtype != torch.bfloat16 or input2_dtype not in [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
]:
|
||||
# non-WOQ cases or WOQ with invalid input types
|
||||
return True
|
||||
if alpha != 1:
|
||||
return False
|
||||
if is_woq_int4:
|
||||
assert q_group_size is not None
|
||||
if (
|
||||
q_group_size < 32
|
||||
or k % q_group_size != 0
|
||||
or config.register_blocking.block_k > q_group_size
|
||||
):
|
||||
return False
|
||||
return k % config.register_blocking.block_k == 0 and n % 64 == 0
|
||||
else: # WOQ INT8
|
||||
if (
|
||||
config.vec_isa_cls == VecAMX
|
||||
and m < block_m
|
||||
and input_dtype == torch.bfloat16
|
||||
and input2_dtype == torch.int8
|
||||
):
|
||||
# For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
|
||||
return False
|
||||
return True
|
||||
|
||||
assert isinstance(n, int) or n.is_number, n
|
||||
assert isinstance(k, int) or k.is_number, k
|
||||
m = V.graph.sizevars.size_hint(m, fallback=1) if isinstance(m, sympy.Expr) else m
|
||||
|
|
@ -932,13 +1209,9 @@ def create_micro_gemm(
|
|||
):
|
||||
continue
|
||||
block_m, block_n, block_k = config.register_blocking
|
||||
if (
|
||||
config.vec_isa_cls == VecAMX
|
||||
and m < block_m
|
||||
and input_dtype == torch.bfloat16
|
||||
and input2_dtype == torch.int8
|
||||
if not woq_extra_check(
|
||||
config, m, n, k, alpha, is_woq_int4, q_group_size
|
||||
):
|
||||
# For int8 WoQ GEMM, AMX micro-kernel may not perform well if m < block_m
|
||||
continue
|
||||
# Criteria on the ranking of configurations
|
||||
# 1. ISA: AMX > VEC
|
||||
|
|
@ -975,7 +1248,7 @@ def create_micro_gemm(
|
|||
)
|
||||
)
|
||||
if len(matched_configs) == 0:
|
||||
if use_ref:
|
||||
if use_ref and not is_woq_int4:
|
||||
return CppMicroGemmRef(
|
||||
name, input_dtype, input2_dtype, output_dtype, compute_dtype, alpha
|
||||
)
|
||||
|
|
|
|||
|
|
@ -484,6 +484,18 @@ inline at::vec::Vectorized<float> vec_shuffle_down(at::vec::Vectorized<float> x,
|
|||
}
|
||||
throw std::runtime_error("Unhandled vec_shuffle_down value " + std::to_string(n));
|
||||
}
|
||||
|
||||
// For WOQ INT4
|
||||
inline __m128i convert_int4_to_int8(const uint8_t* data) {
|
||||
__m128i tmp = _mm_loadu_si64((const __m128i*)data);
|
||||
__m128i bytes = _mm_cvtepu8_epi16(tmp);
|
||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||
__m128i high = _mm_andnot_si128(lowMask, bytes);
|
||||
__m128i low = _mm_and_si128(lowMask, bytes);
|
||||
high = _mm_slli_epi16(high, 4);
|
||||
bytes = _mm_or_si128(low, high);
|
||||
return bytes;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
|
|
@ -945,3 +957,8 @@ class AMXState {
|
|||
tile_release();
|
||||
}
|
||||
};
|
||||
|
||||
// For group-wise quantization, e.g., WOQ INT4
|
||||
inline bool is_block_start(int index, int k_start, int group_size) {
|
||||
return (k_start + index) % group_size == 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch._inductor.ir import IRNode
|
||||
from torch._inductor.kernel.mm_common import mm_args
|
||||
|
||||
from . import config as inductor_config, lowering
|
||||
|
|
@ -133,6 +134,21 @@ def register_woq_mm_ops() -> None:
|
|||
if use_aten_gemm_kernels()
|
||||
else []
|
||||
)
|
||||
if use_cpp_gemm_template(
|
||||
aten_layout,
|
||||
mat1,
|
||||
mat2,
|
||||
mat2_transposed=True,
|
||||
is_woq_int4=True,
|
||||
q_group_size=qGroupSize,
|
||||
):
|
||||
CppGemmTemplate.add_choices(
|
||||
choices,
|
||||
aten_layout,
|
||||
[mat1, mat2, group_size, qScaleAndZeros],
|
||||
is_woq_int4=True,
|
||||
q_group_size=qGroupSize,
|
||||
)
|
||||
|
||||
if (
|
||||
len(choices) == 0
|
||||
|
|
@ -144,11 +160,25 @@ def register_woq_mm_ops() -> None:
|
|||
(mat1, mat2, group_size, qScaleAndZeros), aten_layout
|
||||
).output_node()
|
||||
|
||||
# define functions to generate example inputs for weight and group size
|
||||
# otherwise, autotuner generates example inputs of all zeros for them
|
||||
def get_example_weight(x: IRNode) -> torch.Tensor:
|
||||
shape = x.get_size()
|
||||
device = x.get_device()
|
||||
return torch.randint(0, 255, shape, dtype=torch.uint8, device=device)
|
||||
|
||||
def get_example_group_size(x: IRNode) -> torch.Tensor:
|
||||
return torch.tensor(qGroupSize, dtype=torch.int64)
|
||||
|
||||
return autotune_select_algorithm(
|
||||
"_weight_int4pack_mm_for_cpu",
|
||||
choices,
|
||||
[mat1, mat2, group_size, qScaleAndZeros],
|
||||
aten_layout,
|
||||
input_gen_fns={
|
||||
1: get_example_weight,
|
||||
2: get_example_group_size,
|
||||
},
|
||||
)
|
||||
|
||||
lowering.make_fallback(aten._dyn_quant_matmul_4bit)
|
||||
|
|
|
|||
|
|
@ -1395,7 +1395,13 @@ def use_cpp_bmm_template(layout, mat1, mat2):
|
|||
|
||||
|
||||
def use_cpp_gemm_template(
|
||||
layout, mat1, mat2, mat2_transposed=False, require_constant_mat2=True
|
||||
layout,
|
||||
mat1,
|
||||
mat2,
|
||||
mat2_transposed=False,
|
||||
require_constant_mat2=True,
|
||||
is_woq_int4=False,
|
||||
q_group_size=None,
|
||||
):
|
||||
from . import ir
|
||||
from .codegen.cpp_micro_gemm import create_micro_gemm
|
||||
|
|
@ -1415,6 +1421,7 @@ def use_cpp_gemm_template(
|
|||
mat2,
|
||||
out_dtype=layout.dtype if int8_gemm else None,
|
||||
mat2_transposed=mat2_transposed,
|
||||
use_4x2_dim=is_woq_int4,
|
||||
)
|
||||
|
||||
# TODO(jgong5): support dynamic shapes for n or k
|
||||
|
|
@ -1433,6 +1440,8 @@ def use_cpp_gemm_template(
|
|||
input2_dtype=mat2.get_dtype(),
|
||||
output_dtype=output_dtype,
|
||||
num_threads=parallel_num_threads(),
|
||||
is_woq_int4=is_woq_int4,
|
||||
q_group_size=q_group_size,
|
||||
)
|
||||
|
||||
def is_last_dim_stride1(x):
|
||||
|
|
|
|||
Loading…
Reference in a new issue