diff --git a/test/quantization/fx/test_quantize_pt2e.py b/test/quantization/fx/test_quantize_pt2e.py index 1fe8714bce4..4a88627b727 100644 --- a/test/quantization/fx/test_quantize_pt2e.py +++ b/test/quantization/fx/test_quantize_pt2e.py @@ -26,17 +26,6 @@ from torch.ao.ns.fx.utils import ( compute_sqnr, ) import copy -from torch._decomp import get_decompositions -from torch.fx.experimental.proxy_tensor import make_fx - -quant_decomp = get_decompositions( - [ - torch.ops.quantized_decomposed.quantize_per_tensor, - torch.ops.quantized_decomposed.quantize_per_tensor.tensor, - torch.ops.quantized_decomposed.dequantize_per_tensor, - torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, - ] -) @skipIfNoQNNPACK class TestQuantizePT2E(QuantizationTestCase): @@ -135,81 +124,7 @@ class TestQuantizePT2E(QuantizationTestCase): ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), ns.call_function(torch.ops.aten.addmm.default), ] - self.checkGraphModuleNodes( - m, - expected_node_list=node_list, - expected_node_occurrence=node_occurrence - ) - - @xfailIfPython311 - def test_q_dq_decomposition(self): - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = nn.Conv2d(1, 1, 1) - - def forward(self, x): - x = self.conv(x) - return x - - with override_quantized_engine("qnnpack"): - m = M().eval() - example_inputs = (torch.randn(1, 1, 3, 3),) - - # program capture - m, guards = torchdynamo.export( - m, - *copy.deepcopy(example_inputs), - aten_graph=True, - tracing_mode="real", - ) - - qconfig = get_default_qconfig("qnnpack") - qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Conv2d, qconfig) - backend_config = get_qnnpack_pt2e_backend_config() - m = prepare_pt2e(m, qconfig_mapping, example_inputs, backend_config) - m(*example_inputs) - m = convert_pt2e(m) - m(*example_inputs) - node_occurrence = { - # two for input and weight of the conv, one for output for the conv - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 3, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 3, - } - node_list = [ - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), - ns.call_function(torch.ops.aten.convolution.default), - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor), - ] - self.checkGraphModuleNodes( - m, - expected_node_list=node_list, - expected_node_occurrence=node_occurrence - ) - m = make_fx(m, decomposition_table=quant_decomp)(*copy.deepcopy(example_inputs)) - node_occurrence = { - # check both q/dq are decomposed - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 0, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 0, - } - node_list = [ - # ops in quantize - ns.call_function(torch.ops.aten.mul.Tensor), - ns.call_function(torch.ops.aten.round.default), - ns.call_function(torch.ops.aten.add.Tensor), - ns.call_function(torch.ops.aten.clamp.default), - # ops in dequantize - ns.call_function(torch.ops.aten.sub.Tensor), - ns.call_function(torch.ops.aten.mul.Tensor), - # conv op - ns.call_function(torch.ops.aten.convolution.default), - ] - self.checkGraphModuleNodes( - m, - expected_node_list=node_list, - expected_node_occurrence=node_occurrence - ) + self.checkGraphModuleNodes(m, expected_node_list=node_list) class TestQuantizePT2EModels(QuantizationTestCase): @skip_if_no_torchvision diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 649a292a5b1..3ad1866250e 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2645,10 +2645,6 @@ import torch._refs import torch._refs.nn.functional import torch._refs.special -_QUANTIZED_DECOMPOSED_LIB = torch.library.Library( - "quantized_decomposed", "IMPL", "Meta" -) - def activate_meta(): @@ -2702,8 +2698,6 @@ def activate_meta(): _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn) elif "mkl::" in op_overload.name(): _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn) - elif "quantized_decomposed::" in op_overload.name(): - _QUANTIZED_DECOMPOSED_LIB.impl(op_overload, fn) else: _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn) diff --git a/torch/ao/quantization/fx/_decomposed.py b/torch/ao/quantization/fx/_decomposed.py index 8518fa9f030..6d7d834f2ea 100644 --- a/torch/ao/quantization/fx/_decomposed.py +++ b/torch/ao/quantization/fx/_decomposed.py @@ -2,31 +2,6 @@ import torch from torch.library import Library, impl from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax from typing import Tuple -from torch._decomp import register_decomposition - -def _quantize_per_tensor_impl( - input: torch.Tensor, - scale: float, - zero_point: int, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> torch.Tensor: - inv_scale = 1.0 / scale - return torch.clamp( - torch.round(input * inv_scale) + zero_point, quant_min, quant_max - ).to(dtype) - -def _dequantize_per_tensor_impl( - input: torch.Tensor, - scale: float, - zero_point: int, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> torch.Tensor: - return (input.to(torch.float32) - zero_point) * scale - # Note: decomposed means decomposed quantized tensor, using decomposed so that the @@ -84,18 +59,8 @@ def quantize_per_tensor( assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" _quant_min_max_bounds_check(quant_min, quant_max, dtype) - return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) - -@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor) -def quantize_per_tensor_decomp_impl( - input: torch.Tensor, - scale: float, - zero_point: int, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> torch.Tensor: - return _quantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) + inv_scale = 1.0 / scale + return torch.clamp(torch.round(input * inv_scale) + zero_point, quant_min, quant_max).to(dtype) quantized_decomposed_lib.define( "quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " @@ -117,19 +82,15 @@ def quantize_per_tensor_tensor( """ assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" - return _quantize_per_tensor_impl( - input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] + return quantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) -@register_decomposition(torch.ops.quantized_decomposed.quantize_per_tensor.tensor) -def quantize_per_tensor_tensor_decomp_impl( - input: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> torch.Tensor: - return _quantize_per_tensor_impl(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] +@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta") +def quantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == torch.float32, f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" + _quant_min_max_bounds_check(quant_min, quant_max, dtype) + return torch.empty_like(input, dtype=dtype) # Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in # the signature as metadata for the input Tensor, this might be useful for pattern @@ -177,22 +138,11 @@ def dequantize_per_tensor( # TODO: investigate why # (input - zero_point).to(torch.float32) * scale # failed the test - return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) + return (input.to(torch.float32) - zero_point) * scale else: raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") -@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor) -def dequantize_per_tensor_decomp_impl( - input: torch.Tensor, - scale: float, - zero_point: int, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> torch.Tensor: - return _dequantize_per_tensor_impl(input, scale, zero_point, quant_min, quant_max, dtype) - quantized_decomposed_lib.define( "dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, " "int quant_min, int quant_max, ScalarType dtype) -> Tensor") @@ -213,26 +163,23 @@ def dequantize_per_tensor_tensor( """ assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" - return _dequantize_per_tensor_impl( - input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] + return dequantize_per_tensor(input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) + +@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta") +def dequantize_per_tensor_tensor_meta(input, scale, zero_point, quant_min, quant_max, dtype): + assert zero_point.numel() == 1, f"Exepecting zero_point tensor to be one element, but received : {zero_point.numel()}" + assert scale.numel() == 1, f"Exepecting scale tensor to be one element, but received : {scale.numel()}" + assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" + if dtype in [torch.uint8, torch.int8, torch.int32]: + return torch.empty_like(input, dtype=torch.float32) + else: + raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}") + quantized_decomposed_lib.define( "choose_qparams.tensor(Tensor input, int quant_min, int quant_max, " "ScalarType dtype) -> (Tensor, Tensor)") - -@register_decomposition(torch.ops.quantized_decomposed.dequantize_per_tensor.tensor) -def dequantize_per_tensor_tensor_decomp_impl( - input: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - quant_min: int, - quant_max: int, - dtype: torch.dtype, -) -> torch.Tensor: - return _dequantize_per_tensor_impl( - input, scale.item(), zero_point.item(), quant_min, quant_max, dtype) # type: ignore[arg-type] - @impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd") def choose_qparams_tensor( input: torch.Tensor,