diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 2bbc7163b53..e974df655af 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -99,6 +99,7 @@ Quantization to work with this as well. BackendConfig BackendPatternConfig DTypeConfig + DTypeWithConstraints ObservationType torch.ao.quantization.fx.custom_config diff --git a/torch/ao/quantization/backend_config/backend_config.py b/torch/ao/quantization/backend_config/backend_config.py index af1ac554de2..faf2fd03ade 100644 --- a/torch/ao/quantization/backend_config/backend_config.py +++ b/torch/ao/quantization/backend_config/backend_config.py @@ -70,6 +70,25 @@ class DTypeWithConstraints: Config for specifying additional constraints for a given dtype, such as quantization value ranges, scale value ranges, and fixed quantization params, to be used in :class:`~torch.ao.quantization.backend_config.DTypeConfig`. + + The constraints currently supported are: + + * `quant_min_lower_bound` and `quant_max_upper_bound`: Lower and upper + bounds for the minimum and maximum quantized values respectively. If + the QConfig’s `quant_min` and `quant_max` fall outside this range, + then the QConfig will be ignored. + + * `scale_min_lower_bound` and `scale_max_upper_bound`: Lower and upper + bounds for the minimum and maximum scale values respectively. If the + QConfig’s minimum scale value (currently exposed as `eps`) falls below + the lower bound, then the QConfig will be ignored. Note that the upper + bound is currently not enforced. + + * `scale_exact_match` and `zero_point_exact_match`: Exact match requirements + for scale and zero point, to be used for operators with fixed quantization + parameters such as sigmoid and tanh. If the observer specified in the QConfig + is neither `FixedQParamsObserver` nor `FixedQParamsFakeQuantize`, or if + the quantization parameters don't match, then the QConfig will be ignored. """ dtype: Optional[torch.dtype] = None quant_min_lower_bound: Union[int, float, None] = None @@ -83,8 +102,35 @@ class DTypeWithConstraints: @dataclass class DTypeConfig: """ - Config for the set of supported input/output activation, weight, and bias data types for the - patterns defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. + Config object that specifies the supported data types passed as arguments to + quantize ops in the reference model spec, for input and output activations, + weights, and biases. + + For example, consider the following reference model: + + quant1 - [dequant1 - fp32_linear - quant2] - dequant2 + + The pattern in the square brackets refers to the reference pattern of + statically quantized linear. Setting the input dtype as `torch.quint8` + in the DTypeConfig means we pass in `torch.quint8` as the dtype argument + to the first quantize op (quant1). Similarly, setting the output dtype as + `torch.quint8` means we pass in `torch.quint8` as the dtype argument to + the second quantize op (quant2). + + Note that the dtype here does not refer to the interface dtypes of the + op. For example, the "input dtype" here is not the dtype of the input + tensor passed to the quantized linear op. Though it can still be the + same as the interface dtype, this is not always the case, e.g. the + interface dtype is fp32 in dynamic quantization but the "input dtype" + specified in the DTypeConfig would still be quint8. The semantics of + dtypes here are the same as the semantics of the dtypes specified in + the observers. + + These dtypes are matched against the ones specified in the user’s + QConfig. If there is a match, and the QConfig satisfies the constraints + specified in the DTypeConfig (if any), then we will quantize the given + pattern using this DTypeConfig. Otherwise, the QConfig is ignored and + the pattern will not be quantized. Example usage:: @@ -353,9 +399,8 @@ class BackendConfig: class BackendPatternConfig: """ - Config for ops defined in :class:`~torch.ao.quantization.backend_config.BackendConfig`. + Config object that specifies quantization behavior for a given operator pattern. For a detailed example usage, see :class:`~torch.ao.quantization.backend_config.BackendConfig`. - """ def __init__(self, pattern: Optional[Pattern] = None): self.pattern: Optional[Pattern] = pattern @@ -401,31 +446,39 @@ class BackendPatternConfig: def set_observation_type(self, observation_type: ObservationType) -> BackendPatternConfig: """ Set how observers should be inserted in the graph for this pattern. + + Observation type here refers to how observers (or quant-dequant ops) will be placed + in the graph. This is used to produce the desired reference patterns understood by + the backend. Weighted ops such as linear and conv require different observers + (or quantization parameters passed to quantize ops in the reference model) for the + input and the output. + There are two observation types: - `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance will be - different from the input. This is the most common observation type. + `OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT` (default): the output observer instance + will be different from the input. This is the most common observation type. - `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the same as the input. - This is useful for operators like `cat`. + `OUTPUT_SHARE_OBSERVER_WITH_INPUT`: the output observer instance will be the + same as the input. This is useful for operators like `cat`. - Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs with - observers (and fake quantizes) attached instead of observers themselves. + Note: This will be renamed in the near future, since we will soon insert QuantDeQuantStubs + with observers (and fake quantizes) attached instead of observers themselves. """ self.observation_type = observation_type return self def add_dtype_config(self, dtype_config: DTypeConfig) -> BackendPatternConfig: """ - Add a set of supported input/output activation, weight, and bias data types for this pattern. + Add a set of supported data types passed as arguments to quantize ops in the + reference model spec. """ self.dtype_configs.append(dtype_config) return self def set_dtype_configs(self, dtype_configs: List[DTypeConfig]) -> BackendPatternConfig: """ - Set the supported input/output activation, weight, and bias data types for this pattern, - overriding all previously registered data types. + Set the supported data types passed as arguments to quantize ops in the + reference model spec, overriding all previously registered data types. """ self.dtype_configs = dtype_configs return self @@ -433,7 +486,15 @@ class BackendPatternConfig: def set_root_module(self, root_module: Type[torch.nn.Module]) -> BackendPatternConfig: """ Set the module that represents the root for this pattern. - For example, the root module for :class:`torch.nn.intrinsic.LinearReLU` should be :class:`torch.nn.Linear`. + + When we construct the reference quantized model during the convert phase, + the root modules (e.g. torch.nn.Linear for torch.ao.nn.intrinsic.LinearReLU) + will be swapped to the corresponding reference quantized modules (e.g. + torch.ao.nn.reference.quantized.Linear). This allows custom backends to + specify custom reference quantized module implementations to match the + numerics of their lowered operators. Since this is a one-to-one mapping, + both the root module and the reference quantized module must be specified + in the same BackendPatternConfig in order for the conversion to take place. """ self.root_module = root_module return self @@ -447,7 +508,10 @@ class BackendPatternConfig: def set_reference_quantized_module(self, reference_quantized_module: Type[torch.nn.Module]) -> BackendPatternConfig: """ - Set the module that represents the reference quantized implementation for this pattern's root module. + Set the module that represents the reference quantized implementation for + this pattern's root module. + + For more detail, see :func:`~torch.ao.quantization.backend_config.BackendPatternConfig.set_root_module`. """ self.reference_quantized_module = reference_quantized_module return self @@ -461,7 +525,7 @@ class BackendPatternConfig: def set_fuser_method(self, fuser_method: Callable) -> BackendPatternConfig: """ - Set the function that specifies how to fuse the pattern for this pattern. + Set the function that specifies how to fuse this BackendPatternConfig's pattern. The first argument of this function should be `is_qat`, and the rest of the arguments should be the items in the tuple pattern. The return value of this function should be @@ -471,6 +535,8 @@ class BackendPatternConfig: def fuse_linear_relu(is_qat, linear, relu): return torch.ao.nn.intrinsic.LinearReLU(linear, relu) + + For a more complicated example, see https://gist.github.com/jerryzh168/8bea7180a8ba3c279f2c9b050f2a69a6. """ self.fuser_method = fuser_method return self