diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py new file mode 100644 index 00000000000..e49aaef5113 --- /dev/null +++ b/test/quantization/fx/test_equalize_fx.py @@ -0,0 +1,112 @@ +import torch +from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.quantization.fx._equalize import ( + _InputEqualizationObserver, _WeightEqualizationObserver, calculate_equalization_scale +) + +# Standard Libraries +import numpy as np + +# Testing utils +from hypothesis import given +from hypothesis import strategies as st + + +class TestEqualizeFx(QuantizationTestCase): + @given(input_qdtype=st.sampled_from((torch.qint8, torch.quint8)), + input_qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)), + weight_qdtype=st.sampled_from((torch.qint8, torch.quint8)), + weight_qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams))) + def test_input_weight_observer(self, input_qdtype, input_qscheme, weight_qdtype, weight_qscheme): + input_obs = _InputEqualizationObserver(dtype=input_qdtype, qscheme=input_qscheme) + weight_obs = _WeightEqualizationObserver(dtype=weight_qdtype, qscheme=weight_qscheme) + + width = np.random.randint(1, 10) + x_height = np.random.randint(2, 10) + w_height = np.random.randint(2, 10) + + x = (np.random.random(size=(x_height, width)) * 10).round(decimals=2).astype(np.float32) + w = (np.random.random(size=(w_height, width)) * 10).round(decimals=2).astype(np.float32) + + ret_x = input_obs(torch.tensor(x)) + ret_w = weight_obs(torch.tensor(w)) + self.assertEqual((ret_x, ret_w), (x, w)) + + # Check the min/max input columns are correct + ref_min_inputs = x.min(axis=0) + ref_max_inputs = x.max(axis=0) + self.assertEqual(input_obs.get_input_minmax(), (ref_min_inputs, ref_max_inputs)) + + # Check the min/max weight columns are correct + ref_min_weights_col = w.min(axis=0) + ref_max_weights_col = w.max(axis=0) + self.assertEqual(weight_obs.get_weight_col_minmax(), (ref_min_weights_col, ref_max_weights_col)) + + # Check the min/max weight rows are correct + ref_min_weights_row = w.min(axis=1) + ref_max_weights_row = w.max(axis=1) + self.assertEqual(weight_obs.get_weight_row_minmax(), (ref_min_weights_row, ref_max_weights_row)) + + # Check the column indices of the min/max weight rows are correct + ref_min_weights_ind = w.argmin(axis=1) + ref_max_weights_ind = w.argmax(axis=1) + self.assertEqual((weight_obs.min_weights_ind, weight_obs.max_weights_ind), + (ref_min_weights_ind, ref_max_weights_ind)) + + # Check the equalization scale is correct + equalization_scale = calculate_equalization_scale(input_obs, weight_obs) + ref_equalization_scale = np.sqrt((ref_max_weights_col - ref_min_weights_col) / + (ref_max_inputs - ref_min_inputs)) + self.assertEqual(equalization_scale, ref_equalization_scale) + + input_obs.set_equalization_scale(equalization_scale) + weight_obs.set_equalization_scale(equalization_scale) + + # check the input scale/zero-point values + input_qparams = input_obs.calculate_qparams() + + min_input_scaled = np.min(ref_min_inputs * ref_equalization_scale) + min_input_scaled = min(0, min_input_scaled) + max_input_scaled = np.max(ref_max_inputs * ref_equalization_scale) + max_input_scaled = max(0, max_input_scaled) + + if input_qscheme == torch.per_tensor_symmetric: + ref_scale = 2 * max(abs(min_input_scaled), max_input_scaled) / 255 + ref_zero_point = 0 if input_qdtype is torch.qint8 else 128 + else: + ref_scale = (max_input_scaled - min_input_scaled) / 255 + ref_zero_point = -128 if input_qdtype is torch.qint8 else 0 + + self.assertEqual(input_qparams[0].item(), ref_scale, atol=1e-5, rtol=0) + self.assertEqual(input_qparams[1].item(), ref_zero_point) + + # check the weight scale/zero-point values + weight_qparams = weight_obs.calculate_qparams() + + min_weights_scaled = ref_min_weights_row * (1 / ref_equalization_scale[ref_min_weights_ind]) + max_weights_scaled = ref_max_weights_row * (1 / ref_equalization_scale[ref_max_weights_ind]) + + if weight_qscheme == torch.per_channel_symmetric: + min_weights_scaled = np.minimum(np.zeros(min_weights_scaled.shape), min_weights_scaled) + max_weights_scaled = np.maximum(np.zeros(max_weights_scaled.shape), max_weights_scaled) + + ref_scales = 2 * np.maximum(np.abs(min_weights_scaled), max_weights_scaled) / 255 + ref_zero_points = np.zeros_like( + ref_scales) if weight_qdtype is torch.qint8 else np.ones_like(ref_scales) * 128 + elif weight_qscheme == torch.per_channel_affine_float_qparams: + ref_scales = (max_weights_scaled - min_weights_scaled) / 255 + ref_scales = np.where(ref_scales > 1e-7, ref_scales, np.ones_like(ref_scales)) + ref_zero_points = -1 * min_weights_scaled / ref_scales + else: + min_weights_scaled = np.minimum(np.zeros_like(min_weights_scaled), min_weights_scaled) + max_weights_scaled = np.maximum(np.zeros_like(max_weights_scaled), max_weights_scaled) + + ref_scales = (max_weights_scaled - min_weights_scaled) / 255 + ref_zero_points = -128 if weight_qdtype is torch.qint8 else 0 + ref_zero_points = ref_zero_points - np.round(min_weights_scaled / ref_scales) + + self.assertTrue(torch.allclose(weight_qparams[0], torch.tensor( + ref_scales, dtype=weight_qparams[0].dtype), atol=0.0001)) + self.assertTrue(torch.allclose(weight_qparams[1], torch.tensor( + ref_zero_points, dtype=weight_qparams[1].dtype), atol=1)) diff --git a/test/test_quantization.py b/test/test_quantization.py index 0821fa1f5a8..80ed466ded3 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -79,10 +79,15 @@ try: except ImportError: pass +# Equalization for FX mode +try: + from quantization.fx.test_equalize_fx import TestEqualizeFx # noqa: F401 +except ImportError: + pass + # Backward Compatibility. Tests serialization and BC for quantized modules. from quantization.bc.test_backward_compatibility import TestSerialization # noqa: F401 - # JIT Graph Mode Quantization from quantization.jit.test_quantize_jit import TestQuantizeJit # noqa: F401 from quantization.jit.test_quantize_jit import TestQuantizeJitPasses # noqa: F401 diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py new file mode 100644 index 00000000000..fdbf6c977be --- /dev/null +++ b/torch/quantization/fx/_equalize.py @@ -0,0 +1,231 @@ +import torch +import torch.nn as nn +from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver + +import warnings + + +class _InputEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of input columns, and + computing the quantization parameters for the overall min/max input values. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + output_obs: For the user to specify what kind of output observer they + would like to use + + The running minimum/maximum :math:`x_\text{min/max}` are computed in the + same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`, + with the difference that the running min/max values are stored per column. + + The qparams are calculated by multiplying the min/max input column values + with the equalization scale, reducing to find the global min/max input + values, and then calculating in the same way as in + :class:`~torch.quantization.observer.MinMaxObserver` + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, + quant_min=None, quant_max=None, output_obs=None, + factory_kwargs=None) -> None: + super(_InputEqualizationObserver, self).__init__() + + if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + raise TypeError("Input qscheme must be per-tensor") + + self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs) + + if output_obs is None: + self.output_obs = MinMaxObserver(dtype=dtype, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs) + else: + self.output_obs = output_obs + + self.equalization_scale = torch.empty(0) + + def forward(self, x_orig): + # TODO: Allow for convoluational layers + if not (x_orig.ndim == 2): + raise ValueError("InputEqualizationObserver only supports Linear layers") + + return self.input_obs(x_orig) + + def get_input_minmax(self): + return (self.input_obs.min_vals, self.input_obs.max_vals) + + def set_equalization_scale(self, equalization_scale): + self.equalization_scale = equalization_scale + + def calculate_qparams(self): + r""" + Returns the scale/zero_point for the input and weight rows + """ + + if self.equalization_scale.nelement() == 0: + warnings.warn( + "Must call calculate_scale before calling calculate_qparams.\ + Returning default scale and zero point. " + ) + return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0]) + + # Calculate qparams for the scaled min/max inputs + # Scale the input by the equalization scale located at the same column + # index + (min_inputs, max_inputs) = self.get_input_minmax() + min_input_scaled = torch.min(torch.mul(min_inputs, self.equalization_scale)) + max_input_scaled = torch.max(torch.mul(max_inputs, self.equalization_scale)) + (scale_input, zero_point_input) = self.input_obs._calculate_qparams(min_input_scaled, max_input_scaled) + + return scale_input, zero_point_input + + +class _WeightEqualizationObserver(nn.Module): + r"""Observer for tracking the running min/max values of weight columns and + rows, and computing the quantization parameters for the weight rows. + + Args: + dtype: Quantized data type + qscheme: Quantization scheme + quant_min: Minimum quantization value. If unspecified, it will + follow the 8-bit setup. + quant_max: Maximum quantization value. If unspecified, it will + follow the 8-bit setup. + + This observer is made up of 2 PerChannelMinMaxObservers + - weight_col_obs: Used to record the running minimum and maximum of + columns of incoming weight tensors + - weight_row_obs: Used to record the running minimum and maximum of + rows of incoming weight tensors + + The running minimum/maximum :math:`w_\text{min/max}` are computed in the + same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`. + + The qparams are calculated by multiplying the min/max weight row values + with the inverse of the equalization scale, and then calculating in the same + way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver` + + .. note:: If the running minimum equals to the running maximum, the scales + and zero_points are set to 1.0 and 0. + """ + + def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None, + quant_max=None, factory_kwargs=None) -> None: + super(_WeightEqualizationObserver, self).__init__() + + self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs) + + self.weight_row_obs = PerChannelMinMaxObserver(ch_axis=0, dtype=dtype, + qscheme=qscheme, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs) + + self.equalization_scale = torch.empty(0) + + def forward(self, w_orig): + # TODO: Allow for convoluational layers + if not (w_orig.ndim == 2): + raise ValueError("WeightEqualizationObserver only supports Linear layers") + + return self._forward(w_orig) + + def _forward(self, w_orig): + r""" + Calculates the min/max values of each weight column and weight row. + """ + + w_orig = self.weight_col_obs(w_orig) + w_orig = self.weight_row_obs(w_orig) + + # Calculate the column indices of the min/max weight in each row + num_row, _ = w_orig.shape + min_weights_ind = [] + max_weights_ind = [] + for i in range(num_row): + min_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.min_vals[i])[0][0]) + max_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.max_vals[i])[0][0]) + self.min_weights_ind = torch.tensor(min_weights_ind) + self.max_weights_ind = torch.tensor(max_weights_ind) + + return w_orig + + def get_weight_col_minmax(self): + return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals) + + def get_weight_row_minmax(self): + return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals) + + def set_equalization_scale(self, equalization_scale): + self.equalization_scale = equalization_scale + + def calculate_qparams(self): + r""" + Returns the scale/zero_point for the input and weight rows + """ + + if self.equalization_scale.nelement() == 0: + warnings.warn( + "Must call calculate_scale before calling calculate_qparams.\ + Returning default scale and zero point. " + ) + return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0]) + + if self.min_weights_ind is None or self.max_weights_ind is None: + warnings.warn( + "Must find the column indicies of the minimum of each row in the \ + weights in order to calculate the qparams calculate the \ + qparams. Returning default scale and zero point. " + ) + return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0]) + + # Calculate the qparams for weights by using the rows + # Scale the weight rows by the reciprocal of the equalization scale + # located at the same column index + (min_weights, max_weights) = self.get_weight_row_minmax() + min_weights_scaled = torch.mul(min_weights, torch.reciprocal(self.equalization_scale[self.min_weights_ind])) + max_weights_scaled = torch.mul(max_weights, torch.reciprocal(self.equalization_scale[self.max_weights_ind])) + (scale_weight, zero_point_weight) = self.weight_row_obs._calculate_qparams(min_weights_scaled, max_weights_scaled) + + return scale_weight, zero_point_weight + + +def calculate_equalization_scale(input_obs: _InputEqualizationObserver, + weight_obs: _WeightEqualizationObserver) -> torch.Tensor: + r""" Calculates the equalization scale and sets the equalization_scale value + in the observers. + + Args: + input_obs: Observer that tracks the ranges for the input columns + weight_obs: Observer that tracks the ranges for the weight columns + """ + + (min_inputs, max_inputs) = input_obs.get_input_minmax() + (min_weights, max_weights) = weight_obs.get_weight_col_minmax() + + if not (min_inputs.shape == min_weights.shape): + raise ValueError( + "Input and Weight must have the same column dimension. " + + f"Found {min_inputs.shape} and {max_inputs.shape} instead." + ) + + equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs)) + + return equalization_scale