mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][graphmode] docstrings for top level APIs (#40328)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40328 Test Plan: Imported from OSS Differential Revision: D22149708 fbshipit-source-id: 63a1cd229d9e4668fba0ef3977e894cb8984318b
This commit is contained in:
parent
7a837019a4
commit
59ca1d31ca
3 changed files with 140 additions and 39 deletions
|
|
@ -242,7 +242,7 @@ Layers for the quantization-aware training
|
|||
``torch.quantization``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
* Functions for quantization:
|
||||
* Functions for eager mode quantization:
|
||||
|
||||
* :func:`~torch.quantization.add_observer_` — Adds observer for the leaf
|
||||
modules (if quantization configuration is provided)
|
||||
|
|
@ -259,24 +259,18 @@ Layers for the quantization-aware training
|
|||
* :func:`~torch.quantization.propagate_qconfig_` — Propagates quantization
|
||||
configurations through the module hierarchy and assign them to each leaf
|
||||
module
|
||||
* :func:`~torch.quantization.quantize` — Converts a float module to quantized version
|
||||
* :func:`~torch.quantization.quantize_dynamic` — Converts a float module to
|
||||
dynamically quantized version
|
||||
* :func:`~torch.quantization.quantize_qat` — Converts a float module to
|
||||
quantized version used in quantization aware training
|
||||
* :func:`~torch.quantization.quantize` — Function for eager mode post training static quantization
|
||||
* :func:`~torch.quantization.quantize_dynamic` — Function for eager mode post training dynamic quantization
|
||||
* :func:`~torch.quantization.quantize_qat` — Function for eager mode quantization aware training function
|
||||
* :func:`~torch.quantization.swap_module` — Swaps the module with its
|
||||
quantized counterpart (if quantizable and if it has an observer)
|
||||
|
||||
* :func:`~torch.quantization.default_eval_fn` — Default evaluation function
|
||||
* :func:`~torch.quantization.default_eval_fn` — Default evaluation function
|
||||
used by the :func:`torch.quantization.quantize`
|
||||
* :func:`~torch.quantization.fuse_modules`
|
||||
* :class:`~torch.quantization.FakeQuantize` — Module for simulating the
|
||||
quantization/dequantization at training time
|
||||
* Default Observers. The rest of observers are available from
|
||||
``torch.quantization.observer``:
|
||||
* :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)``
|
||||
* :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)``
|
||||
* :class:`~torch.quantization.Observer` — Abstract base class for observers
|
||||
* :func:`~torch.quantization.fuse_modules`
|
||||
|
||||
* Functions for graph mode quantization:
|
||||
* :func:`~torch.quantization.quantize_jit` - Function for graph mode post training static quantization
|
||||
* :func:`~torch.quantization.quantize_dynamic_jit` - Function for graph mode post training dynamic quantization
|
||||
|
||||
* Quantization configurations
|
||||
* :class:`~torch.quantization.QConfig` — Quantization configuration class
|
||||
|
|
@ -303,22 +297,27 @@ Layers for the quantization-aware training
|
|||
quantized. Inserts the :class:`~torch.quantization.QuantStub` and
|
||||
* :class:`~torch.quantization.DeQuantStub`
|
||||
|
||||
Observers for computing the quantization parameters
|
||||
* Observers for computing the quantization parameters
|
||||
* Default Observers. The rest of observers are available from
|
||||
``torch.quantization.observer``:
|
||||
* :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)``
|
||||
* :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)``
|
||||
* :class:`~torch.quantization.Observer` — Abstract base class for observers
|
||||
* :class:`~torch.quantization.MinMaxObserver` — Derives the quantization
|
||||
parameters from the running minimum and maximum of the observed tensor inputs
|
||||
(per tensor variant)
|
||||
* :class:`~torch.quantization.MovingAverageMinMaxObserver` — Derives the
|
||||
quantization parameters from the running averages of the minimums and
|
||||
maximums of the observed tensor inputs (per tensor variant)
|
||||
* :class:`~torch.quantization.PerChannelMinMaxObserver` — Derives the
|
||||
quantization parameters from the running minimum and maximum of the observed
|
||||
tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives
|
||||
the quantization parameters from the running averages of the minimums and
|
||||
maximums of the observed tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.HistogramObserver` — Derives the quantization
|
||||
parameters by creating a histogram of running minimums and maximums.
|
||||
|
||||
* :class:`~torch.quantization.MinMaxObserver` — Derives the quantization
|
||||
parameters from the running minimum and maximum of the observed tensor inputs
|
||||
(per tensor variant)
|
||||
* :class:`~torch.quantization.MovingAverageMinMaxObserver` — Derives the
|
||||
quantization parameters from the running averages of the minimums and
|
||||
maximums of the observed tensor inputs (per tensor variant)
|
||||
* :class:`~torch.quantization.PerChannelMinMaxObserver` — Derives the
|
||||
quantization parameters from the running minimum and maximum of the observed
|
||||
tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives
|
||||
the quantization parameters from the running averages of the minimums and
|
||||
maximums of the observed tensor inputs (per channel variant)
|
||||
* :class:`~torch.quantization.HistogramObserver` — Derives the quantization
|
||||
parameters by creating a histogram of running minimums and maximums.
|
||||
* Observers that do not compute the quantization parameters:
|
||||
* :class:`~torch.quantization.RecordingObserver` — Records all incoming
|
||||
tensors. Used for debugging only.
|
||||
|
|
@ -326,6 +325,10 @@ Observers for computing the quantization parameters
|
|||
for situation when there are no quantization parameters (i.e.
|
||||
quantization to ``float16``)
|
||||
|
||||
* FakeQuantize module
|
||||
* :class:`~torch.quantization.FakeQuantize` — Module for simulating the
|
||||
quantization/dequantization at training time
|
||||
|
||||
``torch.nn.quantized``
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
|
|||
|
|
@ -187,17 +187,15 @@ def _remove_qconfig(module):
|
|||
del module.qconfig
|
||||
|
||||
def quantize(model, run_fn, run_args, mapping=None, inplace=False):
|
||||
r"""Converts a float model to quantized model.
|
||||
r"""Quantize the input float model with post training static quantization.
|
||||
|
||||
First it will prepare the model for calibration or training, then it calls
|
||||
`run_fn` which will run the calibration step or training step,
|
||||
after that we will call `convert` which will convert the model to a
|
||||
quantized model.
|
||||
First it will prepare the model for calibration, then it calls
|
||||
`run_fn` which will run the calibration step, after that we will
|
||||
convert the model to a quantized model.
|
||||
|
||||
Args:
|
||||
model: input model
|
||||
run_fn: a function for evaluating the prepared model, can be a
|
||||
function that simply runs the prepared model or a training loop
|
||||
model: input float model
|
||||
run_fn: a calibration function for calibrating the prepared model
|
||||
run_args: positional arguments for `run_fn`
|
||||
inplace: carry out model transformations in-place, the original module is mutated
|
||||
mapping: correspondence between original module types and quantized counterparts
|
||||
|
|
|
|||
|
|
@ -20,14 +20,28 @@ def _check_forward_method(model):
|
|||
raise ValueError('input script module does not have forward method')
|
||||
|
||||
def script_qconfig(qconfig):
|
||||
r"""Instantiate the activation and weight observer modules and script
|
||||
them, these observer module instances will be deepcopied during
|
||||
prepare_jit step.
|
||||
"""
|
||||
return QConfig(
|
||||
activation=torch.jit.script(qconfig.activation())._c,
|
||||
weight=torch.jit.script(qconfig.weight())._c)
|
||||
|
||||
def script_qconfig_dict(qconfig_dict):
|
||||
r"""Helper function used by `prepare_jit`.
|
||||
Apply `script_qconfig` for all entries in `qconfig_dict` that is
|
||||
not None.
|
||||
"""
|
||||
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
|
||||
|
||||
def fuse_conv_bn_jit(model):
|
||||
r""" Fuse conv - bn module
|
||||
Works for eval model only.
|
||||
|
||||
Args:
|
||||
model: TorchScript model from scripting or tracing
|
||||
"""
|
||||
return torch.jit._recursive.wrap_cpp_module(torch._C._jit_pass_fold_convbn(model._c))
|
||||
|
||||
def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
|
||||
|
|
@ -87,7 +101,93 @@ def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False
|
|||
return model
|
||||
|
||||
def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
|
||||
r"""Quantize the input float TorchScript model with
|
||||
post training static quantization.
|
||||
|
||||
First it will prepare the model for calibration, then it calls
|
||||
`run_fn` which will run the calibration step, after that we will
|
||||
convert the model to a quantized model.
|
||||
|
||||
Args:
|
||||
`model`: input float TorchScript model
|
||||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||||
qconfig for that module as value, empty key means the qconfig will be applied
|
||||
to whole model unless it’s overwritten by more specific configurations, the
|
||||
qconfig for each module is either found in the dictionary or fallback to
|
||||
the qconfig of parent module.
|
||||
|
||||
Right now qconfig_dict is the only way to configure how the model is quantized,
|
||||
and it is done in the granularity of module, that is, we only support one type
|
||||
of qconfig for each torch.nn.Module, and the qconfig for sub module will
|
||||
override the qconfig for parent module, empty string means global configuration.
|
||||
`run_fn`: a calibration function for calibrating the prepared model
|
||||
`run_args`: positional arguments for `run_fn`
|
||||
`inplace`: carry out model transformations in-place, the original module is
|
||||
mutated
|
||||
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||||
|
||||
Return:
|
||||
Quantized TorchSciprt model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.quantization import get_default_qconfig
|
||||
from torch.quantization import quantize_jit
|
||||
|
||||
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||||
qconfig = get_default_qconfig('fbgemm')
|
||||
def calibrate(model, data_loader):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
model(image)
|
||||
|
||||
quantized_model = quantize_jit(
|
||||
ts_model,
|
||||
{'': qconfig},
|
||||
calibrate,
|
||||
[data_loader_test])
|
||||
```
|
||||
"""
|
||||
return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC)
|
||||
|
||||
def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
|
||||
r"""Quantize the input float TorchScript model with
|
||||
post training dynamic quantization.
|
||||
Currently only qint8 quantization of torch.nn.Linear is supported.
|
||||
|
||||
Args:
|
||||
`model`: input float TorchScript model
|
||||
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
|
||||
qconfig for that module as value, please see detailed
|
||||
descriptions in :func:`~torch.quantization.quantize_jit`
|
||||
`inplace`: carry out model transformations in-place, the original module is
|
||||
mutated
|
||||
`debug`: flag for producing a debug friendly model (preserve weight attribute)
|
||||
|
||||
Return:
|
||||
Quantized TorchSciprt model.
|
||||
|
||||
Example:
|
||||
```python
|
||||
import torch
|
||||
from torch.quantization import per_channel_dynamic_qconfig
|
||||
from torch.quantization import quantize_dynmiac_jit
|
||||
|
||||
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
|
||||
qconfig = get_default_qconfig('fbgemm')
|
||||
def calibrate(model, data_loader):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for image, target in data_loader:
|
||||
model(image)
|
||||
|
||||
quantized_model = quantize_dynamic_jit(
|
||||
ts_model,
|
||||
{'': qconfig},
|
||||
calibrate,
|
||||
[data_loader_test])
|
||||
```
|
||||
"""
|
||||
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)
|
||||
|
|
|
|||
Loading…
Reference in a new issue