diff --git a/test/quantization/fx/test_quantize_pt2e.py b/test/quantization/fx/test_quantize_pt2e.py index 73395391f59..150df701f38 100644 --- a/test/quantization/fx/test_quantize_pt2e.py +++ b/test/quantization/fx/test_quantize_pt2e.py @@ -26,6 +26,17 @@ from torch.ao.ns.fx.utils import ( compute_sqnr, ) import copy +from torch._decomp import get_decompositions +from torch.fx.experimental.proxy_tensor import make_fx + +quant_decomp = get_decompositions( + [ + torch.ops.quantized_decomposed.quantize_per_tensor, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ] +) @skipIfNoQNNPACK class TestQuantizePT2E(QuantizationTestCase): @@ -124,7 +135,81 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), ns.call_function(torch.ops.aten.addmm.default), ] - self.checkGraphModuleNodes(m, expected_node_list=node_list) + self.checkGraphModuleNodes( + m, + expected_node_list=node_list, + expected_node_occurrence=node_occurrence + ) + + @xfailIfPython311 + def test_q_dq_decomposition(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 1, 1) + + def forward(self, x): + x = self.conv(x) + return x + + with override_quantized_engine("qnnpack"): + m = M().eval() + example_inputs = (torch.randn(1, 1, 3, 3),) + + # program capture + m, guards = torchdynamo.export( + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + tracing_mode="real", + ) + + qconfig = get_default_qconfig("qnnpack") + qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig) + backend_config = get_qnnpack_pt2e_backend_config() + m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config) + m(*example_inputs) + m = convert_pt2e(m) + m(*example_inputs) + node_occurrence = { + # two for input and weight of the conv, one for output for the conv + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3, + } + node_list = [ + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ns.call_function(torch.ops.aten.convolution.default), + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), + ] + self.checkGraphModuleNodes( + m, + expected_node_list=node_list, + expected_node_occurrence=node_occurrence + ) + m = make_fx(m, decomposition_table=quant_decomp)(*copy.deepcopy(example_inputs)) + node_occurrence = { + # check both q/dq are decomposed + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 0, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 0, + } + node_list = [ + # ops in quantize + ns.call_function(torch.ops.aten.mul.Tensor), + ns.call_function(torch.ops.aten.round.default), + ns.call_function(torch.ops.aten.add.Tensor), + ns.call_function(torch.ops.aten.clamp.default), + # ops in dequantize + ns.call_function(torch.ops.aten.sub.Tensor), + ns.call_function(torch.ops.aten.mul.Tensor), + # conv op + ns.call_function(torch.ops.aten.convolution.default), + ] + self.checkGraphModuleNodes( + m, + expected_node_list=node_list, + expected_node_occurrence=node_occurrence + ) class TestQuantizePT2EModels(QuantizationTestCase): @skip_if_no_torchvision diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 3ad1866250e..649a292a5b1 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2645,6 +2645,10 @@ import torch._refs import torch._refs.nn.functional import torch._refs.special +_QUANTIZED_DECOMPOSED_LIB = torch.library.Library( + "quantized_decomposed", "IMPL", "Meta" +) + def activate_meta(): @@ -2698,6 +2702,8 @@ def activate_meta(): _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn) elif "mkl::" in op_overload.name(): _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn) + elif "quantized_decomposed::" in op_overload.name(): + _QUANTIZED_DECOMPOSED_LIB.impl(op_overload, fn) else: _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index c6591236b87..53edc4f974d 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -2,6 +2,31 @@ import torch from torch.library import Library, impl from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax from typing import Tuple +from torch._decomp import register_decomposition + +def _quantize_per_tensor_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + inv_scale = 1.0 / scale + return torch.clamp( + torch.round(input * inv_scale) + zero_point, quant_min, quant_max + ).to(dtype) + +def _dequantize_per_tensor_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return (input.to(torch.float32) - zero_point) * scale + # Note: decomposed means decomposed quantized tensor, using decomposed so that the @@ -59,8 +84,18 @@ def quantize_per_tensor( assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) - inv_scale = 1.0 / scale - return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) + return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) + +@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor) +def quantize_per_tensor_decomp_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) quantized_decomposed_lib.define( "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " @@ -82,15 +117,19 @@ def quantize_per_tensor_tensor( """ assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" - return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + return _quantize_per_tensor_impl( + input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] -@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") -def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): - assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" - assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" - _quant_min_max_bounds_check(quant_min, quant_max, dtype) - return torch.empty_like(input, dtype=dtype) +@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor.tensor) +def quantize_per_tensor_tensor_decomp_impl( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return _quantize_per_tensor_impl(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in # the signature as metadata for the input Tensor, this might be useful for pattern @@ -138,11 +177,22 @@ def dequantize_per_tensor( # TODO: investigate why # (input - zero_point).to(torch.float32) * scale # failed the test - return (input.to(torch.float32) - zero_point) * scale + return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) else: raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") +@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor) +def dequantize_per_tensor_decomp_impl( + input: torch.Tensor, + scale: float, + zero_point: int, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) + quantized_decomposed_lib.define( "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " "int quant_min, int quant_max, ScalarType dtype) -> Tensor") @@ -163,23 +213,26 @@ def dequantize_per_tensor_tensor( """ assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" - return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) - -@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") -def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): - assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" - assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" - assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" - if dtype in [torch.uint8, torch.int8, torch.int32]: - return torch.empty_like(input, dtype=torch.float32) - else: - raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") - + return _dequantize_per_tensor_impl( + input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] quantized_decomposed_lib.define( "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " "ScalarType dtype) -> (Tensor, Tensor)") + +@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor) +def dequantize_per_tensor_tensor_decomp_impl( + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + quant_min: int, + quant_max: int, + dtype: torch.dtype, +) -> torch.Tensor: + return _dequantize_per_tensor_impl( + input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] + @impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") def choose_qparams_tensor( input: torch.Tensor,