[reland][quant][fix] Compare resnet with quantizer api with the prepare_fx and decomposed convert flow (#99355)

Summary:
Using a decomposed convert to make sure we get exact match, this means the nodes in resnet are
annotated correctly, reland for https://github.com/pytorch/pytorch/pull/98905

Test Plan:
python test/test_quantization.py TestQuantizePT2EModels.test_resnet18_with_quantizer_api

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D45071168](https://our.internmc.facebook.com/intern/diff/D45071168)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99355
Approved by: https://github.com/kimishpatel
This commit is contained in:
Jerry Zhang 2023-04-18 16:45:12 -07:00 committed by PyTorch MergeBot
parent 391a3add54
commit b0df0cd7cc
3 changed files with 69 additions and 69 deletions

View file

@ -48,7 +48,6 @@ from torch.ao.quantization import (
default_reuse_input_qconfig,
default_symmetric_qnnpack_qconfig,
default_symmetric_qnnpack_qat_qconfig,
default_per_channel_symmetric_qnnpack_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float16_static_qconfig,
@ -192,7 +191,6 @@ from torch.testing._internal.common_quantized import (
from torch.testing._internal.common_utils import (
TemporaryFileName,
IS_ARM64,
IS_WINDOWS,
)
from torch.testing._internal.common_quantization import NodeSpec as ns
@ -6163,57 +6161,6 @@ class TestQuantizeFx(QuantizationTestCase):
res = m(*example_inputs)
self.assertEqual(res, res_ref)
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
""" Test the result for per channel weight quant for reference modules
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
m = M().eval()
qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
example_inputs = (torch.randn(1, 3, 10, 10),)
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
m(*example_inputs)
m_ref = copy.deepcopy(m)
m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
expected_occurrence = {
# for input and output activations
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
# weight is per channel quantized
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
}
import torch._dynamo as torchdynamo
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
self.checkGraphModuleNodes(
m,
expected_node_occurrence=expected_occurrence)
# make sure it runs
res_ref = m_ref(*example_inputs)
res = m(*example_inputs)
self.assertEqual(res, res_ref)
# check the qmin/qmax for per channel quant
for n in m.graph.nodes:
if n.op == "call_function" and \
n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
_QUANT_MIN_INDEX = 4
_QUANT_MAX_INDEX = 5
self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
def test_change_backend_config_for_fixed_qparam_ops(self):
""" Making sure we can skip validation of qconfigs for fixedqparam ops based
on BackendConfig

View file

@ -265,11 +265,5 @@ class TestQuantizePT2EModels(QuantizationTestCase):
compute_sqnr(after_prepare_result, after_prepare_result_fx),
torch.tensor(float("inf")),
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)
self.assertEqual(after_quant_result, after_quant_result_fx)
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) == torch.tensor(float("inf")))

View file

@ -7,7 +7,12 @@ import torch._dynamo as torchdynamo
import torch.nn as nn
from torch._inductor.compile_fx import compile_fx
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
from torch.ao.quantization import (
get_default_qconfig,
observer,
QConfigMapping,
default_per_channel_symmetric_qnnpack_qconfig,
)
from torch.ao.quantization._quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
@ -23,8 +28,13 @@ from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
from torch.ao.quantization.quantize_fx import (
convert_fx,
convert_to_reference_fx,
_convert_to_reference_decomposed_fx,
prepare_fx,
)
from torch.testing._internal.common_utils import (
IS_WINDOWS,
)
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
@ -33,8 +43,10 @@ from torch.testing._internal.common_quantization import (
skipIfNoX86,
)
from torch.testing._internal.common_quantized import override_quantized_engine
import unittest
# TODO: remove after quantizer API is more mature
@skipIfNoQNNPACK
class TestQuantizePT2EFX(QuantizationTestCase):
def test_qconfig_none(self):
@ -255,6 +267,57 @@ class TestQuantizePT2EFX(QuantizationTestCase):
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
# TODO(jerryzh168): move all _convert_to_reference_decomposed_fx tests here
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
""" Test the result for per channel weight quant for reference modules
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
m = M().eval()
qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
example_inputs = (torch.randn(1, 3, 10, 10),)
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
m(*example_inputs)
m_ref = copy.deepcopy(m)
m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
expected_occurrence = {
# for input and output activations
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
# weight is per channel quantized
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
}
import torch._dynamo as torchdynamo
m, guards = torchdynamo.export(
m,
*copy.deepcopy(example_inputs),
aten_graph=True,
tracing_mode="real",
)
self.checkGraphModuleNodes(
m,
expected_node_occurrence=expected_occurrence)
# make sure it runs
res_ref = m_ref(*example_inputs)
res = m(*example_inputs)
self.assertEqual(res, res_ref)
# check the qmin/qmax for per channel quant
for n in m.graph.nodes:
if n.op == "call_function" and \
n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
_QUANT_MIN_INDEX = 4
_QUANT_MAX_INDEX = 5
self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
@skipIfNoQNNPACK
class TestQuantizePT2EFXX86Inductor(QuantizationTestCase):
@ -425,7 +488,7 @@ class TestQuantizePT2EFXModels(QuantizationTestCase):
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
)
after_prepare_result_fx = m_fx(*example_inputs)
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
after_quant_result_fx = m_fx(*example_inputs)
@ -437,9 +500,5 @@ class TestQuantizePT2EFXModels(QuantizationTestCase):
)
# there are slight differences after convert due to different implementations
# of quant/dequant
self.assertTrue(
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
)
self.assertTrue(
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
)
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)