mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[reland][quant][fix] Compare resnet with quantizer api with the prepare_fx and decomposed convert flow (#99355)
Summary: Using a decomposed convert to make sure we get exact match, this means the nodes in resnet are annotated correctly, reland for https://github.com/pytorch/pytorch/pull/98905 Test Plan: python test/test_quantization.py TestQuantizePT2EModels.test_resnet18_with_quantizer_api Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D45071168](https://our.internmc.facebook.com/intern/diff/D45071168) Pull Request resolved: https://github.com/pytorch/pytorch/pull/99355 Approved by: https://github.com/kimishpatel
This commit is contained in:
parent
391a3add54
commit
b0df0cd7cc
3 changed files with 69 additions and 69 deletions
|
|
@ -48,7 +48,6 @@ from torch.ao.quantization import (
|
|||
default_reuse_input_qconfig,
|
||||
default_symmetric_qnnpack_qconfig,
|
||||
default_symmetric_qnnpack_qat_qconfig,
|
||||
default_per_channel_symmetric_qnnpack_qconfig,
|
||||
per_channel_dynamic_qconfig,
|
||||
float16_dynamic_qconfig,
|
||||
float16_static_qconfig,
|
||||
|
|
@ -192,7 +191,6 @@ from torch.testing._internal.common_quantized import (
|
|||
from torch.testing._internal.common_utils import (
|
||||
TemporaryFileName,
|
||||
IS_ARM64,
|
||||
IS_WINDOWS,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_quantization import NodeSpec as ns
|
||||
|
|
@ -6163,57 +6161,6 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
res = m(*example_inputs)
|
||||
self.assertEqual(res, res_ref)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
|
||||
def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
|
||||
""" Test the result for per channel weight quant for reference modules
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
m = M().eval()
|
||||
qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
|
||||
example_inputs = (torch.randn(1, 3, 10, 10),)
|
||||
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
|
||||
m(*example_inputs)
|
||||
m_ref = copy.deepcopy(m)
|
||||
m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
|
||||
m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
|
||||
expected_occurrence = {
|
||||
# for input and output activations
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
# weight is per channel quantized
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
||||
}
|
||||
import torch._dynamo as torchdynamo
|
||||
m, guards = torchdynamo.export(
|
||||
m,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
tracing_mode="real",
|
||||
)
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_occurrence=expected_occurrence)
|
||||
# make sure it runs
|
||||
res_ref = m_ref(*example_inputs)
|
||||
res = m(*example_inputs)
|
||||
self.assertEqual(res, res_ref)
|
||||
# check the qmin/qmax for per channel quant
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "call_function" and \
|
||||
n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
|
||||
_QUANT_MIN_INDEX = 4
|
||||
_QUANT_MAX_INDEX = 5
|
||||
self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
|
||||
self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
|
||||
|
||||
def test_change_backend_config_for_fixed_qparam_ops(self):
|
||||
""" Making sure we can skip validation of qconfigs for fixedqparam ops based
|
||||
on BackendConfig
|
||||
|
|
|
|||
|
|
@ -265,11 +265,5 @@ class TestQuantizePT2EModels(QuantizationTestCase):
|
|||
compute_sqnr(after_prepare_result, after_prepare_result_fx),
|
||||
torch.tensor(float("inf")),
|
||||
)
|
||||
# there are slight differences after convert due to different implementations
|
||||
# of quant/dequant
|
||||
self.assertTrue(
|
||||
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
|
||||
)
|
||||
self.assertTrue(
|
||||
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
|
||||
)
|
||||
self.assertEqual(after_quant_result, after_quant_result_fx)
|
||||
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) == torch.tensor(float("inf")))
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ import torch._dynamo as torchdynamo
|
|||
import torch.nn as nn
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
from torch.ao.quantization import get_default_qconfig, observer, QConfigMapping
|
||||
from torch.ao.quantization import (
|
||||
get_default_qconfig,
|
||||
observer,
|
||||
QConfigMapping,
|
||||
default_per_channel_symmetric_qnnpack_qconfig,
|
||||
)
|
||||
from torch.ao.quantization._quantize_pt2e import (
|
||||
convert_pt2e,
|
||||
prepare_pt2e,
|
||||
|
|
@ -23,8 +28,13 @@ from torch.ao.quantization.backend_config.x86 import get_x86_backend_config
|
|||
from torch.ao.quantization.quantize_fx import (
|
||||
convert_fx,
|
||||
convert_to_reference_fx,
|
||||
_convert_to_reference_decomposed_fx,
|
||||
prepare_fx,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
)
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
QuantizationTestCase,
|
||||
|
|
@ -33,8 +43,10 @@ from torch.testing._internal.common_quantization import (
|
|||
skipIfNoX86,
|
||||
)
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
import unittest
|
||||
|
||||
|
||||
# TODO: remove after quantizer API is more mature
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EFX(QuantizationTestCase):
|
||||
def test_qconfig_none(self):
|
||||
|
|
@ -255,6 +267,57 @@ class TestQuantizePT2EFX(QuantizationTestCase):
|
|||
}
|
||||
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
|
||||
|
||||
# TODO(jerryzh168): move all _convert_to_reference_decomposed_fx tests here
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on Windows")
|
||||
def test__convert_to_reference_decomposed_fx_per_channel_quant_module(self):
|
||||
""" Test the result for per channel weight quant for reference modules
|
||||
"""
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(3, 3, 3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
m = M().eval()
|
||||
qconfig_mapping = QConfigMapping().set_global(default_per_channel_symmetric_qnnpack_qconfig)
|
||||
example_inputs = (torch.randn(1, 3, 10, 10),)
|
||||
m = prepare_fx(m, qconfig_mapping, example_inputs, backend_config=get_qnnpack_backend_config())
|
||||
m(*example_inputs)
|
||||
m_ref = copy.deepcopy(m)
|
||||
m_ref = convert_to_reference_fx(m_ref, backend_config=get_qnnpack_backend_config())
|
||||
m = _convert_to_reference_decomposed_fx(m, backend_config=get_qnnpack_backend_config())
|
||||
expected_occurrence = {
|
||||
# for input and output activations
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2,
|
||||
# weight is per channel quantized
|
||||
ns.call_function(torch.ops.quantized_decomposed.quantize_per_channel.default): 1,
|
||||
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1,
|
||||
}
|
||||
import torch._dynamo as torchdynamo
|
||||
m, guards = torchdynamo.export(
|
||||
m,
|
||||
*copy.deepcopy(example_inputs),
|
||||
aten_graph=True,
|
||||
tracing_mode="real",
|
||||
)
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_occurrence=expected_occurrence)
|
||||
# make sure it runs
|
||||
res_ref = m_ref(*example_inputs)
|
||||
res = m(*example_inputs)
|
||||
self.assertEqual(res, res_ref)
|
||||
# check the qmin/qmax for per channel quant
|
||||
for n in m.graph.nodes:
|
||||
if n.op == "call_function" and \
|
||||
n.target == torch.ops.quantized_decomposed.quantize_per_channel.default:
|
||||
_QUANT_MIN_INDEX = 4
|
||||
_QUANT_MAX_INDEX = 5
|
||||
self.assertEqual(n.args[_QUANT_MIN_INDEX], -127)
|
||||
self.assertEqual(n.args[_QUANT_MAX_INDEX], 127)
|
||||
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EFXX86Inductor(QuantizationTestCase):
|
||||
|
|
@ -425,7 +488,7 @@ class TestQuantizePT2EFXModels(QuantizationTestCase):
|
|||
m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
|
||||
)
|
||||
after_prepare_result_fx = m_fx(*example_inputs)
|
||||
m_fx = convert_to_reference_fx(m_fx, backend_config=backend_config)
|
||||
m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
|
||||
|
||||
after_quant_result_fx = m_fx(*example_inputs)
|
||||
|
||||
|
|
@ -437,9 +500,5 @@ class TestQuantizePT2EFXModels(QuantizationTestCase):
|
|||
)
|
||||
# there are slight differences after convert due to different implementations
|
||||
# of quant/dequant
|
||||
self.assertTrue(
|
||||
torch.max(after_quant_result - after_quant_result_fx) < 1e-1
|
||||
)
|
||||
self.assertTrue(
|
||||
compute_sqnr(after_quant_result, after_quant_result_fx) > 35
|
||||
)
|
||||
self.assertTrue(torch.max(after_quant_result - after_quant_result_fx) < 1e-1)
|
||||
self.assertTrue(compute_sqnr(after_quant_result, after_quant_result_fx) > 35)
|
||||
|
|
|
|||
Loading…
Reference in a new issue