From 0ad0e4bb52c2629425fafe09878220281fe5c5b3 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Sat, 8 Feb 2025 19:13:42 -0800 Subject: [PATCH] Update (base update) [ghstack-poisoned] --- torch/_meta_registrations.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index c1ef520c967..b3ccc4350eb 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3774,6 +3774,26 @@ def meta__weight_int8pack_mm(x, w, q_scales): return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) +@register_meta([torch.ops.quantized.int4mm_packed_weight_cpu]) +def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros): + torch._check(x.dim() == 2, f"x must be a 2D tensor, got {x.dim()}D") + torch._check(w.dim() == 2, f"w must be a 2D tensor, got {w.dim()}D") + torch._check( + x.dtype in [torch.float32, torch.float16, torch.bfloat16], + f"expected x to be f32/f16/bf16, got {x.dtype}", + ) + torch._check(w.dtype == torch.uint8, f"expected w to be uint8, got {w.dtype}") + torch._check( + q_group_size.dtype == torch.int64, + f"q_group_size must be int64, got {q_group_size.dtype}", + ) + torch._check( + q_scale_and_zeros.dtype == x.dtype, + f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}", + ) + return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) + + @register_meta(aten._cdist_forward.default) def meta_cdist_forward(x1, x2, p, compute_mode): torch._check(