[quant][graphmode] docstrings for top level APIs (#40328)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40328

Test Plan: Imported from OSS

Differential Revision: D22149708

fbshipit-source-id: 63a1cd229d9e4668fba0ef3977e894cb8984318b
This commit is contained in:
Jerry Zhang 2020-06-19 22:18:57 -07:00 committed by Facebook GitHub Bot
parent 7a837019a4
commit 59ca1d31ca
3 changed files with 140 additions and 39 deletions

View file

@ -242,7 +242,7 @@ Layers for the quantization-aware training
``torch.quantization``
~~~~~~~~~~~~~~~~~~~~~~
* Functions for quantization:
* Functions for eager mode quantization:
* :func:`~torch.quantization.add_observer_` — Adds observer for the leaf
modules (if quantization configuration is provided)
@ -259,24 +259,18 @@ Layers for the quantization-aware training
* :func:`~torch.quantization.propagate_qconfig_` — Propagates quantization
configurations through the module hierarchy and assign them to each leaf
module
* :func:`~torch.quantization.quantize` — Converts a float module to quantized version
* :func:`~torch.quantization.quantize_dynamic` — Converts a float module to
dynamically quantized version
* :func:`~torch.quantization.quantize_qat` — Converts a float module to
quantized version used in quantization aware training
* :func:`~torch.quantization.quantize` — Function for eager mode post training static quantization
* :func:`~torch.quantization.quantize_dynamic` — Function for eager mode post training dynamic quantization
* :func:`~torch.quantization.quantize_qat` — Function for eager mode quantization aware training function
* :func:`~torch.quantization.swap_module` — Swaps the module with its
quantized counterpart (if quantizable and if it has an observer)
* :func:`~torch.quantization.default_eval_fn` — Default evaluation function
* :func:`~torch.quantization.default_eval_fn` — Default evaluation function
used by the :func:`torch.quantization.quantize`
* :func:`~torch.quantization.fuse_modules`
* :class:`~torch.quantization.FakeQuantize` — Module for simulating the
quantization/dequantization at training time
* Default Observers. The rest of observers are available from
``torch.quantization.observer``:
* :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)``
* :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)``
* :class:`~torch.quantization.Observer` — Abstract base class for observers
* :func:`~torch.quantization.fuse_modules`
* Functions for graph mode quantization:
* :func:`~torch.quantization.quantize_jit` - Function for graph mode post training static quantization
* :func:`~torch.quantization.quantize_dynamic_jit` - Function for graph mode post training dynamic quantization
* Quantization configurations
* :class:`~torch.quantization.QConfig` — Quantization configuration class
@ -303,22 +297,27 @@ Layers for the quantization-aware training
quantized. Inserts the :class:`~torch.quantization.QuantStub` and
* :class:`~torch.quantization.DeQuantStub`
Observers for computing the quantization parameters
* Observers for computing the quantization parameters
* Default Observers. The rest of observers are available from
``torch.quantization.observer``:
* :attr:`~torch.quantization.default_observer` — Same as ``MinMaxObserver.with_args(reduce_range=True)``
* :attr:`~torch.quantization.default_weight_observer` — Same as ``MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)``
* :class:`~torch.quantization.Observer` — Abstract base class for observers
* :class:`~torch.quantization.MinMaxObserver` — Derives the quantization
parameters from the running minimum and maximum of the observed tensor inputs
(per tensor variant)
* :class:`~torch.quantization.MovingAverageMinMaxObserver` — Derives the
quantization parameters from the running averages of the minimums and
maximums of the observed tensor inputs (per tensor variant)
* :class:`~torch.quantization.PerChannelMinMaxObserver` — Derives the
quantization parameters from the running minimum and maximum of the observed
tensor inputs (per channel variant)
* :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives
the quantization parameters from the running averages of the minimums and
maximums of the observed tensor inputs (per channel variant)
* :class:`~torch.quantization.HistogramObserver` — Derives the quantization
parameters by creating a histogram of running minimums and maximums.
* :class:`~torch.quantization.MinMaxObserver` — Derives the quantization
parameters from the running minimum and maximum of the observed tensor inputs
(per tensor variant)
* :class:`~torch.quantization.MovingAverageMinMaxObserver` — Derives the
quantization parameters from the running averages of the minimums and
maximums of the observed tensor inputs (per tensor variant)
* :class:`~torch.quantization.PerChannelMinMaxObserver` — Derives the
quantization parameters from the running minimum and maximum of the observed
tensor inputs (per channel variant)
* :class:`~torch.quantization.MovingAveragePerChannelMinMaxObserver` — Derives
the quantization parameters from the running averages of the minimums and
maximums of the observed tensor inputs (per channel variant)
* :class:`~torch.quantization.HistogramObserver` — Derives the quantization
parameters by creating a histogram of running minimums and maximums.
* Observers that do not compute the quantization parameters:
* :class:`~torch.quantization.RecordingObserver` — Records all incoming
tensors. Used for debugging only.
@ -326,6 +325,10 @@ Observers for computing the quantization parameters
for situation when there are no quantization parameters (i.e.
quantization to ``float16``)
* FakeQuantize module
* :class:`~torch.quantization.FakeQuantize` — Module for simulating the
quantization/dequantization at training time
``torch.nn.quantized``
~~~~~~~~~~~~~~~~~~~~~~

View file

@ -187,17 +187,15 @@ def _remove_qconfig(module):
del module.qconfig
def quantize(model, run_fn, run_args, mapping=None, inplace=False):
r"""Converts a float model to quantized model.
r"""Quantize the input float model with post training static quantization.
First it will prepare the model for calibration or training, then it calls
`run_fn` which will run the calibration step or training step,
after that we will call `convert` which will convert the model to a
quantized model.
First it will prepare the model for calibration, then it calls
`run_fn` which will run the calibration step, after that we will
convert the model to a quantized model.
Args:
model: input model
run_fn: a function for evaluating the prepared model, can be a
function that simply runs the prepared model or a training loop
model: input float model
run_fn: a calibration function for calibrating the prepared model
run_args: positional arguments for `run_fn`
inplace: carry out model transformations in-place, the original module is mutated
mapping: correspondence between original module types and quantized counterparts

View file

@ -20,14 +20,28 @@ def _check_forward_method(model):
raise ValueError('input script module does not have forward method')
def script_qconfig(qconfig):
r"""Instantiate the activation and weight observer modules and script
them, these observer module instances will be deepcopied during
prepare_jit step.
"""
return QConfig(
activation=torch.jit.script(qconfig.activation())._c,
weight=torch.jit.script(qconfig.weight())._c)
def script_qconfig_dict(qconfig_dict):
r"""Helper function used by `prepare_jit`.
Apply `script_qconfig` for all entries in `qconfig_dict` that is
not None.
"""
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
def fuse_conv_bn_jit(model):
r""" Fuse conv - bn module
Works for eval model only.
Args:
model: TorchScript model from scripting or tracing
"""
return torch.jit._recursive.wrap_cpp_module(torch._C._jit_pass_fold_convbn(model._c))
def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
@ -87,7 +101,93 @@ def _quantize_jit(model, qconfig_dict, run_fn=None, run_args=None, inplace=False
return model
def quantize_jit(model, qconfig_dict, run_fn, run_args, inplace=False, debug=False):
r"""Quantize the input float TorchScript model with
post training static quantization.
First it will prepare the model for calibration, then it calls
`run_fn` which will run the calibration step, after that we will
convert the model to a quantized model.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
qconfig for that module as value, empty key means the qconfig will be applied
to whole model unless its overwritten by more specific configurations, the
qconfig for each module is either found in the dictionary or fallback to
the qconfig of parent module.
Right now qconfig_dict is the only way to configure how the model is quantized,
and it is done in the granularity of module, that is, we only support one type
of qconfig for each torch.nn.Module, and the qconfig for sub module will
override the qconfig for parent module, empty string means global configuration.
`run_fn`: a calibration function for calibrating the prepared model
`run_args`: positional arguments for `run_fn`
`inplace`: carry out model transformations in-place, the original module is
mutated
`debug`: flag for producing a debug friendly model (preserve weight attribute)
Return:
Quantized TorchSciprt model.
Example:
```python
import torch
from torch.quantization import get_default_qconfig
from torch.quantization import quantize_jit
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
quantized_model = quantize_jit(
ts_model,
{'': qconfig},
calibrate,
[data_loader_test])
```
"""
return _quantize_jit(model, qconfig_dict, run_fn, run_args, inplace, debug, quant_type=QuantType.STATIC)
def quantize_dynamic_jit(model, qconfig_dict, inplace=False, debug=False):
r"""Quantize the input float TorchScript model with
post training dynamic quantization.
Currently only qint8 quantization of torch.nn.Linear is supported.
Args:
`model`: input float TorchScript model
`qconfig_dict`: qconfig_dict is a dictionary with names of sub modules as key and
qconfig for that module as value, please see detailed
descriptions in :func:`~torch.quantization.quantize_jit`
`inplace`: carry out model transformations in-place, the original module is
mutated
`debug`: flag for producing a debug friendly model (preserve weight attribute)
Return:
Quantized TorchSciprt model.
Example:
```python
import torch
from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynmiac_jit
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)
qconfig = get_default_qconfig('fbgemm')
def calibrate(model, data_loader):
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
quantized_model = quantize_dynamic_jit(
ts_model,
{'': qconfig},
calibrate,
[data_loader_test])
```
"""
return _quantize_jit(model, qconfig_dict, inplace=inplace, debug=debug, quant_type=QuantType.DYNAMIC)