mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][fx] add _remove_qconfig flag to convert_fx (#53166)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53166 Context: For fx modules that consist of scriptmodules, calling delattr(module, 'qconfig') throws an attribute error. will follow up with a separate issue/repro to fix this problem This PR adds a temporary flag to convert_fx API to preserve the qconfig attributes on the converted model We will remove this flag once we reach a conclusion on calling delattr on scriptmodules Test Plan: python test/test_quantization.py test_preserve_qconfig Imported from OSS Reviewed By: jerryzh168 Differential Revision: D26771518 fbshipit-source-id: 9fd72816576856ffb4aa11f8fde08303d1df10a2
This commit is contained in:
parent
25a3732c8d
commit
7cec4b3d4a
3 changed files with 55 additions and 8 deletions
|
|
@ -1762,6 +1762,46 @@ class TestQuantizeFx(QuantizationTestCase):
|
|||
|
||||
checkModel(m, data, ref_weight, ref_bias, ref_res)
|
||||
|
||||
def test_preserve_qconfig(self):
|
||||
"""
|
||||
Test to make sure the temporary config option to preserve qconfig attributes
|
||||
in the model works
|
||||
"""
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.w = torch.ones(5, 5)
|
||||
self.b = torch.zeros(5)
|
||||
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.linear(x, self.w, self.b)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mods1 = torch.nn.Sequential(
|
||||
Linear(),
|
||||
Linear()
|
||||
)
|
||||
self.mods2 = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mods1(x)
|
||||
x = self.mods2(x)
|
||||
return x
|
||||
|
||||
model = M().eval()
|
||||
qconfig_dict = {
|
||||
"object_type": [
|
||||
(torch.nn.functional.linear, float16_dynamic_qconfig),
|
||||
],
|
||||
}
|
||||
m = prepare_fx(model, qconfig_dict)
|
||||
m(torch.rand(5, 5))
|
||||
m = convert_fx(m, _remove_qconfig=False)
|
||||
|
||||
self.assertTrue(hasattr(m.mods2, 'qconfig'))
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
class TestQuantizeFxOps(QuantizationTestCase):
|
||||
"""Unit tests for individual ops
|
||||
|
|
|
|||
|
|
@ -638,7 +638,8 @@ class Quantizer:
|
|||
|
||||
def _convert(self, model: GraphModule, is_reference: bool = False,
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
is_standalone_module: bool = False) -> QuantizedGraphModule:
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig_flag: bool = True) -> QuantizedGraphModule:
|
||||
""" standalone_module means it a submodule that is not inlined in
|
||||
parent module, and will be quantized separately as one unit.
|
||||
|
||||
|
|
@ -919,7 +920,8 @@ class Quantizer:
|
|||
node, load_arg_simple)
|
||||
|
||||
# removes qconfig and activation_post_process modules
|
||||
_remove_qconfig(model)
|
||||
if _remove_qconfig_flag:
|
||||
_remove_qconfig(model)
|
||||
model = QuantizedGraphModule(model, act_post_process_removed_graph)
|
||||
return model
|
||||
|
||||
|
|
@ -977,9 +979,10 @@ class Quantizer:
|
|||
|
||||
def convert(self, model: GraphModule, is_reference: bool = False,
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
is_standalone_module: bool = False) -> QuantizedGraphModule:
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig: bool = True) -> QuantizedGraphModule:
|
||||
quantized = self._convert(
|
||||
model, is_reference, convert_custom_config_dict, is_standalone_module)
|
||||
model, is_reference, convert_custom_config_dict, is_standalone_module, _remove_qconfig_flag=_remove_qconfig)
|
||||
if not is_reference:
|
||||
quantized = self._fold_weight(quantized)
|
||||
return quantized
|
||||
|
|
|
|||
|
|
@ -418,7 +418,8 @@ def prepare_qat_fx(
|
|||
def _convert_fx(
|
||||
graph_module: GraphModule, is_reference: bool,
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
is_standalone_module: bool = False) -> QuantizedGraphModule:
|
||||
is_standalone_module: bool = False,
|
||||
_remove_qconfig: bool = True) -> QuantizedGraphModule:
|
||||
""" `is_standalone_module`: see docs in :func:`~torch.quantization.prepare_standalone_module_fx`
|
||||
"""
|
||||
if convert_custom_config_dict is None:
|
||||
|
|
@ -427,7 +428,8 @@ def _convert_fx(
|
|||
_check_is_graph_module(graph_module)
|
||||
|
||||
quantizer = Quantizer()
|
||||
quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict, is_standalone_module)
|
||||
quantized = quantizer.convert(graph_module, is_reference, convert_custom_config_dict,
|
||||
is_standalone_module, _remove_qconfig=_remove_qconfig)
|
||||
|
||||
preserved_attributes = convert_custom_config_dict.get("preserved_attributes", [])
|
||||
for attr_name in preserved_attributes:
|
||||
|
|
@ -436,7 +438,8 @@ def _convert_fx(
|
|||
|
||||
def convert_fx(
|
||||
graph_module: GraphModule, is_reference: bool = False,
|
||||
convert_custom_config_dict: Dict[str, Any] = None) -> QuantizedGraphModule:
|
||||
convert_custom_config_dict: Dict[str, Any] = None,
|
||||
_remove_qconfig: bool = True) -> QuantizedGraphModule:
|
||||
r""" Convert a calibrated or trained model to a quantized model
|
||||
Args:
|
||||
`graph_module`: A prepared and calibrated/trained model (GraphModule)
|
||||
|
|
@ -480,6 +483,7 @@ def convert_fx(
|
|||
# not used in the code
|
||||
"preserved_attributes": ["preserved_attr"],
|
||||
}
|
||||
`_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
|
||||
|
||||
Return:
|
||||
A quantized model (GraphModule)
|
||||
|
|
@ -491,7 +495,7 @@ def convert_fx(
|
|||
```
|
||||
"""
|
||||
torch._C._log_api_usage_once("quantization_api.quantize_fx.convert_fx")
|
||||
return _convert_fx(graph_module, is_reference, convert_custom_config_dict)
|
||||
return _convert_fx(graph_module, is_reference, convert_custom_config_dict, _remove_qconfig=_remove_qconfig)
|
||||
|
||||
def _convert_standalone_module_fx(
|
||||
graph_module: GraphModule, is_reference: bool = False,
|
||||
|
|
|
|||
Loading…
Reference in a new issue