[quant][fx] add _remove_qconfig flag to convert_fx (#53166)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53166

Context: For fx modules that consist of scriptmodules, calling
delattr(module, 'qconfig') throws an attribute error. will follow up
with a separate issue/repro to fix this problem

This PR adds a temporary flag to convert_fx API to preserve the qconfig attributes on the converted model
We will remove this flag once we reach a conclusion on calling delattr on scriptmodules

Test Plan:
python test/test_quantization.py test_preserve_qconfig

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D26771518

fbshipit-source-id: 9fd72816576856ffb4aa11f8fde08303d1df10a2
This commit is contained in:
Supriya Rao 2021-03-03 12:52:49 -08:00 committed by Facebook GitHub Bot
parent 25a3732c8d
commit 7cec4b3d4a
3 changed files with 55 additions and 8 deletions

View file

@ -1762,6 +1762,46 @@ class TestQuantizeFx(QuantizationTestCase):
checkModel(m, data, ref_weight, ref_bias, ref_res)
def test_preserve_qconfig(self):
"""
Test to make sure the temporary config option to preserve qconfig attributes
in the model works
"""
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)
def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods2 = torch.nn.Sigmoid()
def forward(self, x):
x = self.mods1(x)
x = self.mods2(x)
return x
model = M().eval()
qconfig_dict = {
"object_type": [
(torch.nn.functional.linear, float16_dynamic_qconfig),
],
}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))
m = convert_fx(m, _remove_qconfig=False)
self.assertTrue(hasattr(m.mods2, 'qconfig'))
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops

View file

@ -638,7 +638,8 @@ class Quantizer:
def _convert(self, model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> QuantizedGraphModule:
is_standalone_module: bool = False,
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
""" standalone_module means it a submodule that is not inlined in
parent module, and will be quantized separately as one unit.
@ -919,7 +920,8 @@ class Quantizer:
node, load_arg_simple)
# removes qconfig and activation_post_process modules
_remove_qconfig(model)
if _remove_qconfig_flag:
_remove_qconfig(model)
model = QuantizedGraphModule(model, act_post_process_removed_graph)
return model
@ -977,9 +979,10 @@ class Quantizer:
def convert(self, model: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> QuantizedGraphModule:
is_standalone_module: bool = False,
_remove_qconfig: bool = True) -> QuantizedGraphModule:
quantized = self._convert(
model, is_reference, convert_custom_config_dict, is_standalone_module)
model, is_reference, convert_custom_config_dict, is_standalone_module, _remove_qconfig_flag=_remove_qconfig)
if not is_reference:
quantized = self._fold_weight(quantized)
return quantized

View file

@ -418,7 +418,8 @@ def prepare_qat_fx(
def _convert_fx(
graph_module: GraphModule, is_reference: bool,
convert_custom_config_dict: Dict[str, Any] = None,
is_standalone_module: bool = False) -> QuantizedGraphModule:
is_standalone_module: bool = False,
_remove_qconfig: bool = True) -> QuantizedGraphModule:
""" `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx`
"""
if convert_custom_config_dict is None:
@ -427,7 +428,8 @@ def _convert_fx(
_check_is_graph_module(graph_module)
quantizer = Quantizer()
quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict, is_standalone_module)
quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict,
is_standalone_module, _remove_qconfig=_remove_qconfig)
preserved_attributes = convert_custom_config_dict.get("preserved_attributes", [])
for attr_name in preserved_attributes:
@ -436,7 +438,8 @@ def _convert_fx(
def convert_fx(
graph_module: GraphModule, is_reference: bool = False,
convert_custom_config_dict: Dict[str, Any] = None) -> QuantizedGraphModule:
convert_custom_config_dict: Dict[str, Any] = None,
_remove_qconfig: bool = True) -> QuantizedGraphModule:
r""" Convert a calibrated or trained model to a quantized model
Args:
`graph_module`: A prepared and calibrated/trained model (GraphModule)
@ -480,6 +483,7 @@ def convert_fx(
# not used in the code
"preserved_attributes": ["preserved_attr"],
}
`_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
Return:
A quantized model (GraphModule)
@ -491,7 +495,7 @@ def convert_fx(
```
"""
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
return _convert_fx(graph_module, is_reference, convert_custom_config_dict)
return _convert_fx(graph_module, is_reference, convert_custom_config_dict, _remove_qconfig=_remove_qconfig)
def _convert_standalone_module_fx(
graph_module: GraphModule, is_reference: bool = False,