mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[quant] update per-channel observer min/max_val attribute names (#62345)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62345 This PR updates the attribute names from min_vals to min_val. the motivation for this is to keep the attribute name consistent with per-tensor observers so that dependencies (like FusedMovingAvgObsFakeQuantize) don't need to differentiate between the two observer types to access the attributes. It also adds some BC tests to make sure that observers saved earlier with min_vals/max_vals can be loaded depending on the state_dict version. Note: Scriptability of the observers isn't fully supported yet, so we aren't testing for that in this PR. Test Plan: python test/test_quantization.py TestSerialization Imported from OSS Reviewed By: HDCharles Differential Revision: D30003700 fbshipit-source-id: 20e673f1bb15e2b209551b6b9d5f8f3be3f85c0a
This commit is contained in:
parent
d92301dd02
commit
cfd0f5ebc9
11 changed files with 93 additions and 55 deletions
|
|
@ -14,6 +14,8 @@ import torch.nn.intrinsic.quantized as nniq
|
|||
from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED
|
||||
from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm
|
||||
|
||||
from torch.quantization import MinMaxObserver, PerChannelMinMaxObserver
|
||||
|
||||
def remove_prefix(text, prefix):
|
||||
if text.startswith(prefix):
|
||||
return text[len(prefix):]
|
||||
|
|
@ -122,6 +124,24 @@ class TestSerialization(TestCase):
|
|||
self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec)
|
||||
self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec)
|
||||
|
||||
def _test_obs(self, obs, input_size, subname=None, generate=False):
|
||||
"""
|
||||
Test observer code can be loaded from state_dict.
|
||||
"""
|
||||
input_file, state_dict_file, _, traced_module_file, expected_file = \
|
||||
get_filenames(self, None)
|
||||
if generate:
|
||||
input_tensor = torch.rand(*input_size).float()
|
||||
torch.save(input_tensor, input_file)
|
||||
torch.save(obs(input_tensor), expected_file)
|
||||
torch.save(obs.state_dict(), state_dict_file)
|
||||
|
||||
input_tensor = torch.load(input_file)
|
||||
obs.load_state_dict(torch.load(state_dict_file))
|
||||
expected = torch.load(expected_file)
|
||||
|
||||
self.assertEqual(obs(input_tensor), expected)
|
||||
|
||||
@override_qengines
|
||||
def test_linear(self):
|
||||
module = nnq.Linear(3, 1, bias_=True, dtype=torch.qint8)
|
||||
|
|
@ -251,3 +271,11 @@ class TestSerialization(TestCase):
|
|||
if qengine_is_fbgemm():
|
||||
mod = LSTMModule()
|
||||
self._test_op(mod, input_size=[4, 4, 3], input_quantized=False, generate=False, new_zipfile_serialization=True)
|
||||
|
||||
def test_per_channel_observer(self):
|
||||
obs = PerChannelMinMaxObserver()
|
||||
self._test_obs(obs, input_size=[5, 5], generate=False)
|
||||
|
||||
def test_per_tensor_observer(self):
|
||||
obs = MinMaxObserver()
|
||||
self._test_obs(obs, input_size=[5, 5], generate=False)
|
||||
|
|
|
|||
|
|
@ -186,8 +186,8 @@ class TestObserver(QuantizationTestCase):
|
|||
]
|
||||
per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]
|
||||
|
||||
self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
|
||||
self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
|
||||
self.assertEqual(myobs.min_val, ref_min_vals[ch_axis])
|
||||
self.assertEqual(myobs.max_val, ref_max_vals[ch_axis])
|
||||
if qscheme == torch.per_channel_symmetric:
|
||||
ref_scales = per_channel_symmetric_ref_scales[ch_axis]
|
||||
ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
|
||||
|
|
@ -223,8 +223,8 @@ class TestObserver(QuantizationTestCase):
|
|||
loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme)
|
||||
loaded_obs.load_state_dict(loaded_dict)
|
||||
loaded_qparams = loaded_obs.calculate_qparams()
|
||||
self.assertEqual(myobs.min_vals, loaded_obs.min_vals)
|
||||
self.assertEqual(myobs.max_vals, loaded_obs.max_vals)
|
||||
self.assertEqual(myobs.min_val, loaded_obs.min_val)
|
||||
self.assertEqual(myobs.max_val, loaded_obs.max_val)
|
||||
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())
|
||||
|
||||
|
||||
|
|
@ -314,9 +314,7 @@ class TestObserver(QuantizationTestCase):
|
|||
Tests that we can save and load state_dict for observers that are scripted
|
||||
in a quantized model.
|
||||
"""
|
||||
obs_list = [MinMaxObserver, MovingAverageMinMaxObserver,
|
||||
PerChannelMinMaxObserver,
|
||||
MovingAveragePerChannelMinMaxObserver, HistogramObserver]
|
||||
obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver]
|
||||
|
||||
for obs in obs_list:
|
||||
model = SingleLayerLinearModel().eval()
|
||||
|
|
|
|||
|
|
@ -189,8 +189,8 @@ class TestEqualizeFx(QuantizationTestCase):
|
|||
|
||||
# Check the min/max weight rows are correct
|
||||
ref_min_weights_scaled, ref_max_weights_scaled = self.channel_minmax(ref_w_scaled)
|
||||
self.assertEqual(weight_quant_obs.min_vals, torch.tensor(ref_min_weights_scaled, dtype=torch.float32))
|
||||
self.assertEqual(weight_quant_obs.max_vals, torch.tensor(ref_max_weights_scaled, dtype=torch.float32))
|
||||
self.assertEqual(weight_quant_obs.min_val, torch.tensor(ref_min_weights_scaled, dtype=torch.float32))
|
||||
self.assertEqual(weight_quant_obs.max_val, torch.tensor(ref_max_weights_scaled, dtype=torch.float32))
|
||||
|
||||
weight_qparams = weight_quant_obs.calculate_qparams()
|
||||
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -80,7 +80,7 @@ class _InputEqualizationObserver(nn.Module):
|
|||
return self.input_obs(x_orig)
|
||||
|
||||
def get_input_minmax(self):
|
||||
return (self.input_obs.min_vals, self.input_obs.max_vals)
|
||||
return (self.input_obs.min_val, self.input_obs.max_val)
|
||||
|
||||
def set_equalization_scale(self, equalization_scale):
|
||||
# Reshape the equalization scale along axis=1 so that it can be
|
||||
|
|
@ -154,7 +154,7 @@ class _WeightEqualizationObserver(nn.Module):
|
|||
return self.weight_col_obs(w_orig)
|
||||
|
||||
def get_weight_col_minmax(self):
|
||||
return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals)
|
||||
return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
|
||||
|
||||
def set_equalization_scale(self, equalization_scale):
|
||||
self.equalization_scale = equalization_scale
|
||||
|
|
|
|||
|
|
@ -149,7 +149,9 @@ class _ObserverBase(ObserverBase):
|
|||
# Version 3
|
||||
# for HistogramObserver only, changed the shape of uninitialized
|
||||
# min_val and max_val buffers from torch.Size([0]) to torch.Size([])
|
||||
_version = 2
|
||||
# for PerChannelObservers, changed the name of the buffers from min_vals
|
||||
# to min_val and from max_vals to max_val.
|
||||
_version = 3
|
||||
|
||||
eps: torch.Tensor
|
||||
|
||||
|
|
@ -606,8 +608,8 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
.. note:: If the running minimum equals to the running maximum, the scales
|
||||
and zero_points are set to 1.0 and 0.
|
||||
"""
|
||||
min_vals: torch.Tensor
|
||||
max_vals: torch.Tensor
|
||||
min_val: torch.Tensor
|
||||
max_val: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -629,8 +631,8 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
)
|
||||
factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
|
||||
self.ch_axis = ch_axis
|
||||
self.register_buffer("min_vals", torch.tensor([], **factory_kwargs))
|
||||
self.register_buffer("max_vals", torch.tensor([], **factory_kwargs))
|
||||
self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
|
||||
self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
|
||||
if (
|
||||
self.qscheme == torch.per_channel_symmetric
|
||||
and self.reduce_range
|
||||
|
|
@ -647,8 +649,8 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
if x_orig.numel() == 0:
|
||||
return x_orig
|
||||
x = x_orig.detach() # avoid keeping autograd tape
|
||||
min_vals = self.min_vals
|
||||
max_vals = self.max_vals
|
||||
min_val = self.min_val
|
||||
max_val = self.max_val
|
||||
x_dim = x.size()
|
||||
|
||||
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
|
||||
|
|
@ -657,28 +659,27 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
y = x.permute(new_axis_list)
|
||||
# Need to match dtype of min/max because the updates to buffers
|
||||
# are done in place and types need to match for comparisons
|
||||
y = y.to(self.min_vals.dtype)
|
||||
y = y.to(self.min_val.dtype)
|
||||
y = torch.flatten(y, start_dim=1)
|
||||
if min_vals.numel() == 0 or max_vals.numel() == 0:
|
||||
min_vals, max_vals = torch._aminmax(y, 1)
|
||||
if min_val.numel() == 0 or max_val.numel() == 0:
|
||||
min_val, max_val = torch._aminmax(y, 1)
|
||||
else:
|
||||
min_vals_cur, max_vals_cur = torch._aminmax(y, 1)
|
||||
min_vals = torch.min(min_vals_cur, min_vals)
|
||||
max_vals = torch.max(max_vals_cur, max_vals)
|
||||
self.min_vals.resize_(min_vals.shape)
|
||||
self.max_vals.resize_(max_vals.shape)
|
||||
self.min_vals.copy_(min_vals)
|
||||
self.max_vals.copy_(max_vals)
|
||||
min_val_cur, max_val_cur = torch._aminmax(y, 1)
|
||||
min_val = torch.min(min_val_cur, min_val)
|
||||
max_val = torch.max(max_val_cur, max_val)
|
||||
self.min_val.resize_(min_val.shape)
|
||||
self.max_val.resize_(max_val.shape)
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.copy_(max_val)
|
||||
return x_orig
|
||||
|
||||
@torch.jit.export
|
||||
def calculate_qparams(self):
|
||||
return self._calculate_qparams(self.min_vals, self.max_vals)
|
||||
return self._calculate_qparams(self.min_val, self.max_val)
|
||||
|
||||
def extra_repr(self):
|
||||
return "min_val={}, max_val={}".format(self.min_vals, self.max_vals)
|
||||
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
|
||||
|
||||
@torch.jit.export
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]],
|
||||
|
|
@ -689,26 +690,38 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
):
|
||||
local_state = ["min_vals", "max_vals"]
|
||||
version = local_metadata.get("version", None)
|
||||
if version is None or version < 3:
|
||||
local_state = ["min_vals", "max_vals"]
|
||||
expected_min_name = "min_vals"
|
||||
expected_max_name = "max_vals"
|
||||
else:
|
||||
local_state = ["min_val", "max_val"]
|
||||
expected_min_name = "min_val"
|
||||
expected_max_name = "max_val"
|
||||
for name in local_state:
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
val = state_dict[key]
|
||||
# Custom handling to allow loading min_vals or max_vals
|
||||
# Custom handling to allow loading min_val or max_val
|
||||
# of size N into uninitialized buffers of size 0. The
|
||||
# buffers are resized here, and the values are copied in
|
||||
# the default state_dict loading code of the parent.
|
||||
if name == "min_vals":
|
||||
self.min_vals.resize_(val.shape)
|
||||
if name == expected_min_name:
|
||||
self.min_val.resize_(val.shape)
|
||||
elif name == expected_max_name:
|
||||
self.max_val.resize_(val.shape)
|
||||
else:
|
||||
self.max_vals.resize_(val.shape)
|
||||
warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name))
|
||||
# For torchscript module we need to update the attributes here since we do not
|
||||
# call the `_load_from_state_dict` function defined module.py
|
||||
if torch.jit.is_scripting():
|
||||
if name == "min_vals":
|
||||
self.min_vals.copy_(val)
|
||||
if name == expected_min_name:
|
||||
self.min_val.copy_(val)
|
||||
elif name == expected_max_name:
|
||||
self.max_val.copy_(val)
|
||||
else:
|
||||
self.max_vals.copy_(val)
|
||||
warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
|
|
@ -717,13 +730,12 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
False,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
|
||||
@torch.jit.export
|
||||
def _load_from_state_dict_script(
|
||||
self,
|
||||
state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]],
|
||||
|
|
@ -748,8 +760,8 @@ class PerChannelMinMaxObserver(_ObserverBase):
|
|||
@torch.jit.export
|
||||
def reset_min_max_vals(self):
|
||||
"""Resets the min/max values."""
|
||||
self.min_vals = torch.tensor([])
|
||||
self.max_vals = torch.tensor([])
|
||||
self.min_val = torch.tensor([])
|
||||
self.max_val = torch.tensor([])
|
||||
|
||||
|
||||
class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
|
||||
|
|
@ -805,9 +817,9 @@ class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
|
|||
if x_orig.numel() == 0:
|
||||
return x_orig
|
||||
x = x_orig.detach() # avoid keeping autograd tape
|
||||
x = x.to(self.min_vals.dtype)
|
||||
min_vals = self.min_vals
|
||||
max_vals = self.max_vals
|
||||
x = x.to(self.min_val.dtype)
|
||||
min_val = self.min_val
|
||||
max_val = self.max_val
|
||||
x_dim = x.size()
|
||||
|
||||
new_axis_list = [i for i in range(len(x_dim))] # noqa: C416
|
||||
|
|
@ -815,16 +827,16 @@ class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver):
|
|||
new_axis_list[0] = self.ch_axis
|
||||
y = x.permute(new_axis_list)
|
||||
y = torch.flatten(y, start_dim=1)
|
||||
if min_vals.numel() == 0 or max_vals.numel() == 0:
|
||||
min_vals, max_vals = torch._aminmax(y, 1)
|
||||
if min_val.numel() == 0 or max_val.numel() == 0:
|
||||
min_val, max_val = torch._aminmax(y, 1)
|
||||
else:
|
||||
min_vals_cur, max_vals_cur = torch._aminmax(y, 1)
|
||||
min_vals = min_vals + self.averaging_constant * (min_vals_cur - min_vals)
|
||||
max_vals = max_vals + self.averaging_constant * (max_vals_cur - max_vals)
|
||||
self.min_vals.resize_(min_vals.shape)
|
||||
self.max_vals.resize_(max_vals.shape)
|
||||
self.min_vals.copy_(min_vals)
|
||||
self.max_vals.copy_(max_vals)
|
||||
min_val_cur, max_val_cur = torch._aminmax(y, 1)
|
||||
min_val = min_val + self.averaging_constant * (min_val_cur - min_val)
|
||||
max_val = max_val + self.averaging_constant * (max_val_cur - max_val)
|
||||
self.min_val.resize_(min_val.shape)
|
||||
self.max_val.resize_(max_val.shape)
|
||||
self.min_val.copy_(min_val)
|
||||
self.max_val.copy_(max_val)
|
||||
return x_orig
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue