mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Summary: Implemented two observers (InputEqualObserver and WeightEqualObserver) which will be inserted into the graph during prepare_fx(). Test Plan: python test/test_quantization.py TestEqualizeFx Reviewed By: supriyar Differential Revision: D28836954 fbshipit-source-id: 25517dc82ae67698ed8b2dc334e3323286976104
231 lines
10 KiB
Python
231 lines
10 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from torch.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
|
|
|
|
import warnings
|
|
|
|
|
|
class _InputEqualizationObserver(nn.Module):
|
|
r"""Observer for tracking the running min/max values of input columns, and
|
|
computing the quantization parameters for the overall min/max input values.
|
|
|
|
Args:
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme
|
|
quant_min: Minimum quantization value. If unspecified, it will
|
|
follow the 8-bit setup.
|
|
quant_max: Maximum quantization value. If unspecified, it will
|
|
follow the 8-bit setup.
|
|
output_obs: For the user to specify what kind of output observer they
|
|
would like to use
|
|
|
|
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
|
|
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`,
|
|
with the difference that the running min/max values are stored per column.
|
|
|
|
The qparams are calculated by multiplying the min/max input column values
|
|
with the equalization scale, reducing to find the global min/max input
|
|
values, and then calculating in the same way as in
|
|
:class:`~torch.quantization.observer.MinMaxObserver`
|
|
|
|
.. note:: If the running minimum equals to the running maximum, the scales
|
|
and zero_points are set to 1.0 and 0.
|
|
"""
|
|
|
|
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
|
|
quant_min=None, quant_max=None, output_obs=None,
|
|
factory_kwargs=None) -> None:
|
|
super(_InputEqualizationObserver, self).__init__()
|
|
|
|
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
|
|
raise TypeError("Input qscheme must be per-tensor")
|
|
|
|
self.input_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
|
qscheme=qscheme,
|
|
quant_min=quant_min,
|
|
quant_max=quant_max,
|
|
factory_kwargs=factory_kwargs)
|
|
|
|
if output_obs is None:
|
|
self.output_obs = MinMaxObserver(dtype=dtype,
|
|
qscheme=qscheme,
|
|
quant_min=quant_min,
|
|
quant_max=quant_max,
|
|
factory_kwargs=factory_kwargs)
|
|
else:
|
|
self.output_obs = output_obs
|
|
|
|
self.equalization_scale = torch.empty(0)
|
|
|
|
def forward(self, x_orig):
|
|
# TODO: Allow for convoluational layers
|
|
if not (x_orig.ndim == 2):
|
|
raise ValueError("InputEqualizationObserver only supports Linear layers")
|
|
|
|
return self.input_obs(x_orig)
|
|
|
|
def get_input_minmax(self):
|
|
return (self.input_obs.min_vals, self.input_obs.max_vals)
|
|
|
|
def set_equalization_scale(self, equalization_scale):
|
|
self.equalization_scale = equalization_scale
|
|
|
|
def calculate_qparams(self):
|
|
r"""
|
|
Returns the scale/zero_point for the input and weight rows
|
|
"""
|
|
|
|
if self.equalization_scale.nelement() == 0:
|
|
warnings.warn(
|
|
"Must call calculate_scale before calling calculate_qparams.\
|
|
Returning default scale and zero point. "
|
|
)
|
|
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
|
|
|
|
# Calculate qparams for the scaled min/max inputs
|
|
# Scale the input by the equalization scale located at the same column
|
|
# index
|
|
(min_inputs, max_inputs) = self.get_input_minmax()
|
|
min_input_scaled = torch.min(torch.mul(min_inputs, self.equalization_scale))
|
|
max_input_scaled = torch.max(torch.mul(max_inputs, self.equalization_scale))
|
|
(scale_input, zero_point_input) = self.input_obs._calculate_qparams(min_input_scaled, max_input_scaled)
|
|
|
|
return scale_input, zero_point_input
|
|
|
|
|
|
class _WeightEqualizationObserver(nn.Module):
|
|
r"""Observer for tracking the running min/max values of weight columns and
|
|
rows, and computing the quantization parameters for the weight rows.
|
|
|
|
Args:
|
|
dtype: Quantized data type
|
|
qscheme: Quantization scheme
|
|
quant_min: Minimum quantization value. If unspecified, it will
|
|
follow the 8-bit setup.
|
|
quant_max: Maximum quantization value. If unspecified, it will
|
|
follow the 8-bit setup.
|
|
|
|
This observer is made up of 2 PerChannelMinMaxObservers
|
|
- weight_col_obs: Used to record the running minimum and maximum of
|
|
columns of incoming weight tensors
|
|
- weight_row_obs: Used to record the running minimum and maximum of
|
|
rows of incoming weight tensors
|
|
|
|
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
|
|
same way as :class:`~torch.quantization.observer.PerChannelMinMaxObserver`.
|
|
|
|
The qparams are calculated by multiplying the min/max weight row values
|
|
with the inverse of the equalization scale, and then calculating in the same
|
|
way as in :class:`~torch.quantization.observer.PerChannelMinMaxObserver`
|
|
|
|
.. note:: If the running minimum equals to the running maximum, the scales
|
|
and zero_points are set to 1.0 and 0.
|
|
"""
|
|
|
|
def __init__(self, dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=None,
|
|
quant_max=None, factory_kwargs=None) -> None:
|
|
super(_WeightEqualizationObserver, self).__init__()
|
|
|
|
self.weight_col_obs = PerChannelMinMaxObserver(ch_axis=1, dtype=dtype,
|
|
qscheme=qscheme,
|
|
quant_min=quant_min,
|
|
quant_max=quant_max,
|
|
factory_kwargs=factory_kwargs)
|
|
|
|
self.weight_row_obs = PerChannelMinMaxObserver(ch_axis=0, dtype=dtype,
|
|
qscheme=qscheme,
|
|
quant_min=quant_min,
|
|
quant_max=quant_max,
|
|
factory_kwargs=factory_kwargs)
|
|
|
|
self.equalization_scale = torch.empty(0)
|
|
|
|
def forward(self, w_orig):
|
|
# TODO: Allow for convoluational layers
|
|
if not (w_orig.ndim == 2):
|
|
raise ValueError("WeightEqualizationObserver only supports Linear layers")
|
|
|
|
return self._forward(w_orig)
|
|
|
|
def _forward(self, w_orig):
|
|
r"""
|
|
Calculates the min/max values of each weight column and weight row.
|
|
"""
|
|
|
|
w_orig = self.weight_col_obs(w_orig)
|
|
w_orig = self.weight_row_obs(w_orig)
|
|
|
|
# Calculate the column indices of the min/max weight in each row
|
|
num_row, _ = w_orig.shape
|
|
min_weights_ind = []
|
|
max_weights_ind = []
|
|
for i in range(num_row):
|
|
min_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.min_vals[i])[0][0])
|
|
max_weights_ind.append(torch.nonzero(w_orig[i] == self.weight_row_obs.max_vals[i])[0][0])
|
|
self.min_weights_ind = torch.tensor(min_weights_ind)
|
|
self.max_weights_ind = torch.tensor(max_weights_ind)
|
|
|
|
return w_orig
|
|
|
|
def get_weight_col_minmax(self):
|
|
return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)
|
|
|
|
def get_weight_row_minmax(self):
|
|
return (self.weight_row_obs.min_vals, self.weight_row_obs.max_vals)
|
|
|
|
def set_equalization_scale(self, equalization_scale):
|
|
self.equalization_scale = equalization_scale
|
|
|
|
def calculate_qparams(self):
|
|
r"""
|
|
Returns the scale/zero_point for the input and weight rows
|
|
"""
|
|
|
|
if self.equalization_scale.nelement() == 0:
|
|
warnings.warn(
|
|
"Must call calculate_scale before calling calculate_qparams.\
|
|
Returning default scale and zero point. "
|
|
)
|
|
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
|
|
|
|
if self.min_weights_ind is None or self.max_weights_ind is None:
|
|
warnings.warn(
|
|
"Must find the column indicies of the minimum of each row in the \
|
|
weights in order to calculate the qparams calculate the \
|
|
qparams. Returning default scale and zero point. "
|
|
)
|
|
return torch.tensor([1.0]), torch.tensor([0]), torch.tensor([1.0]), torch.tensor([0])
|
|
|
|
# Calculate the qparams for weights by using the rows
|
|
# Scale the weight rows by the reciprocal of the equalization scale
|
|
# located at the same column index
|
|
(min_weights, max_weights) = self.get_weight_row_minmax()
|
|
min_weights_scaled = torch.mul(min_weights, torch.reciprocal(self.equalization_scale[self.min_weights_ind]))
|
|
max_weights_scaled = torch.mul(max_weights, torch.reciprocal(self.equalization_scale[self.max_weights_ind]))
|
|
(scale_weight, zero_point_weight) = self.weight_row_obs._calculate_qparams(min_weights_scaled, max_weights_scaled)
|
|
|
|
return scale_weight, zero_point_weight
|
|
|
|
|
|
def calculate_equalization_scale(input_obs: _InputEqualizationObserver,
|
|
weight_obs: _WeightEqualizationObserver) -> torch.Tensor:
|
|
r""" Calculates the equalization scale and sets the equalization_scale value
|
|
in the observers.
|
|
|
|
Args:
|
|
input_obs: Observer that tracks the ranges for the input columns
|
|
weight_obs: Observer that tracks the ranges for the weight columns
|
|
"""
|
|
|
|
(min_inputs, max_inputs) = input_obs.get_input_minmax()
|
|
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
|
|
|
|
if not (min_inputs.shape == min_weights.shape):
|
|
raise ValueError(
|
|
"Input and Weight must have the same column dimension. " +
|
|
f"Found {min_inputs.shape} and {max_inputs.shape} instead."
|
|
)
|
|
|
|
equalization_scale = torch.sqrt((max_weights - min_weights) / (max_inputs - min_inputs))
|
|
|
|
return equalization_scale
|