diff --git a/test/quantization/test_quantize_fx.py b/test/quantization/test_quantize_fx.py index cc39250b745..8bddd15ac90 100644 --- a/test/quantization/test_quantize_fx.py +++ b/test/quantization/test_quantize_fx.py @@ -1762,6 +1762,46 @@ class TestQuantizeFx(QuantizationTestCase): checkModel(m, data, ref_weight, ref_bias, ref_res) + def test_preserve_qconfig(self): + """ + Test to make sure the temporary config option to preserve qconfig attributes + in the model works + """ + class Linear(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.ones(5, 5) + self.b = torch.zeros(5) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w, self.b) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.mods1 = torch.nn.Sequential( + Linear(), + Linear() + ) + self.mods2 = torch.nn.Sigmoid() + + def forward(self, x): + x = self.mods1(x) + x = self.mods2(x) + return x + + model = M().eval() + qconfig_dict = { + "object_type": [ + (torch.nn.functional.linear, float16_dynamic_qconfig), + ], + } + m = prepare_fx(model, qconfig_dict) + m(torch.rand(5, 5)) + m = convert_fx(m, _remove_qconfig=False) + + self.assertTrue(hasattr(m.mods2, 'qconfig')) + @skipIfNoFBGEMM class TestQuantizeFxOps(QuantizationTestCase): """Unit tests for individual ops diff --git a/torch/quantization/fx/quantize.py b/torch/quantization/fx/quantize.py index 65e57d151c0..8851fbe74c1 100644 --- a/torch/quantization/fx/quantize.py +++ b/torch/quantization/fx/quantize.py @@ -638,7 +638,8 @@ class Quantizer: def _convert(self, model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, - is_standalone_module: bool = False) -> QuantizedGraphModule: + is_standalone_module: bool = False, + _remove_qconfig_flag: bool = True) -> QuantizedGraphModule: """ standalone_module means it a submodule that is not inlined in parent module, and will be quantized separately as one unit. @@ -919,7 +920,8 @@ class Quantizer: node, load_arg_simple) # removes qconfig and activation_post_process modules - _remove_qconfig(model) + if _remove_qconfig_flag: + _remove_qconfig(model) model = QuantizedGraphModule(model, act_post_process_removed_graph) return model @@ -977,9 +979,10 @@ class Quantizer: def convert(self, model: GraphModule, is_reference: bool = False, convert_custom_config_dict: Dict[str, Any] = None, - is_standalone_module: bool = False) -> QuantizedGraphModule: + is_standalone_module: bool = False, + _remove_qconfig: bool = True) -> QuantizedGraphModule: quantized = self._convert( - model, is_reference, convert_custom_config_dict, is_standalone_module) + model, is_reference, convert_custom_config_dict, is_standalone_module, _remove_qconfig_flag=_remove_qconfig) if not is_reference: quantized = self._fold_weight(quantized) return quantized diff --git a/torch/quantization/quantize_fx.py b/torch/quantization/quantize_fx.py index 242e26cd407..323277b6c62 100644 --- a/torch/quantization/quantize_fx.py +++ b/torch/quantization/quantize_fx.py @@ -418,7 +418,8 @@ def prepare_qat_fx( def _convert_fx( graph_module: GraphModule, is_reference: bool, convert_custom_config_dict: Dict[str, Any] = None, - is_standalone_module: bool = False) -> QuantizedGraphModule: + is_standalone_module: bool = False, + _remove_qconfig: bool = True) -> QuantizedGraphModule: """ `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx` """ if convert_custom_config_dict is None: @@ -427,7 +428,8 @@ def _convert_fx( _check_is_graph_module(graph_module) quantizer = Quantizer() - quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict, is_standalone_module) + quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict, + is_standalone_module, _remove_qconfig=_remove_qconfig) preserved_attributes = convert_custom_config_dict.get("preserved_attributes", []) for attr_name in preserved_attributes: @@ -436,7 +438,8 @@ def _convert_fx( def convert_fx( graph_module: GraphModule, is_reference: bool = False, - convert_custom_config_dict: Dict[str, Any] = None) -> QuantizedGraphModule: + convert_custom_config_dict: Dict[str, Any] = None, + _remove_qconfig: bool = True) -> QuantizedGraphModule: r""" Convert a calibrated or trained model to a quantized model Args: `graph_module`: A prepared and calibrated/trained model (GraphModule) @@ -480,6 +483,7 @@ def convert_fx( # not used in the code "preserved_attributes": ["preserved_attr"], } + `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert. Return: A quantized model (GraphModule) @@ -491,7 +495,7 @@ def convert_fx( ``` """ torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx") - return _convert_fx(graph_module, is_reference, convert_custom_config_dict) + return _convert_fx(graph_module, is_reference, convert_custom_config_dict, _remove_qconfig=_remove_qconfig) def _convert_standalone_module_fx( graph_module: GraphModule, is_reference: bool = False,