diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index aa8d4e1ec93..be3110addb0 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -2,7 +2,7 @@ Quantization API Reference ------------------------------- torch.quantization -~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~ This module contains Eager mode quantization APIs. @@ -52,7 +52,7 @@ Utility functions get_observer_dict torch.quantization.quantize_fx -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This module contains FX graph mode quantization APIs (prototype). @@ -68,6 +68,59 @@ This module contains FX graph mode quantization APIs (prototype). convert_fx fuse_fx +torch.ao.quantization.qconfig_mapping +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This module contains QConfigMapping for configuring FX graph mode quantization. + +.. currentmodule:: torch.ao.quantization.qconfig_mapping + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + QConfigMapping + get_default_qconfig_mapping + get_default_qat_qconfig_mapping + +torch.ao.quantization.backend_config +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This module contains BackendConfig, a config object that defines how quantization is supported +in a backend. Currently only used by FX Graph Mode Quantization, but we may extend Eager Mode +Quantization to work with this as well. + +.. currentmodule:: torch.ao.quantization.backend_config + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + BackendConfig + BackendPatternConfig + DTypeConfig + ObservationType + +torch.ao.quantization.fx.custom_config +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This module contains a few CustomConfig classes that's used in both eager mode and FX graph mode quantization + + +.. currentmodule:: torch.ao.quantization.fx.custom_config + +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + FuseCustomConfig + PrepareCustomConfig + ConvertCustomConfig + StandaloneModuleConfigEntry + torch (quantization related functions) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -127,7 +180,7 @@ regular full-precision tensor. torch.quantization.observer -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This module contains observers which are used to collect statistics about the values observed during calibration (PTQ) or training (QAT). @@ -160,7 +213,7 @@ the values observed during calibration (PTQ) or training (QAT). default_float_qparams_observer torch.quantization.fake_quantize -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This module implements modules which are used to perform fake quantization during QAT. @@ -189,7 +242,7 @@ during QAT. enable_observer torch.quantization.qconfig -~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This module defines `QConfig` objects which are used to configure quantization settings for individual ops. diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 14221aff64e..352ffc3625f 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -91,15 +91,6 @@ "Union", "get_combined_dict" ], - "torch.ao.quantization.backend_config.utils": [ - "Any", - "Dict", - "Callable", - "List", - "Union", - "Tuple", - "Pattern" - ], "torch.ao.quantization.backend_config.native": [ "Any", "Dict", diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index 223fc5ae39b..a2bda2250d5 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -47,14 +47,21 @@ OVERWRITE_OUTPUT_OBSERVER_DICT_KEY = "overwrite_output_observer" # TODO: maybe rename this to something that's not related to observer # e.g. QParamsType class ObservationType(Enum): - # this means input and output are observed with different observers, based - # on qconfig.activation - # example: conv, linear, softmax + """ An enum that represents different ways of how an operator/operator pattern + should be observed + """ + OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT = 0 - # this means the output will use the same observer instance as input, based - # on qconfig.activation - # example: torch.cat, maxpool + """this means input and output are observed with different observers, based + on qconfig.activation + example: conv, linear, softmax + """ + OUTPUT_SHARE_OBSERVER_WITH_INPUT = 1 + """this means the output will use the same observer instance as input, based + on qconfig.activation + example: torch.cat, maxpool + """ @dataclass class DTypeConfig: @@ -71,7 +78,7 @@ class DTypeConfig: @classmethod def from_dict(cls, dtype_config_dict: Dict[str, Any]) -> DTypeConfig: """ - Create a `DTypeConfig` from a dictionary with the following items (all optional): + Create a ``DTypeConfig`` from a dictionary with the following items (all optional): "input_dtype": torch.dtype "output_dtype": torch.dtype @@ -88,7 +95,7 @@ class DTypeConfig: def to_dict(self) -> Dict[str, Any]: """ - Convert this `DTypeConfig` to a dictionary with the items described in + Convert this ``DTypeConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.backend_config.DTypeConfig.from_dict`. """ dtype_config_dict: Dict[str, Any] = {} @@ -107,16 +114,18 @@ class DTypeConfig: class BackendConfig: # TODO: refer to NativeBackendConfig once that is implemented - """ - Config that defines the set of patterns that can be quantized on a given backend, and how reference + """Config that defines the set of patterns that can be quantized on a given backend, and how reference quantized models can be produced from these patterns. A pattern in this context refers to a module, a functional, an operator, or a directed acyclic graph of the above. Each pattern supported on the target backend can be individually configured through :class:`~torch.ao.quantization.backend_config.BackendPatternConfig` in terms of: - (1) The supported input/output activation, weight, and bias data types - (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and - (3) (Optionally) Fusion, QAT, and reference module mappings. + + (1) The supported input/output activation, weight, and bias data types + + (2) How observers and quant/dequant ops are inserted in order to construct the reference pattern, and + + (3) (Optionally) Fusion, QAT, and reference module mappings. The format of the patterns is described in: https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/backend_config/README.md @@ -149,6 +158,7 @@ class BackendConfig: backend_config = BackendConfig("my_backend") \ .set_backend_pattern_config(linear_config) \ .set_backend_pattern_config(conv_relu_config) + """ def __init__(self, name: str = ""): self.name = name @@ -181,10 +191,12 @@ class BackendConfig: @classmethod def from_dict(cls, backend_config_dict: Dict[str, Any]) -> BackendConfig: """ - Create a `BackendConfig` from a dictionary with the following items: + Create a ``BackendConfig`` from a dictionary with the following items: "name": the name of the target backend + "configs": a list of dictionaries that each represents a `BackendPatternConfig` + """ conf = cls(backend_config_dict.get(NAME_DICT_KEY, "")) for d in backend_config_dict.get(CONFIGS_DICT_KEY, []): @@ -198,7 +210,7 @@ class BackendConfig: def to_dict(self) -> Dict[str, Any]: """ - Convert this `BackendConfig` to a dictionary with the items described in + Convert this ``BackendConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.backend_config.BackendConfig.from_dict`. """ return { @@ -210,19 +222,8 @@ class BackendConfig: class BackendPatternConfig: """ Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. - - The user can configure how a operator pattern graph is handled on a given backend using the following methods: - `set_observation_type`: sets how observers should be inserted for this pattern. - See :class:`~torch.ao.quantization.backend_config.ObservationType` - `add_dtype_config`: add a set of supported data types for this pattern - `set_root_module`: sets the module that represents the root for this pattern - `set_qat_module`: sets the module that represents the QAT implementation for this pattern - `set_reference_quantized_module`: sets the module that represents the reference quantized - implementation for this pattern's root module. - `set_fused_module`: sets the module that represents the fused implementation for this pattern - `set_fuser_method`: sets the function that specifies how to fuse the pattern for this pattern - For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. + """ def __init__(self, pattern: Pattern): self.pattern = pattern @@ -246,13 +247,15 @@ class BackendPatternConfig: def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: """ Set how observers should be inserted for this pattern. + See :class:`~torch.ao.quantization.backend_config.ObservationType` for details + """ self.observation_type = observation_type return self def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: """ - Register a set of supported input/output activation, weight, and bias data types for this pattern. + Add a set of supported input/output activation, weight, and bias data types for this pattern. """ self.dtype_configs.append(dtype_config) return self @@ -333,22 +336,23 @@ class BackendPatternConfig: @classmethod def from_dict(cls, backend_pattern_config_dict: Dict[str, Any]) -> BackendPatternConfig: """ - Create a `BackendPatternConfig` from a dictionary with the following items: + Create a ``BackendPatternConfig`` from a dictionary with the following items: "pattern": the pattern being configured "observation_type": the :class:`~torch.ao.quantization.backend_config.ObservationType` that specifies how - observers should be inserted for this pattern - "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig`s + observers should be inserted for this pattern + "dtype_configs": a list of dictionaries that represents :class:`~torch.ao.quantization.backend_config.DTypeConfig` s "root_module": a :class:`torch.nn.Module` that represents the root for this pattern "qat_module": a :class:`torch.nn.Module` that represents the QAT implementation for this pattern "reference_quantized_module": a :class:`torch.nn.Module` that represents the reference quantized - implementation for this pattern's root module. + implementation for this pattern's root module. "fused_module": a :class:`torch.nn.Module` that represents the fused implementation for this pattern "fuser_method": a function that specifies how to fuse the pattern for this pattern + """ def _get_dtype_config(obj: Any) -> DTypeConfig: """ - Convert the given object into a `DTypeConfig` if possible, else throw an exception. + Convert the given object into a ``DTypeConfig`` if possible, else throw an exception. """ if isinstance(obj, DTypeConfig): return obj @@ -381,7 +385,7 @@ class BackendPatternConfig: def to_dict(self) -> Dict[str, Any]: """ - Convert this `BackendPatternConfig` to a dictionary with the items described in + Convert this ``BackendPatternConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.from_dict`. """ backend_pattern_config_dict: Dict[str, Any] = { diff --git a/torch/ao/quantization/backend_config/utils.py b/torch/ao/quantization/backend_config/utils.py index 6645cea8010..187fd7e1b70 100644 --- a/torch/ao/quantization/backend_config/utils.py +++ b/torch/ao/quantization/backend_config/utils.py @@ -6,6 +6,21 @@ import torch.nn.functional as F from .backend_config import BackendConfig, DTypeConfig from ..quantization_types import Pattern +__all__ = [ + "get_pattern_to_dtype_configs", + "get_qat_module_classes", + "get_fused_module_classes", + "get_pattern_to_input_type_to_index", + "get_root_module_to_quantized_reference_module", + "get_fuser_method_mapping", + "get_module_to_qat_module", + "get_fusion_pattern_to_root_node_getter", + "get_fusion_pattern_to_extra_inputs_getter", + "remove_boolean_dispatch_from_name", + "pattern_to_human_readable", + "entry_to_pretty_str", +] + def get_pattern_to_dtype_configs(backend_config: BackendConfig) -> Dict[Pattern, List[DTypeConfig]]: pattern_to_dtype_configs: Dict[Pattern, List[DTypeConfig]] = {} for pattern, config in backend_config.configs.items(): diff --git a/torch/ao/quantization/fx/custom_config.py b/torch/ao/quantization/fx/custom_config.py index 1fdde5e51a3..0f5f5bfe8d1 100644 --- a/torch/ao/quantization/fx/custom_config.py +++ b/torch/ao/quantization/fx/custom_config.py @@ -42,18 +42,6 @@ class PrepareCustomConfig: Custom configuration for :func:`~torch.ao.quantization.quantize_fx.prepare_fx` and :func:`~torch.ao.quantization.quantize_fx.prepare_qat_fx`. - The user can set custom configuration using the following methods: - - `set_standalone_module_name`: sets the config for preparing a standalone module for quantization, identified by name - `set_standalone_module_class`: sets the config for preparing a standalone module for quantization, identified by class - `set_float_to_observed_mapping`: sets the mapping from a float module class to an observed module class - `set_non_traceable_module_names`: sets modules that are not symbolically traceable, identified by name - `set_non_traceable_module_classes`: sets modules that are not symbolically traceable, identified by class - `set_input_quantized_indexes`: sets the indexes of the inputs of the graph that should be quantized. - `set_output_quantized_indexes`: sets the indexes of the outputs of the graph that should be quantized. - `set_preserved_attributes`: sets the names of the attributes that will persist in the graph module even - if they are not used in the model's `forward` method - Example usage:: prepare_custom_config = PrepareCustomConfig() \ @@ -86,11 +74,11 @@ class PrepareCustomConfig: prepare_custom_config: Optional[PrepareCustomConfig], backend_config: Optional[BackendConfig]) -> PrepareCustomConfig: """ - Set the configuration for running a standalone module identified by `module_name`. + Set the configuration for running a standalone module identified by ``module_name``. - If `qconfig_mapping` is None, the parent `qconfig_mapping` will be used instead. - If `prepare_custom_config` is None, an empty `PrepareCustomConfig` will be used. - If `backend_config` is None, the parent `backend_config` will be used instead. + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. """ self.standalone_module_names[module_name] = \ StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config) @@ -104,11 +92,11 @@ class PrepareCustomConfig: prepare_custom_config: Optional[PrepareCustomConfig], backend_config: Optional[BackendConfig]) -> PrepareCustomConfig: """ - Set the configuration for running a standalone module identified by `module_class`. + Set the configuration for running a standalone module identified by ``module_class``. - If `qconfig_mapping` is None, the parent `qconfig_mapping` will be used instead. - If `prepare_custom_config` is None, an empty `PrepareCustomConfig` will be used. - If `backend_config` is None, the parent `backend_config` will be used instead. + If ``qconfig_mapping`` is None, the parent ``qconfig_mapping`` will be used instead. + If ``prepare_custom_config`` is None, an empty ``PrepareCustomConfig`` will be used. + If ``backend_config`` is None, the parent ``backend_config`` will be used instead. """ self.standalone_module_classes[module_class] = \ StandaloneModuleConfigEntry(qconfig_mapping, example_inputs, prepare_custom_config, backend_config) @@ -122,7 +110,7 @@ class PrepareCustomConfig: """ Set the mapping from a custom float module class to a custom observed module class. - The observed module class must have a `from_float` class method that converts the float module class + The observed module class must have a ``from_float`` class method that converts the float module class to the observed module class. This is currently only supported for static quantization. """ if quant_type != QuantType.STATIC: @@ -165,7 +153,7 @@ class PrepareCustomConfig: def set_preserved_attributes(self, attributes: List[str]) -> PrepareCustomConfig: """ Set the names of the attributes that will persist in the graph module even if they are not used in - the model's `forward` method. + the model's ``forward`` method. """ self.preserved_attributes = attributes return self @@ -174,23 +162,23 @@ class PrepareCustomConfig: @classmethod def from_dict(cls, prepare_custom_config_dict: Dict[str, Any]) -> PrepareCustomConfig: """ - Create a `PrepareCustomConfig` from a dictionary with the following items: + Create a ``PrepareCustomConfig`` from a dictionary with the following items: "standalone_module_name": a list of (module_name, qconfig_mapping, example_inputs, - child_prepare_custom_config, backend_config) tuples + child_prepare_custom_config, backend_config) tuples "standalone_module_class" a list of (module_class, qconfig_mapping, example_inputs, - child_prepare_custom_config, backend_config) tuples + child_prepare_custom_config, backend_config) tuples "float_to_observed_custom_module_class": a nested dictionary mapping from quantization - mode to an inner mapping from float module classes to observed module classes, e.g. - {"static": {FloatCustomModule: ObservedCustomModule}} + mode to an inner mapping from float module classes to observed module classes, e.g. + {"static": {FloatCustomModule: ObservedCustomModule}} "non_traceable_module_name": a list of modules names that are not symbolically traceable "non_traceable_module_class": a list of module classes that are not symbolically traceable "input_quantized_idxs": a list of indexes of graph inputs that should be quantized "output_quantized_idxs": a list of indexes of graph outputs that should be quantized - "preserved_attributes": a list of attributes that persist even if they are not used in `forward` + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` This function is primarily for backward compatibility and may be removed in the future. """ @@ -255,7 +243,7 @@ class PrepareCustomConfig: def to_dict(self) -> Dict[str, Any]: """ - Convert this `PrepareCustomConfig` to a dictionary with the items described in + Convert this ``PrepareCustomConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig.from_dict`. """ def _make_tuple(key: Any, e: StandaloneModuleConfigEntry): @@ -293,12 +281,6 @@ class ConvertCustomConfig: """ Custom configuration for :func:`~torch.ao.quantization.quantize_fx.convert_fx`. - The user can set custom configuration using the following methods: - - `set_observed_to_quantized_mapping`: sets the mapping from an observed module class to a quantized module class - `set_preserved_attributes`: sets the names of the attributes that will persist in the graph module even if they - are not used in the model's `forward` method - Example usage:: convert_custom_config = ConvertCustomConfig() \ @@ -318,7 +300,7 @@ class ConvertCustomConfig: """ Set the mapping from a custom observed module class to a custom quantized module class. - The quantized module class must have a `from_observed` class method that converts the observed module class + The quantized module class must have a ``from_observed`` class method that converts the observed module class to the quantized module class. """ if quant_type not in self.observed_to_quantized_mapping: @@ -329,7 +311,7 @@ class ConvertCustomConfig: def set_preserved_attributes(self, attributes: List[str]) -> ConvertCustomConfig: """ Set the names of the attributes that will persist in the graph module even if they are not used in - the model's `forward` method. + the model's ``forward`` method. """ self.preserved_attributes = attributes return self @@ -338,17 +320,16 @@ class ConvertCustomConfig: @classmethod def from_dict(cls, convert_custom_config_dict: Dict[str, Any]) -> ConvertCustomConfig: """ - Create a `ConvertCustomConfig` from a dictionary with the following items: + Create a ``ConvertCustomConfig`` from a dictionary with the following items: "observed_to_quantized_custom_module_class": a nested dictionary mapping from quantization - mode to an inner mapping from observed module classes to quantized module classes, e.g. - { - "static": {FloatCustomModule: ObservedCustomModule}, - "dynamic": {FloatCustomModule: ObservedCustomModule}, - "weight_only": {FloatCustomModule: ObservedCustomModule} - } - - "preserved_attributes": a list of attributes that persist even if they are not used in `forward` + mode to an inner mapping from observed module classes to quantized module classes, e.g.:: + { + "static": {FloatCustomModule: ObservedCustomModule}, + "dynamic": {FloatCustomModule: ObservedCustomModule}, + "weight_only": {FloatCustomModule: ObservedCustomModule} + } + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` This function is primarily for backward compatibility and may be removed in the future. """ @@ -362,7 +343,7 @@ class ConvertCustomConfig: def to_dict(self) -> Dict[str, Any]: """ - Convert this `ConvertCustomConfig` to a dictionary with the items described in + Convert this ``ConvertCustomConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. """ d: Dict[str, Any] = {} @@ -379,11 +360,6 @@ class FuseCustomConfig: """ Custom configuration for :func:`~torch.ao.quantization.quantize_fx.fuse_fx`. - The user can set custom configuration using the following method: - - `set_preserved_attributes`: sets the names of the attributes that will persist in the graph module - even if they are not used in the model's `forward` method - Example usage:: fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["attr1", "attr2"]) @@ -395,7 +371,7 @@ class FuseCustomConfig: def set_preserved_attributes(self, attributes: List[str]) -> FuseCustomConfig: """ Set the names of the attributes that will persist in the graph module even if they are not used in - the model's `forward` method. + the model's ``forward`` method. """ self.preserved_attributes = attributes return self @@ -404,9 +380,9 @@ class FuseCustomConfig: @classmethod def from_dict(cls, fuse_custom_config_dict: Dict[str, Any]) -> FuseCustomConfig: """ - Create a `ConvertCustomConfig` from a dictionary with the following items: + Create a ``ConvertCustomConfig`` from a dictionary with the following items: - "preserved_attributes": a list of attributes that persist even if they are not used in `forward` + "preserved_attributes": a list of attributes that persist even if they are not used in ``forward`` This function is primarily for backward compatibility and may be removed in the future. """ @@ -416,7 +392,7 @@ class FuseCustomConfig: def to_dict(self) -> Dict[str, Any]: """ - Convert this `FuseCustomConfig` to a dictionary with the items described in + Convert this ``FuseCustomConfig`` to a dictionary with the items described in :func:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig.from_dict`. """ d: Dict[str, Any] = {} diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index b4597222869..efcf0bfab94 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -116,28 +116,39 @@ def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QC def get_default_qconfig_mapping(backend="fbgemm", version=0) -> QConfigMapping: """ Return the default QConfigMapping for post training quantization. + + Args: + * ``backend`` : the quantization backend for the default qconfig mapping, should be + one of ["fbgemm", "qnnpack"] + * ``version`` : the version for the default qconfig mapping """ + # TODO: add assert for backend choices return _get_default_qconfig_mapping(False, backend, version) def get_default_qat_qconfig_mapping(backend="fbgemm", version=1) -> QConfigMapping: """ Return the default QConfigMapping for quantization aware training. + + Args: + * ``backend`` : the quantization backend for the default qconfig mapping, should be + one of ["fbgemm", "qnnpack"] + * ``version`` : the version for the default qconfig mapping """ return _get_default_qconfig_mapping(True, backend, version) class QConfigMapping: """ - Mapping from model ops to :class:`torch.ao.quantization.QConfig`s. + Mapping from model ops to :class:`torch.ao.quantization.QConfig` s. The user can specify QConfigs using the following methods (in increasing match priority): - `set_global`: sets the global (default) QConfig - `set_object_type`: sets the QConfig for a given module type, function, or method name - `set_module_name_regex`: sets the QConfig for modules matching the given regex string - `set_module_name`: sets the QConfig for modules matching the given module name - `set_module_name_object_type_order`: sets the QConfig for modules matching a combination - of the given module name, object type, and the index at which the module appears + ``set_global`` : sets the global (default) QConfig + ``set_object_type`` : sets the QConfig for a given module type, function, or method name + ``set_module_name_regex`` : sets the QConfig for modules matching the given regex string + ``set_module_name`` : sets the QConfig for modules matching the given module name + ``set_module_name_object_type_order`` : sets the QConfig for modules matching a combination + of the given module name, object type, and the index at which the module appears Example usage:: @@ -150,6 +161,7 @@ class QConfigMapping: .set_module_name("module1", qconfig1) .set_module_name("module2", qconfig2) .set_module_name_object_type_order("foo.bar", torch.nn.functional.linear, 0, qconfig3) + """ def __init__(self): @@ -224,12 +236,16 @@ class QConfigMapping: # TODO: remove this def to_dict(self) -> Dict[str, Any]: """ - Convert this `QConfigMapping` to a dictionary with the following keys: + Convert this ``QConfigMapping`` to a dictionary with the following keys: "" (for global QConfig) + "object_type" + "module_name_regex" + "module_name" + "module_name_object_type_order" The values of this dictionary are lists of tuples. @@ -248,12 +264,16 @@ class QConfigMapping: @classmethod def from_dict(cls, qconfig_dict: Dict[str, Any]) -> QConfigMapping: """ - Create a `QConfigMapping` from a dictionary with the following keys (all optional): + Create a ``QConfigMapping`` from a dictionary with the following keys (all optional): "" (for global QConfig) + "object_type" + "module_name_regex" + "module_name" + "module_name_object_type_order" The values of this dictionary are expected to be lists of tuples. diff --git a/torch/ao/quantization/quantize_fx.py b/torch/ao/quantization/quantize_fx.py index 8d572e0c4f2..fb6f3dc1fe5 100644 --- a/torch/ao/quantization/quantize_fx.py +++ b/torch/ao/quantization/quantize_fx.py @@ -236,13 +236,9 @@ def fuse_fx( Args: - * `model`: a torch.nn.Module model - * `fuse_custom_config`: custom configurations for fuse_fx. - See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more detail:: - - from torch.ao.quantization.fx.custom_config import FuseCustomConfig - fuse_custom_config = FuseCustomConfig().set_preserved_attributes(["preserved_attr"]) - + * `model` (torch.nn.Module): a torch.nn.Module model + * `fuse_custom_config` (FuseCustomConfig): custom configurations for fuse_fx. + See :class:`~torch.ao.quantization.fx.custom_config.FuseCustomConfig` for more details Example:: from torch.ao.quantization import fuse_fx @@ -280,52 +276,24 @@ def prepare_fx( r""" Prepare a model for post training static quantization Args: - * `model` (required): torch.nn.Module model, must be in eval mode + * `model` (torch.nn.Module): torch.nn.Module model - * `qconfig_mapping` (required): mapping from model ops to qconfigs:: + * `qconfig_mapping` (QConfigMapping): QConfigMapping object to configure how a model is + quantized, see :class:`~torch.ao.quantization.qconfig_mapping.QConfigMapping` + for more details - from torch.quantization import QConfigMapping + * `example_inputs` (Tuple[Any, ...]): Example inputs for forward function of the model, + Tuple of positional args (keyword args can be passed as positional args as well) - qconfig_mapping = QConfigMapping() \ - .set_global(global_qconfig) \ - .set_object_type(torch.nn.Linear, qconfig1) \ - .set_object_type(torch.nn.functional.linear, qconfig1) \ - .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig1) \ - .set_module_name_regex("foo.*bar.*", qconfig2) \ - .set_module_name_regex("foo.*", qconfig3) \ - .set_module_name("module1", qconfig1) \ - .set_module_name("module2", qconfig2) \ - .set_module_name_object_type_order("module3", torch.nn.functional.linear, 0, qconfig3) - - - The precedence of different settings: - set_global < set_object_type < set_module_name_regex < set_module_name < set_module_name_object_type_order - * `example_inputs`: (required) Example inputs for forward function of the model - - * `prepare_custom_config`: customization configuration for quantization tool. - See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more detail:: - - from torch.ao.quantization.fx.custom_config import PrepareCustomConfig - - prepare_custom_config = PrepareCustomConfig() \ - .set_standalone_module_name("module1", qconfig_mapping, example_inputs, \ - child_prepare_custom_config, backend_config) \ - .set_standalone_module_class(MyStandaloneModule, qconfig_mapping, example_inputs, \ - child_prepare_custom_config, backend_config) \ - .set_float_to_observed_mapping(FloatCustomModule, ObservedCustomModule) \ - .set_non_traceable_module_names(["module2", "module3"]) \ - .set_non_traceable_module_classes([NonTraceableModule1, NonTraceableModule2]) \ - .set_input_quantized_indexes([0]) \ - .set_output_quantized_indexes([0]) \ - .set_preserved_attributes(["attr1", "attr2"]) + * `prepare_custom_config` (PrepareCustomConfig): customization configuration for quantization tool. + See :class:`~torch.ao.quantization.fx.custom_config.PrepareCustomConfig` for more details * `_equalization_config`: config for specifying how to perform equalization on the model - * `backend_config`: config that specifies how operators are quantized + * `backend_config` (BackendConfig): config that specifies how operators are quantized in a backend, this includes how the operaetors are observed, supported fusion patterns, how quantize/dequantize ops are - inserted, supported dtypes etc. The structure of the dictionary is still WIP - and will change in the future, please don't use right now. + inserted, supported dtypes etc. See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details Return: A GraphModule with observer (configured by qconfig_mapping), ready for calibration @@ -458,14 +426,14 @@ def prepare_qat_fx( r""" Prepare a model for quantization aware training Args: - * `model`: torch.nn.Module model, must be in train mode - * `qconfig_mapping`: see :func:`~torch.ao.quantization.prepare_fx` - * `example_inputs`: see :func:`~torch.ao.quantization.prepare_fx` - * `prepare_custom_config`: see :func:`~torch.ao.quantization.prepare_fx` - * `backend_config`: see :func:`~torch.ao.quantization.prepare_fx` + * `model` (torch.nn.Module): torch.nn.Module model + * `qconfig_mapping` (QConfigMapping): see :func:`~torch.ao.quantization.prepare_fx` + * `example_inputs` (Tuple[Any, ...]): see :func:`~torch.ao.quantization.prepare_fx` + * `prepare_custom_config` (PrepareCustomConfig): see :func:`~torch.ao.quantization.prepare_fx` + * `backend_config` (BackendConfig): see :func:`~torch.ao.quantization.prepare_fx` Return: - A GraphModule with fake quant modules (configured by qconfig_mapping), ready for + A GraphModule with fake quant modules (configured by qconfig_mapping and backend_config), ready for quantization aware training Example:: @@ -602,23 +570,14 @@ def convert_fx( r""" Convert a calibrated or trained model to a quantized model Args: - * `graph_module`: A prepared and calibrated/trained model (GraphModule) - * `is_reference`: flag for whether to produce a reference quantized model, - which will be a common interface between pytorch quantization with - other backends like accelerators + * `graph_module` (torch.fx.GraphModule): A prepared and calibrated/trained model (GraphModule) - * `convert_custom_config`: custom configurations for convert function. - See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more detail:: + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :class:`~torch.ao.quantization.fx.custom_config.ConvertCustomConfig` for more details - from torch.ao.quantization.fx.custom_config import ConvertCustomConfig + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. - convert_custom_config = ConvertCustomConfig() \ - .set_observed_to_quantized_mapping(ObservedCustomModule, QuantizedCustomModule) \ - .set_preserved_attributes(["attr1", "attr2"]) - - * `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert. - - * `qconfig_mapping`: config for specifying how to convert a model for quantization. + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. The keys must include the ones in the qconfig_mapping passed to `prepare_fx` or `prepare_qat_fx`, with the same values or `None`. Additional keys can be specified with values set to `None`. @@ -631,14 +590,14 @@ def convert_fx( .set_object_type(torch.nn.functional.linear, qconfig_from_prepare) .set_module_name("foo.bar", None) # skip quantizing module "foo.bar" - * `backend_config`: A configuration for the backend which describes how + * `backend_config` (BackendConfig): A configuration for the backend which describes how operators should be quantized in the backend, this includes quantization mode support (static/dynamic/weight_only), dtype support (quint8/qint8 etc.), - observer placement for each operators and fused operators. Detailed - documentation can be found in torch/ao/quantization/backend_config/README.md + observer placement for each operators and fused operators. + See :class:`~torch.ao.quantization.backend_config.BackendConfig` for more details Return: - A quantized model (GraphModule) + A quantized model (torch.nn.Module) Example:: @@ -682,19 +641,19 @@ def convert_to_reference_fx( hardware, like accelerators Args: - * `graph_module`: A prepared and calibrated/trained model (GraphModule) + * `graph_module` (GraphModule): A prepared and calibrated/trained model (GraphModule) - * `convert_custom_config`: custom configurations for convert function. - See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail. + * `convert_custom_config` (ConvertCustomConfig): custom configurations for convert function. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. - * `_remove_qconfig`: Option to remove the qconfig attributes in the model after convert. + * `_remove_qconfig` (bool): Option to remove the qconfig attributes in the model after convert. - * `qconfig_mapping`: config for specifying how to convert a model for quantization. - See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail. + * `qconfig_mapping` (QConfigMapping): config for specifying how to convert a model for quantization. + See :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. - * `backend_config`: A configuration for the backend which describes how + * `backend_config` (BackendConfig): A configuration for the backend which describes how operators should be quantized in the backend. See - :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more detail. + :func:`~torch.ao.quantization.quantize_fx.convert_fx` for more details. Return: A reference quantized model (GraphModule) diff --git a/torch/ao/sparsity/_mappings.py b/torch/ao/sparsity/_mappings.py index c831b3ddce2..281450bcb29 100644 --- a/torch/ao/sparsity/_mappings.py +++ b/torch/ao/sparsity/_mappings.py @@ -1,3 +1,8 @@ +__all__ = [ + "get_static_sparse_quantized_mapping", + "get_dynamic_sparse_quantized_mapping", +] + def get_static_sparse_quantized_mapping(): import torch.ao.nn.sparse _static_sparse_quantized_mapping = dict({