From b0df0cd7cc00faae718a456eb831bcb5bdaa6fcc Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 18 Apr 2023 16:45:12 -0700 Subject: [PATCH] [reland][quant][fix] Compare resnet with quantizer api with the prepare_fx and decomposed convert flow (#99355) Summary: Using a decomposed convert to make sure we get exact match, this means the nodes in resnet are annotated correctly, reland for https://github.com/pytorch/pytorch/pull/98905 Test Plan: python test/test_quantization.py TestQuantizePT2EModels.test_resnet18_with_quantizer_api Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D45071168](https://our.internmc.facebook.com/intern/diff/D45071168) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99355 Approved by: https://github.com/kimishpatel --- test/quantization/fx/test_quantize_fx.py | 53 ------------- test/quantization/pt2e/test_quantize_pt2e.py | 10 +-- .../pt2e/test_quantize_pt2e_fx.py | 75 +++++++++++++++++-- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index af4b1ea9340..370f32d6464 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -48,7 +48,6 @@ from torch.ao.quantization import ( default_reuse_input_qconfig, default_symmetric_qnnpack_qconfig, default_symmetric_qnnpack_qat_qconfig, - default_per_channel_symmetric_qnnpack_qconfig, per_channel_dynamic_qconfig, float16_dynamic_qconfig, float16_static_qconfig, @@ -192,7 +191,6 @@ from torch.testing._internal.common_quantized import ( from torch.testing._internal.common_utils import ( TemporaryFileName, IS_ARM64, - IS_WINDOWS, ) from torch.testing._internal.common_quantization import NodeSpec as ns @@ -6163,57 +6161,6 @@ class TestQuantizeFx(QuantizationTestCase): res = m(*example_inputs) self.assertEqual(res, res_ref) - @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") - def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self): - """ Test the result for per channel weight quant for reference modules - """ - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 3, 3) - - def forward(self, x): - return self.conv(x) - - m = M().eval() - qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig) - example_inputs = (torch.randn(1, 3, 10, 10),) - m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config()) - m(*example_inputs) - m_ref = copy.deepcopy(m) - m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config()) - m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config()) - expected_occurrence = { - # for input and output activations - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, - # weight is per channel quantized - ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1, - ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, - } - import torch._dynamo as torchdynamo - m, guards = torchdynamo.export( - m, - *copy.deepcopy(example_inputs), - aten_graph=True, - tracing_mode="real", - ) - self.checkGraphModuleNodes( - m, - expected_node_occurrence=expected_occurrence) - # make sure it runs - res_ref = m_ref(*example_inputs) - res = m(*example_inputs) - self.assertEqual(res, res_ref) - # check the qmin/qmax for per channel quant - for n in m.graph.nodes: - if n.op == "call_function" and \ - n.target == torch.ops.quantized_decomposed.quantize_per_channel.default: - _QUANT_MIN_INDEX = 4 - _QUANT_MAX_INDEX = 5 - self.assertEqual(n.args[_QUANT_MIN_INDEX], -127) - self.assertEqual(n.args[_QUANT_MAX_INDEX], 127) - def test_change_backend_config_for_fixed_qparam_ops(self): """ Making sure we can skip validation of qconfigs for fixedqparam ops based on BackendConfig diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 5470ec46fff..b9d4a4b064a 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -265,11 +265,5 @@ class TestQuantizePT2EModels(QuantizationTestCase): compute_sqnr(after_prepare_result, after_prepare_result_fx), torch.tensor(float("inf")), ) - # there are slight differences after convert due to different implementations - # of quant/dequant - self.assertTrue( - torch.max(after_quant_result - after_quant_result_fx) < 1e-1 - ) - self.assertTrue( - compute_sqnr(after_quant_result, after_quant_result_fx) > 35 - ) + self.assertEqual(after_quant_result, after_quant_result_fx) + self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) == torch.tensor(float("inf"))) diff --git a/test/quantization/pt2e/test_quantize_pt2e_fx.py b/test/quantization/pt2e/test_quantize_pt2e_fx.py index 8b93190173a..1af51002e2b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_fx.py +++ b/test/quantization/pt2e/test_quantize_pt2e_fx.py @@ -7,7 +7,12 @@ import torch._dynamo as torchdynamo import torch.nn as nn from torch._inductor.compile_fx import compile_fx from torch.ao.ns.fx.utils import compute_sqnr -from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping +from torch.ao.quantization import ( + get_default_qconfig, + observer, + QConfigMapping, + default_per_channel_symmetric_qnnpack_qconfig, +) from torch.ao.quantization._quantize_pt2e import ( convert_pt2e, prepare_pt2e, @@ -23,8 +28,13 @@ from torch.ao.quantization.backend_config.x86 import get_x86_backend_config from torch.ao.quantization.quantize_fx import ( convert_fx, convert_to_reference_fx, + _convert_to_reference_decomposed_fx, prepare_fx, ) + +from torch.testing._internal.common_utils import ( + IS_WINDOWS, +) from torch.testing._internal.common_quantization import ( NodeSpec as ns, QuantizationTestCase, @@ -33,8 +43,10 @@ from torch.testing._internal.common_quantization import ( skipIfNoX86, ) from torch.testing._internal.common_quantized import override_quantized_engine +import unittest +# TODO: remove after quantizer API is more mature @skipIfNoQNNPACK class TestQuantizePT2EFX(QuantizationTestCase): def test_qconfig_none(self): @@ -255,6 +267,57 @@ class TestQuantizePT2EFX(QuantizationTestCase): } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + # TODO(jerryzh168): move all _convert_to_reference_decomposed_fx tests here + @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows") + def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self): + """ Test the result for per channel weight quant for reference modules + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv(x) + + m = M().eval() + qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig) + example_inputs = (torch.randn(1, 3, 10, 10),) + m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config()) + m(*example_inputs) + m_ref = copy.deepcopy(m) + m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config()) + m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config()) + expected_occurrence = { + # for input and output activations + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, + # weight is per channel quantized + ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, + } + import torch._dynamo as torchdynamo + m, guards = torchdynamo.export( + m, + *copy.deepcopy(example_inputs), + aten_graph=True, + tracing_mode="real", + ) + self.checkGraphModuleNodes( + m, + expected_node_occurrence=expected_occurrence) + # make sure it runs + res_ref = m_ref(*example_inputs) + res = m(*example_inputs) + self.assertEqual(res, res_ref) + # check the qmin/qmax for per channel quant + for n in m.graph.nodes: + if n.op == "call_function" and \ + n.target == torch.ops.quantized_decomposed.quantize_per_channel.default: + _QUANT_MIN_INDEX = 4 + _QUANT_MAX_INDEX = 5 + self.assertEqual(n.args[_QUANT_MIN_INDEX], -127) + self.assertEqual(n.args[_QUANT_MAX_INDEX], 127) @skipIfNoQNNPACK class TestQuantizePT2EFXX86Inductor(QuantizationTestCase): @@ -425,7 +488,7 @@ class TestQuantizePT2EFXModels(QuantizationTestCase): m_copy, qconfig_mapping, example_inputs, backend_config=backend_config ) after_prepare_result_fx = m_fx(*example_inputs) - m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config) + m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config) after_quant_result_fx = m_fx(*example_inputs) @@ -437,9 +500,5 @@ class TestQuantizePT2EFXModels(QuantizationTestCase): ) # there are slight differences after convert due to different implementations # of quant/dequant - self.assertTrue( - torch.max(after_quant_result - after_quant_result_fx) < 1e-1 - ) - self.assertTrue( - compute_sqnr(after_quant_result, after_quant_result_fx) > 35 - ) + self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1) + self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)