mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
parent
1933b6b11d
commit
47247b56ef
1 changed files with 19 additions and 20 deletions
|
|
@ -2658,6 +2658,25 @@ if torch._C._has_mkldnn:
|
|||
memory_format=memory_format,
|
||||
)
|
||||
|
||||
@register_meta(torch.ops.quantized.int4mm_packed_weight_cpu)
|
||||
def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros):
|
||||
torch._check(x.dim() == 2, f"x must be a 2D tensor, got {x.dim()}D")
|
||||
torch._check(w.dim() == 2, f"w must be a 2D tensor, got {w.dim()}D")
|
||||
torch._check(
|
||||
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
|
||||
f"expected x to be f32/f16/bf16, got {x.dtype}",
|
||||
)
|
||||
torch._check(w.dtype == torch.uint8, f"expected w to be uint8, got {w.dtype}")
|
||||
torch._check(
|
||||
q_group_size.dtype == torch.int64,
|
||||
f"q_group_size must be int64, got {q_group_size.dtype}",
|
||||
)
|
||||
torch._check(
|
||||
q_scale_and_zeros.dtype == x.dtype,
|
||||
f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
|
||||
)
|
||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
# from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
|
||||
def check_dim_size(tensor, dim, dim_size, size):
|
||||
|
|
@ -3774,26 +3793,6 @@ def meta__weight_int8pack_mm(x, w, q_scales):
|
|||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
@register_meta([torch.ops.quantized.int4mm_packed_weight_cpu])
|
||||
def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros):
|
||||
torch._check(x.dim() == 2, f"x must be a 2D tensor, got {x.dim()}D")
|
||||
torch._check(w.dim() == 2, f"w must be a 2D tensor, got {w.dim()}D")
|
||||
torch._check(
|
||||
x.dtype in [torch.float32, torch.float16, torch.bfloat16],
|
||||
f"expected x to be f32/f16/bf16, got {x.dtype}",
|
||||
)
|
||||
torch._check(w.dtype == torch.uint8, f"expected w to be uint8, got {w.dtype}")
|
||||
torch._check(
|
||||
q_group_size.dtype == torch.int64,
|
||||
f"q_group_size must be int64, got {q_group_size.dtype}",
|
||||
)
|
||||
torch._check(
|
||||
q_scale_and_zeros.dtype == x.dtype,
|
||||
f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
|
||||
)
|
||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
@register_meta(aten._cdist_forward.default)
|
||||
def meta_cdist_forward(x1, x2, p, compute_mode):
|
||||
torch._check(
|
||||
|
|
|
|||
Loading…
Reference in a new issue