From eebdcb5a2ef6a117a608b9ca5ca1eb2fd4f72fbd Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Sat, 27 Aug 2022 16:06:16 -0700 Subject: [PATCH] [Pytorch][quantization][ondevice] Add a wrapper API for server side prep (#83742) for ondevice quantization Summary: THis diff just wraps existing API for ondevice quantization Test Plan: test/quantization/jit/test_ondevice_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38868647](https://our.internmc.facebook.com/intern/diff/D38868647) Pull Request resolved: https://github.com/pytorch/pytorch/pull/83742 Approved by: https://github.com/jerryzh168 --- .../jit/test_ondevice_quantization.py | 29 +++++++- torch/ao/quantization/quantize_jit.py | 66 +++++++++++++++++++ torch/quantization/__init__.py | 3 +- 3 files changed, 94 insertions(+), 4 deletions(-) diff --git a/test/quantization/jit/test_ondevice_quantization.py b/test/quantization/jit/test_ondevice_quantization.py index cd4abf9479c..8b453a50f1a 100644 --- a/test/quantization/jit/test_ondevice_quantization.py +++ b/test/quantization/jit/test_ondevice_quantization.py @@ -12,7 +12,7 @@ from torch.ao.quantization.quantize_jit import ( prepare_dynamic_jit, convert_dynamic_jit, _prepare_ondevice_dynamic_jit, - _convert_ondevice_dynamic_jit, + _quantize_ondevice_dynamic_jit, ) from torch.testing._internal.common_utils import TestCase @@ -22,6 +22,8 @@ from torch.testing._internal.common_quantization import ( LinearAddModel, ) +from torch.jit.mobile import _load_for_lite_interpreter + from torch.testing import FileCheck import io @@ -69,8 +71,7 @@ class OnDevicePTQUtils(object): def ptq_dynamic_quantize(model, qconfig_dict): inputs = model.get_example_inputs() m = get_script_module(model, False, inputs) - m = _prepare_ondevice_dynamic_jit(m, qconfig_dict) - m = _convert_ondevice_dynamic_jit(m, 'forward', True, False) + m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True) return m @staticmethod @@ -420,6 +421,17 @@ class TestOnDeviceDynamicPTQFinalize(TestCase): output = m.quantized_forward(*inputs) self.assertTrue(torch.allclose(ref_output, output)) + # check for lite interpreter + m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) + buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + m = _load_for_lite_interpreter(buffer) # Error here + m.run_method("reset_observers_forward") + m.run_method("observe_forward", *inputs) + m.run_method("quantize_forward", *inputs) + output = m.run_method("quantized_forward", *inputs) + self.assertTrue(torch.allclose(ref_output, output)) + model.eval() inputs = model.get_example_inputs() ref_m = torch.jit.script(model) @@ -444,6 +456,17 @@ class TestOnDeviceDynamicPTQFinalize(TestCase): output = m.quantized_forward(*inputs) self.assertTrue(torch.allclose(ref_output, output)) + # check for lite interpreter + m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) + buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) + buffer.seek(0) + m = _load_for_lite_interpreter(buffer) # Error here + m.run_method("reset_observers_forward") + m.run_method("observe_forward", *inputs) + m.run_method("quantize_forward", *inputs) + output = m.run_method("quantized_forward", *inputs) + self.assertTrue(torch.allclose(ref_output, output)) + def test_quantize_forward(self): model = LinearAddModel() diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 17854a7e639..b9b19e59bae 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -145,6 +145,12 @@ def convert_dynamic_jit(model, inplace=False, debug=False, preserved_attrs=None) def _convert_ondevice_dynamic_jit(model, method_name, inplace=False, debug=False): return _convert_ondevice_jit(model, method_name, inplace, debug, quant_type=QuantType.DYNAMIC) + +def _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=False): + model = _prepare_ondevice_dynamic_jit(model, qconfig_dict, method_name, inplace) + model = _convert_ondevice_dynamic_jit(model, method_name, inplace) + return model + def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False, debug=False, quant_type=QuantType.STATIC): # Always do inplace convert because the Tensor is already # copied in prepare_jit when inplace is False @@ -255,3 +261,63 @@ def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False): """ torch._C._log_api_usage_once("quantization_api.quantize_jit.quantize_dynamic_jit") return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC) + + +def _quantize_ondevice_dynamic_jit(model, qconfig_dict, method_name='forward', inplace=False): + r"""Prepares the input float TorchScript model with + *on-device* post training dynamic quantization. + Currently only qint8 quantization of torch.nn.Linear is supported. + + Args: + `model`: input float TorchScript model + `qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and + qconfig for that module as value, please see detailed + `method_name`: Name of the method within the model, to be prepared for quantization + descriptions in :func:`~torch.ao.quantization.quantize_jit` + `inplace`: carry out model transformations in-place, the original module is + mutated + + Return: + TorchScript model that is ready for on device quantization. + This means that the returned + model has: + - Method is inlined. + - Model has observer modules inserted in the model. + - Model has packed params inserted in the model. However they are empty as in they dont + contain valid quantized weights. + - observe_ is added that observe the values to be quantized. + - reset_observers_ to reset observers. + - quantize_ is added to the model. + - This method extract scale, zero points. + - Quantizes observed weights. + - Creates packed params from it and update the attribute of the model with the new values + for the packed params. + - Reset the original fp32 weights with empty tensor using SetAttr. + - quantized_ is added to the model. + - This method uses quantized weights and quantized linear ops instead of fp32 op. + - This method should be used for inference post PTQ. + - Note that all method's signatures should be the same as method_name. + + Later on device: + - Run reset_observers_ + - Run observe_ + - Run quantize_ + - Now model can be saved and loaded later. + - Run model with quantized_ + + Example: + ```python + import torch + from torch.ao.quantization import per_channel_dynamic_qconfig + from torch.ao.quantization.quantize_jit import _quantize_ondevice_dynamic_jit + + ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input) + qconfig = get_default_qconfig('fbgemm') + quant_ready_model = _quantize_ondevice_dynamic_jit( + ts_model, + {'': qconfig}, + 'forward', + True) + ``` + """ + return _quantize_ondevice_dynamic_jit_impl(model, qconfig_dict, method_name, inplace=inplace) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 48ba1abdd90..df9a75d0226 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -25,7 +25,8 @@ _all__ = [ 'quantize', 'quantize_dynamic', 'quantize_qat', 'prepare', 'convert', 'prepare_qat', # Top level API for graph mode quantization on TorchScript - 'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', '_convert_ondevice_dynamic_jit', + 'quantize_jit', 'quantize_dynamic_jit', '_prepare_ondevice_dynamic_jit', + '_convert_ondevice_dynamic_jit', '_quantize_ondevice_dynamic_jit', # Top level API for graph mode quantization on GraphModule(torch.fx) # 'fuse_fx', 'quantize_fx', # TODO: add quantize_dynamic_fx # 'prepare_fx', 'prepare_dynamic_fx', 'convert_fx',