mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant][docs] Additonal fixes for quantize_fx docs (#84587)
Summary: Some more clarifications for the arguments, including linking to object docs (QConfigMapping, BackendConfig) and adding types in the doc Test Plan: ``` cd docs make html ``` and visual inspection for the generated docs Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/84587 Approved by: https://github.com/vkuzo
This commit is contained in:
parent
0a89bdf989
commit
214a6500e3
8 changed files with 213 additions and 190 deletions
|
|
@ -2,7 +2,7 @@ Quantization API Reference
|
|||
-------------------------------
|
||||
|
||||
torch.quantization
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains Eager mode quantization APIs.
|
||||
|
||||
|
|
@ -52,7 +52,7 @@ Utility functions
|
|||
get_observer_dict
|
||||
|
||||
torch.quantization.quantize_fx
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains FX graph mode quantization APIs (prototype).
|
||||
|
||||
|
|
@ -68,6 +68,59 @@ This module contains FX graph mode quantization APIs (prototype).
|
|||
convert_fx
|
||||
fuse_fx
|
||||
|
||||
torch.ao.quantization.qconfig_mapping
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains QConfigMapping for configuring FX graph mode quantization.
|
||||
|
||||
.. currentmodule:: torch.ao.quantization.qconfig_mapping
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
QConfigMapping
|
||||
get_default_qconfig_mapping
|
||||
get_default_qat_qconfig_mapping
|
||||
|
||||
torch.ao.quantization.backend_config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains BackendConfig, a config object that defines how quantization is supported
|
||||
in a backend. Currently only used by FX Graph Mode Quantization, but we may extend Eager Mode
|
||||
Quantization to work with this as well.
|
||||
|
||||
.. currentmodule:: torch.ao.quantization.backend_config
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
BackendConfig
|
||||
BackendPatternConfig
|
||||
DTypeConfig
|
||||
ObservationType
|
||||
|
||||
torch.ao.quantization.fx.custom_config
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization
|
||||
|
||||
|
||||
.. currentmodule:: torch.ao.quantization.fx.custom_config
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
:template: classtemplate.rst
|
||||
|
||||
FuseCustomConfig
|
||||
PrepareCustomConfig
|
||||
ConvertCustomConfig
|
||||
StandaloneModuleConfigEntry
|
||||
|
||||
torch (quantization related functions)
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
|
@ -127,7 +180,7 @@ regular full-precision tensor.
|
|||
|
||||
|
||||
torch.quantization.observer
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module contains observers which are used to collect statistics about
|
||||
the values observed during calibration (PTQ) or training (QAT).
|
||||
|
|
@ -160,7 +213,7 @@ the values observed during calibration (PTQ) or training (QAT).
|
|||
default_float_qparams_observer
|
||||
|
||||
torch.quantization.fake_quantize
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module implements modules which are used to perform fake quantization
|
||||
during QAT.
|
||||
|
|
@ -189,7 +242,7 @@ during QAT.
|
|||
enable_observer
|
||||
|
||||
torch.quantization.qconfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This module defines `QConfig` objects which are used
|
||||
to configure quantization settings for individual ops.
|
||||
|
|
|
|||
|
|
@ -91,15 +91,6 @@
|
|||
"Union",
|
||||
"get_combined_dict"
|
||||
],
|
||||
"torch.ao.quantization.backend_config.utils": [
|
||||
"Any",
|
||||
"Dict",
|
||||
"Callable",
|
||||
"List",
|
||||
"Union",
|
||||
"Tuple",
|
||||
"Pattern"
|
||||
],
|
||||
"torch.ao.quantization.backend_config.native": [
|
||||
"Any",
|
||||
"Dict",
|
||||
|
|
|
|||
|
|
@ -47,14 +47,21 @@ OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer"
|
|||
# TODO: maybe rename this to something that's not related to observer
|
||||
# e.g. QParamsType
|
||||
class ObservationType(Enum):
|
||||
# this means input and output are observed with different observers, based
|
||||
# on qconfig.activation
|
||||
# example: conv, linear, softmax
|
||||
""" An enum that represents different ways of how an operator/operator pattern
|
||||
should be observed
|
||||
"""
|
||||
|
||||
OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0
|
||||
# this means the output will use the same observer instance as input, based
|
||||
# on qconfig.activation
|
||||
# example: torch.cat, maxpool
|
||||
"""this means input and output are observed with different observers, based
|
||||
on qconfig.activation
|
||||
example: conv, linear, softmax
|
||||
"""
|
||||
|
||||
OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1
|
||||
"""this means the output will use the same observer instance as input, based
|
||||
on qconfig.activation
|
||||
example: torch.cat, maxpool
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class DTypeConfig:
|
||||
|
|
@ -71,7 +78,7 @@ class DTypeConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig:
|
||||
"""
|
||||
Create a `DTypeConfig` from a dictionary with the following items (all optional):
|
||||
Create a ``DTypeConfig`` from a dictionary with the following items (all optional):
|
||||
|
||||
"input_dtype": torch.dtype
|
||||
"output_dtype": torch.dtype
|
||||
|
|
@ -88,7 +95,7 @@ class DTypeConfig:
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `DTypeConfig` to a dictionary with the items described in
|
||||
Convert this ``DTypeConfig`` to a dictionary with the items described in
|
||||
:func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`.
|
||||
"""
|
||||
dtype_config_dict: Dict[str, Any] = {}
|
||||
|
|
@ -107,16 +114,18 @@ class DTypeConfig:
|
|||
|
||||
class BackendConfig:
|
||||
# TODO: refer to NativeBackendConfig once that is implemented
|
||||
"""
|
||||
Config that defines the set of patterns that can be quantized on a given backend, and how reference
|
||||
"""Config that defines the set of patterns that can be quantized on a given backend, and how reference
|
||||
quantized models can be produced from these patterns.
|
||||
|
||||
A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph
|
||||
of the above. Each pattern supported on the target backend can be individually configured through
|
||||
:class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of:
|
||||
(1) The supported input/output activation, weight, and bias data types
|
||||
(2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
|
||||
(3) (Optionally) Fusion, QAT, and reference module mappings.
|
||||
|
||||
(1) The supported input/output activation, weight, and bias data types
|
||||
|
||||
(2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and
|
||||
|
||||
(3) (Optionally) Fusion, QAT, and reference module mappings.
|
||||
|
||||
The format of the patterns is described in:
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md
|
||||
|
|
@ -149,6 +158,7 @@ class BackendConfig:
|
|||
backend_config = BackendConfig("my_backend") \
|
||||
.set_backend_pattern_config(linear_config) \
|
||||
.set_backend_pattern_config(conv_relu_config)
|
||||
|
||||
"""
|
||||
def __init__(self, name: str = ""):
|
||||
self.name = name
|
||||
|
|
@ -181,10 +191,12 @@ class BackendConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig:
|
||||
"""
|
||||
Create a `BackendConfig` from a dictionary with the following items:
|
||||
Create a ``BackendConfig`` from a dictionary with the following items:
|
||||
|
||||
"name": the name of the target backend
|
||||
|
||||
"configs": a list of dictionaries that each represents a `BackendPatternConfig`
|
||||
|
||||
"""
|
||||
conf = cls(backend_config_dict.get(NAME_DICT_KEY, ""))
|
||||
for d in backend_config_dict.get(CONFIGS_DICT_KEY, []):
|
||||
|
|
@ -198,7 +210,7 @@ class BackendConfig:
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `BackendConfig` to a dictionary with the items described in
|
||||
Convert this ``BackendConfig`` to a dictionary with the items described in
|
||||
:func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`.
|
||||
"""
|
||||
return {
|
||||
|
|
@ -210,19 +222,8 @@ class BackendConfig:
|
|||
class BackendPatternConfig:
|
||||
"""
|
||||
Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`.
|
||||
|
||||
The user can configure how a operator pattern graph is handled on a given backend using the following methods:
|
||||
`set_observation_type`: sets how observers should be inserted for this pattern.
|
||||
See :class:`~torch.ao.quantization.backend_config.ObservationType`
|
||||
`add_dtype_config`: add a set of supported data types for this pattern
|
||||
`set_root_module`: sets the module that represents the root for this pattern
|
||||
`set_qat_module`: sets the module that represents the QAT implementation for this pattern
|
||||
`set_reference_quantized_module`: sets the module that represents the reference quantized
|
||||
implementation for this pattern's root module.
|
||||
`set_fused_module`: sets the module that represents the fused implementation for this pattern
|
||||
`set_fuser_method`: sets the function that specifies how to fuse the pattern for this pattern
|
||||
|
||||
For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`.
|
||||
|
||||
"""
|
||||
def __init__(self, pattern: Pattern):
|
||||
self.pattern = pattern
|
||||
|
|
@ -246,13 +247,15 @@ class BackendPatternConfig:
|
|||
def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig:
|
||||
"""
|
||||
Set how observers should be inserted for this pattern.
|
||||
See :class:`~torch.ao.quantization.backend_config.ObservationType` for details
|
||||
|
||||
"""
|
||||
self.observation_type = observation_type
|
||||
return self
|
||||
|
||||
def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig:
|
||||
"""
|
||||
Register a set of supported input/output activation, weight, and bias data types for this pattern.
|
||||
Add a set of supported input/output activation, weight, and bias data types for this pattern.
|
||||
"""
|
||||
self.dtype_configs.append(dtype_config)
|
||||
return self
|
||||
|
|
@ -333,22 +336,23 @@ class BackendPatternConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig:
|
||||
"""
|
||||
Create a `BackendPatternConfig` from a dictionary with the following items:
|
||||
Create a ``BackendPatternConfig`` from a dictionary with the following items:
|
||||
|
||||
"pattern": the pattern being configured
|
||||
"observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how
|
||||
observers should be inserted for this pattern
|
||||
"dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig`s
|
||||
observers should be inserted for this pattern
|
||||
"dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s
|
||||
"root_module": a :class:`torch.nn.Module` that represents the root for this pattern
|
||||
"qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern
|
||||
"reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized
|
||||
implementation for this pattern's root module.
|
||||
implementation for this pattern's root module.
|
||||
"fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern
|
||||
"fuser_method": a function that specifies how to fuse the pattern for this pattern
|
||||
|
||||
"""
|
||||
def _get_dtype_config(obj: Any) -> DTypeConfig:
|
||||
"""
|
||||
Convert the given object into a `DTypeConfig` if possible, else throw an exception.
|
||||
Convert the given object into a ``DTypeConfig`` if possible, else throw an exception.
|
||||
"""
|
||||
if isinstance(obj, DTypeConfig):
|
||||
return obj
|
||||
|
|
@ -381,7 +385,7 @@ class BackendPatternConfig:
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `BackendPatternConfig` to a dictionary with the items described in
|
||||
Convert this ``BackendPatternConfig`` to a dictionary with the items described in
|
||||
:func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`.
|
||||
"""
|
||||
backend_pattern_config_dict: Dict[str, Any] = {
|
||||
|
|
|
|||
|
|
@ -6,6 +6,21 @@ import torch.nn.functional as F
|
|||
from .backend_config import BackendConfig, DTypeConfig
|
||||
from ..quantization_types import Pattern
|
||||
|
||||
__all__ = [
|
||||
"get_pattern_to_dtype_configs",
|
||||
"get_qat_module_classes",
|
||||
"get_fused_module_classes",
|
||||
"get_pattern_to_input_type_to_index",
|
||||
"get_root_module_to_quantized_reference_module",
|
||||
"get_fuser_method_mapping",
|
||||
"get_module_to_qat_module",
|
||||
"get_fusion_pattern_to_root_node_getter",
|
||||
"get_fusion_pattern_to_extra_inputs_getter",
|
||||
"remove_boolean_dispatch_from_name",
|
||||
"pattern_to_human_readable",
|
||||
"entry_to_pretty_str",
|
||||
]
|
||||
|
||||
def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]:
|
||||
pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {}
|
||||
for pattern, config in backend_config.configs.items():
|
||||
|
|
|
|||
|
|
@ -42,18 +42,6 @@ class PrepareCustomConfig:
|
|||
Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and
|
||||
:func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`.
|
||||
|
||||
The user can set custom configuration using the following methods:
|
||||
|
||||
`set_standalone_module_name`: sets the config for preparing a standalone module for quantization, identified by name
|
||||
`set_standalone_module_class`: sets the config for preparing a standalone module for quantization, identified by class
|
||||
`set_float_to_observed_mapping`: sets the mapping from a float module class to an observed module class
|
||||
`set_non_traceable_module_names`: sets modules that are not symbolically traceable, identified by name
|
||||
`set_non_traceable_module_classes`: sets modules that are not symbolically traceable, identified by class
|
||||
`set_input_quantized_indexes`: sets the indexes of the inputs of the graph that should be quantized.
|
||||
`set_output_quantized_indexes`: sets the indexes of the outputs of the graph that should be quantized.
|
||||
`set_preserved_attributes`: sets the names of the attributes that will persist in the graph module even
|
||||
if they are not used in the model's `forward` method
|
||||
|
||||
Example usage::
|
||||
|
||||
prepare_custom_config = PrepareCustomConfig() \
|
||||
|
|
@ -86,11 +74,11 @@ class PrepareCustomConfig:
|
|||
prepare_custom_config: Optional[PrepareCustomConfig],
|
||||
backend_config: Optional[BackendConfig]) -> PrepareCustomConfig:
|
||||
"""
|
||||
Set the configuration for running a standalone module identified by `module_name`.
|
||||
Set the configuration for running a standalone module identified by ``module_name``.
|
||||
|
||||
If `qconfig_mapping` is None, the parent `qconfig_mapping` will be used instead.
|
||||
If `prepare_custom_config` is None, an empty `PrepareCustomConfig` will be used.
|
||||
If `backend_config` is None, the parent `backend_config` will be used instead.
|
||||
If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
|
||||
If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
|
||||
If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
|
||||
"""
|
||||
self.standalone_module_names[module_name] = \
|
||||
StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
|
||||
|
|
@ -104,11 +92,11 @@ class PrepareCustomConfig:
|
|||
prepare_custom_config: Optional[PrepareCustomConfig],
|
||||
backend_config: Optional[BackendConfig]) -> PrepareCustomConfig:
|
||||
"""
|
||||
Set the configuration for running a standalone module identified by `module_class`.
|
||||
Set the configuration for running a standalone module identified by ``module_class``.
|
||||
|
||||
If `qconfig_mapping` is None, the parent `qconfig_mapping` will be used instead.
|
||||
If `prepare_custom_config` is None, an empty `PrepareCustomConfig` will be used.
|
||||
If `backend_config` is None, the parent `backend_config` will be used instead.
|
||||
If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead.
|
||||
If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used.
|
||||
If ``backend_config`` is None, the parent ``backend_config`` will be used instead.
|
||||
"""
|
||||
self.standalone_module_classes[module_class] = \
|
||||
StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config)
|
||||
|
|
@ -122,7 +110,7 @@ class PrepareCustomConfig:
|
|||
"""
|
||||
Set the mapping from a custom float module class to a custom observed module class.
|
||||
|
||||
The observed module class must have a `from_float` class method that converts the float module class
|
||||
The observed module class must have a ``from_float`` class method that converts the float module class
|
||||
to the observed module class. This is currently only supported for static quantization.
|
||||
"""
|
||||
if quant_type != QuantType.STATIC:
|
||||
|
|
@ -165,7 +153,7 @@ class PrepareCustomConfig:
|
|||
def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig:
|
||||
"""
|
||||
Set the names of the attributes that will persist in the graph module even if they are not used in
|
||||
the model's `forward` method.
|
||||
the model's ``forward`` method.
|
||||
"""
|
||||
self.preserved_attributes = attributes
|
||||
return self
|
||||
|
|
@ -174,23 +162,23 @@ class PrepareCustomConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, prepare_custom_config_dict: Dict[str, Any]) -> PrepareCustomConfig:
|
||||
"""
|
||||
Create a `PrepareCustomConfig` from a dictionary with the following items:
|
||||
Create a ``PrepareCustomConfig`` from a dictionary with the following items:
|
||||
|
||||
"standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs,
|
||||
child_prepare_custom_config, backend_config) tuples
|
||||
child_prepare_custom_config, backend_config) tuples
|
||||
|
||||
"standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs,
|
||||
child_prepare_custom_config, backend_config) tuples
|
||||
child_prepare_custom_config, backend_config) tuples
|
||||
|
||||
"float_to_observed_custom_module_class": a nested dictionary mapping from quantization
|
||||
mode to an inner mapping from float module classes to observed module classes, e.g.
|
||||
{"static": {FloatCustomModule: ObservedCustomModule}}
|
||||
mode to an inner mapping from float module classes to observed module classes, e.g.
|
||||
{"static": {FloatCustomModule: ObservedCustomModule}}
|
||||
|
||||
"non_traceable_module_name": a list of modules names that are not symbolically traceable
|
||||
"non_traceable_module_class": a list of module classes that are not symbolically traceable
|
||||
"input_quantized_idxs": a list of indexes of graph inputs that should be quantized
|
||||
"output_quantized_idxs": a list of indexes of graph outputs that should be quantized
|
||||
"preserved_attributes": a list of attributes that persist even if they are not used in `forward`
|
||||
"preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
|
||||
|
||||
This function is primarily for backward compatibility and may be removed in the future.
|
||||
"""
|
||||
|
|
@ -255,7 +243,7 @@ class PrepareCustomConfig:
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `PrepareCustomConfig` to a dictionary with the items described in
|
||||
Convert this ``PrepareCustomConfig`` to a dictionary with the items described in
|
||||
:func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`.
|
||||
"""
|
||||
def _make_tuple(key: Any, e: StandaloneModuleConfigEntry):
|
||||
|
|
@ -293,12 +281,6 @@ class ConvertCustomConfig:
|
|||
"""
|
||||
Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`.
|
||||
|
||||
The user can set custom configuration using the following methods:
|
||||
|
||||
`set_observed_to_quantized_mapping`: sets the mapping from an observed module class to a quantized module class
|
||||
`set_preserved_attributes`: sets the names of the attributes that will persist in the graph module even if they
|
||||
are not used in the model's `forward` method
|
||||
|
||||
Example usage::
|
||||
|
||||
convert_custom_config = ConvertCustomConfig() \
|
||||
|
|
@ -318,7 +300,7 @@ class ConvertCustomConfig:
|
|||
"""
|
||||
Set the mapping from a custom observed module class to a custom quantized module class.
|
||||
|
||||
The quantized module class must have a `from_observed` class method that converts the observed module class
|
||||
The quantized module class must have a ``from_observed`` class method that converts the observed module class
|
||||
to the quantized module class.
|
||||
"""
|
||||
if quant_type not in self.observed_to_quantized_mapping:
|
||||
|
|
@ -329,7 +311,7 @@ class ConvertCustomConfig:
|
|||
def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig:
|
||||
"""
|
||||
Set the names of the attributes that will persist in the graph module even if they are not used in
|
||||
the model's `forward` method.
|
||||
the model's ``forward`` method.
|
||||
"""
|
||||
self.preserved_attributes = attributes
|
||||
return self
|
||||
|
|
@ -338,17 +320,16 @@ class ConvertCustomConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, convert_custom_config_dict: Dict[str, Any]) -> ConvertCustomConfig:
|
||||
"""
|
||||
Create a `ConvertCustomConfig` from a dictionary with the following items:
|
||||
Create a ``ConvertCustomConfig`` from a dictionary with the following items:
|
||||
|
||||
"observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization
|
||||
mode to an inner mapping from observed module classes to quantized module classes, e.g.
|
||||
{
|
||||
"static": {FloatCustomModule: ObservedCustomModule},
|
||||
"dynamic": {FloatCustomModule: ObservedCustomModule},
|
||||
"weight_only": {FloatCustomModule: ObservedCustomModule}
|
||||
}
|
||||
|
||||
"preserved_attributes": a list of attributes that persist even if they are not used in `forward`
|
||||
mode to an inner mapping from observed module classes to quantized module classes, e.g.::
|
||||
{
|
||||
"static": {FloatCustomModule: ObservedCustomModule},
|
||||
"dynamic": {FloatCustomModule: ObservedCustomModule},
|
||||
"weight_only": {FloatCustomModule: ObservedCustomModule}
|
||||
}
|
||||
"preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
|
||||
|
||||
This function is primarily for backward compatibility and may be removed in the future.
|
||||
"""
|
||||
|
|
@ -362,7 +343,7 @@ class ConvertCustomConfig:
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `ConvertCustomConfig` to a dictionary with the items described in
|
||||
Convert this ``ConvertCustomConfig`` to a dictionary with the items described in
|
||||
:func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
|
||||
"""
|
||||
d: Dict[str, Any] = {}
|
||||
|
|
@ -379,11 +360,6 @@ class FuseCustomConfig:
|
|||
"""
|
||||
Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`.
|
||||
|
||||
The user can set custom configuration using the following method:
|
||||
|
||||
`set_preserved_attributes`: sets the names of the attributes that will persist in the graph module
|
||||
even if they are not used in the model's `forward` method
|
||||
|
||||
Example usage::
|
||||
|
||||
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"])
|
||||
|
|
@ -395,7 +371,7 @@ class FuseCustomConfig:
|
|||
def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig:
|
||||
"""
|
||||
Set the names of the attributes that will persist in the graph module even if they are not used in
|
||||
the model's `forward` method.
|
||||
the model's ``forward`` method.
|
||||
"""
|
||||
self.preserved_attributes = attributes
|
||||
return self
|
||||
|
|
@ -404,9 +380,9 @@ class FuseCustomConfig:
|
|||
@classmethod
|
||||
def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig:
|
||||
"""
|
||||
Create a `ConvertCustomConfig` from a dictionary with the following items:
|
||||
Create a ``ConvertCustomConfig`` from a dictionary with the following items:
|
||||
|
||||
"preserved_attributes": a list of attributes that persist even if they are not used in `forward`
|
||||
"preserved_attributes": a list of attributes that persist even if they are not used in ``forward``
|
||||
|
||||
This function is primarily for backward compatibility and may be removed in the future.
|
||||
"""
|
||||
|
|
@ -416,7 +392,7 @@ class FuseCustomConfig:
|
|||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `FuseCustomConfig` to a dictionary with the items described in
|
||||
Convert this ``FuseCustomConfig`` to a dictionary with the items described in
|
||||
:func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`.
|
||||
"""
|
||||
d: Dict[str, Any] = {}
|
||||
|
|
|
|||
|
|
@ -116,28 +116,39 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC
|
|||
def get_default_qconfig_mapping(backend="fbgemm", version=0) -> QConfigMapping:
|
||||
"""
|
||||
Return the default QConfigMapping for post training quantization.
|
||||
|
||||
Args:
|
||||
* ``backend`` : the quantization backend for the default qconfig mapping, should be
|
||||
one of ["fbgemm", "qnnpack"]
|
||||
* ``version`` : the version for the default qconfig mapping
|
||||
"""
|
||||
# TODO: add assert for backend choices
|
||||
return _get_default_qconfig_mapping(False, backend, version)
|
||||
|
||||
def get_default_qat_qconfig_mapping(backend="fbgemm", version=1) -> QConfigMapping:
|
||||
"""
|
||||
Return the default QConfigMapping for quantization aware training.
|
||||
|
||||
Args:
|
||||
* ``backend`` : the quantization backend for the default qconfig mapping, should be
|
||||
one of ["fbgemm", "qnnpack"]
|
||||
* ``version`` : the version for the default qconfig mapping
|
||||
"""
|
||||
return _get_default_qconfig_mapping(True, backend, version)
|
||||
|
||||
|
||||
class QConfigMapping:
|
||||
"""
|
||||
Mapping from model ops to :class:`torch.ao.quantization.QConfig`s.
|
||||
Mapping from model ops to :class:`torch.ao.quantization.QConfig` s.
|
||||
|
||||
The user can specify QConfigs using the following methods (in increasing match priority):
|
||||
|
||||
`set_global`: sets the global (default) QConfig
|
||||
`set_object_type`: sets the QConfig for a given module type, function, or method name
|
||||
`set_module_name_regex`: sets the QConfig for modules matching the given regex string
|
||||
`set_module_name`: sets the QConfig for modules matching the given module name
|
||||
`set_module_name_object_type_order`: sets the QConfig for modules matching a combination
|
||||
of the given module name, object type, and the index at which the module appears
|
||||
``set_global`` : sets the global (default) QConfig
|
||||
``set_object_type`` : sets the QConfig for a given module type, function, or method name
|
||||
``set_module_name_regex`` : sets the QConfig for modules matching the given regex string
|
||||
``set_module_name`` : sets the QConfig for modules matching the given module name
|
||||
``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination
|
||||
of the given module name, object type, and the index at which the module appears
|
||||
|
||||
Example usage::
|
||||
|
||||
|
|
@ -150,6 +161,7 @@ class QConfigMapping:
|
|||
.set_module_name("module1", qconfig1)
|
||||
.set_module_name("module2", qconfig2)
|
||||
.set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -224,12 +236,16 @@ class QConfigMapping:
|
|||
# TODO: remove this
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert this `QConfigMapping` to a dictionary with the following keys:
|
||||
Convert this ``QConfigMapping`` to a dictionary with the following keys:
|
||||
|
||||
"" (for global QConfig)
|
||||
|
||||
"object_type"
|
||||
|
||||
"module_name_regex"
|
||||
|
||||
"module_name"
|
||||
|
||||
"module_name_object_type_order"
|
||||
|
||||
The values of this dictionary are lists of tuples.
|
||||
|
|
@ -248,12 +264,16 @@ class QConfigMapping:
|
|||
@classmethod
|
||||
def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping:
|
||||
"""
|
||||
Create a `QConfigMapping` from a dictionary with the following keys (all optional):
|
||||
Create a ``QConfigMapping`` from a dictionary with the following keys (all optional):
|
||||
|
||||
"" (for global QConfig)
|
||||
|
||||
"object_type"
|
||||
|
||||
"module_name_regex"
|
||||
|
||||
"module_name"
|
||||
|
||||
"module_name_object_type_order"
|
||||
|
||||
The values of this dictionary are expected to be lists of tuples.
|
||||
|
|
|
|||
|
|
@ -236,13 +236,9 @@ def fuse_fx(
|
|||
|
||||
Args:
|
||||
|
||||
* `model`: a torch.nn.Module model
|
||||
* `fuse_custom_config`: custom configurations for fuse_fx.
|
||||
See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more detail::
|
||||
|
||||
from torch.ao.quantization.fx.custom_config import FuseCustomConfig
|
||||
fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["preserved_attr"])
|
||||
|
||||
* `model` (torch.nn.Module): a torch.nn.Module model
|
||||
* `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx.
|
||||
See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details
|
||||
Example::
|
||||
|
||||
from torch.ao.quantization import fuse_fx
|
||||
|
|
@ -280,52 +276,24 @@ def prepare_fx(
|
|||
r""" Prepare a model for post training static quantization
|
||||
|
||||
Args:
|
||||
* `model` (required): torch.nn.Module model, must be in eval mode
|
||||
* `model` (torch.nn.Module): torch.nn.Module model
|
||||
|
||||
* `qconfig_mapping` (required): mapping from model ops to qconfigs::
|
||||
* `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is
|
||||
quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping`
|
||||
for more details
|
||||
|
||||
from torch.quantization import QConfigMapping
|
||||
* `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model,
|
||||
Tuple of positional args (keyword args can be passed as positional args as well)
|
||||
|
||||
qconfig_mapping = QConfigMapping() \
|
||||
.set_global(global_qconfig) \
|
||||
.set_object_type(torch.nn.Linear, qconfig1) \
|
||||
.set_object_type(torch.nn.functional.linear, qconfig1) \
|
||||
.set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) \
|
||||
.set_module_name_regex("foo.*bar.*", qconfig2) \
|
||||
.set_module_name_regex("foo.*", qconfig3) \
|
||||
.set_module_name("module1", qconfig1) \
|
||||
.set_module_name("module2", qconfig2) \
|
||||
.set_module_name_object_type_order("module3", torch.nn.functional.linear, 0, qconfig3)
|
||||
|
||||
|
||||
The precedence of different settings:
|
||||
set_global < set_object_type < set_module_name_regex < set_module_name < set_module_name_object_type_order
|
||||
* `example_inputs`: (required) Example inputs for forward function of the model
|
||||
|
||||
* `prepare_custom_config`: customization configuration for quantization tool.
|
||||
See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more detail::
|
||||
|
||||
from torch.ao.quantization.fx.custom_config import PrepareCustomConfig
|
||||
|
||||
prepare_custom_config = PrepareCustomConfig() \
|
||||
.set_standalone_module_name("module1", qconfig_mapping, example_inputs, \
|
||||
child_prepare_custom_config, backend_config) \
|
||||
.set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \
|
||||
child_prepare_custom_config, backend_config) \
|
||||
.set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \
|
||||
.set_non_traceable_module_names(["module2", "module3"]) \
|
||||
.set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \
|
||||
.set_input_quantized_indexes([0]) \
|
||||
.set_output_quantized_indexes([0]) \
|
||||
.set_preserved_attributes(["attr1", "attr2"])
|
||||
* `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool.
|
||||
See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details
|
||||
|
||||
* `_equalization_config`: config for specifying how to perform equalization on the model
|
||||
|
||||
* `backend_config`: config that specifies how operators are quantized
|
||||
* `backend_config` (BackendConfig): config that specifies how operators are quantized
|
||||
in a backend, this includes how the operaetors are observed,
|
||||
supported fusion patterns, how quantize/dequantize ops are
|
||||
inserted, supported dtypes etc. The structure of the dictionary is still WIP
|
||||
and will change in the future, please don't use right now.
|
||||
inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
|
||||
|
||||
Return:
|
||||
A GraphModule with observer (configured by qconfig_mapping), ready for calibration
|
||||
|
|
@ -458,14 +426,14 @@ def prepare_qat_fx(
|
|||
r""" Prepare a model for quantization aware training
|
||||
|
||||
Args:
|
||||
* `model`: torch.nn.Module model, must be in train mode
|
||||
* `qconfig_mapping`: see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `example_inputs`: see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `prepare_custom_config`: see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `backend_config`: see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `model` (torch.nn.Module): torch.nn.Module model
|
||||
* `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx`
|
||||
* `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx`
|
||||
|
||||
Return:
|
||||
A GraphModule with fake quant modules (configured by qconfig_mapping), ready for
|
||||
A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for
|
||||
quantization aware training
|
||||
|
||||
Example::
|
||||
|
|
@ -602,23 +570,14 @@ def convert_fx(
|
|||
r""" Convert a calibrated or trained model to a quantized model
|
||||
|
||||
Args:
|
||||
* `graph_module`: A prepared and calibrated/trained model (GraphModule)
|
||||
* `is_reference`: flag for whether to produce a reference quantized model,
|
||||
which will be a common interface between pytorch quantization with
|
||||
other backends like accelerators
|
||||
* `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule)
|
||||
|
||||
* `convert_custom_config`: custom configurations for convert function.
|
||||
See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more detail::
|
||||
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
|
||||
See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details
|
||||
|
||||
from torch.ao.quantization.fx.custom_config import ConvertCustomConfig
|
||||
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
|
||||
|
||||
convert_custom_config = ConvertCustomConfig() \
|
||||
.set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \
|
||||
.set_preserved_attributes(["attr1", "attr2"])
|
||||
|
||||
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
|
||||
|
||||
* `qconfig_mapping`: config for specifying how to convert a model for quantization.
|
||||
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
|
||||
|
||||
The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`,
|
||||
with the same values or `None`. Additional keys can be specified with values set to `None`.
|
||||
|
|
@ -631,14 +590,14 @@ def convert_fx(
|
|||
.set_object_type(torch.nn.functional.linear, qconfig_from_prepare)
|
||||
.set_module_name("foo.bar", None) # skip quantizing module "foo.bar"
|
||||
|
||||
* `backend_config`: A configuration for the backend which describes how
|
||||
* `backend_config` (BackendConfig): A configuration for the backend which describes how
|
||||
operators should be quantized in the backend, this includes quantization
|
||||
mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.),
|
||||
observer placement for each operators and fused operators. Detailed
|
||||
documentation can be found in torch/ao/quantization/backend_config/README.md
|
||||
observer placement for each operators and fused operators.
|
||||
See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details
|
||||
|
||||
Return:
|
||||
A quantized model (GraphModule)
|
||||
A quantized model (torch.nn.Module)
|
||||
|
||||
Example::
|
||||
|
||||
|
|
@ -682,19 +641,19 @@ def convert_to_reference_fx(
|
|||
hardware, like accelerators
|
||||
|
||||
Args:
|
||||
* `graph_module`: A prepared and calibrated/trained model (GraphModule)
|
||||
* `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule)
|
||||
|
||||
* `convert_custom_config`: custom configurations for convert function.
|
||||
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail.
|
||||
* `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function.
|
||||
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
||||
|
||||
* `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert.
|
||||
* `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert.
|
||||
|
||||
* `qconfig_mapping`: config for specifying how to convert a model for quantization.
|
||||
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail.
|
||||
* `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization.
|
||||
See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
||||
|
||||
* `backend_config`: A configuration for the backend which describes how
|
||||
* `backend_config` (BackendConfig): A configuration for the backend which describes how
|
||||
operators should be quantized in the backend. See
|
||||
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail.
|
||||
:func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details.
|
||||
|
||||
Return:
|
||||
A reference quantized model (GraphModule)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
__all__ = [
|
||||
"get_static_sparse_quantized_mapping",
|
||||
"get_dynamic_sparse_quantized_mapping",
|
||||
]
|
||||
|
||||
def get_static_sparse_quantized_mapping():
|
||||
import torch.ao.nn.sparse
|
||||
_static_sparse_quantized_mapping = dict({
|
||||
|
|
|
|||
Loading…
Reference in a new issue