diff --git a/test/quantization/bc/test_backward_compatibility.py b/test/quantization/bc/test_backward_compatibility.py index 65253869ddc..8b488f56fd0 100644 --- a/test/quantization/bc/test_backward_compatibility.py +++ b/test/quantization/bc/test_backward_compatibility.py @@ -14,6 +14,8 @@ import torch.nn.intrinsic.quantized as nniq from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm +from torch.quantization import MinMaxObserver, PerChannelMinMaxObserver + def remove_prefix(text, prefix): if text.startswith(prefix): return text[len(prefix):] @@ -122,6 +124,24 @@ class TestSerialization(TestCase): self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) + def _test_obs(self, obs, input_size, subname=None, generate=False): + """ + Test observer code can be loaded from state_dict. + """ + input_file, state_dict_file, _, traced_module_file, expected_file = \ + get_filenames(self, None) + if generate: + input_tensor = torch.rand(*input_size).float() + torch.save(input_tensor, input_file) + torch.save(obs(input_tensor), expected_file) + torch.save(obs.state_dict(), state_dict_file) + + input_tensor = torch.load(input_file) + obs.load_state_dict(torch.load(state_dict_file)) + expected = torch.load(expected_file) + + self.assertEqual(obs(input_tensor), expected) + @override_qengines def test_linear(self): module = nnq.Linear(3, 1, bias_=True, dtype=torch.qint8) @@ -251,3 +271,11 @@ class TestSerialization(TestCase): if qengine_is_fbgemm(): mod = LSTMModule() self._test_op(mod, input_size=[4, 4, 3], input_quantized=False, generate=False, new_zipfile_serialization=True) + + def test_per_channel_observer(self): + obs = PerChannelMinMaxObserver() + self._test_obs(obs, input_size=[5, 5], generate=False) + + def test_per_tensor_observer(self): + obs = MinMaxObserver() + self._test_obs(obs, input_size=[5, 5], generate=False) diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index e728fb54e8f..2115515b581 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -186,8 +186,8 @@ class TestObserver(QuantizationTestCase): ] per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]] - self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis]) - self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis]) + self.assertEqual(myobs.min_val, ref_min_vals[ch_axis]) + self.assertEqual(myobs.max_val, ref_max_vals[ch_axis]) if qscheme == torch.per_channel_symmetric: ref_scales = per_channel_symmetric_ref_scales[ch_axis] ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128] @@ -223,8 +223,8 @@ class TestObserver(QuantizationTestCase): loaded_obs = PerChannelMinMaxObserver(reduce_range=reduce_range, ch_axis=ch_axis, dtype=qdtype, qscheme=qscheme) loaded_obs.load_state_dict(loaded_dict) loaded_qparams = loaded_obs.calculate_qparams() - self.assertEqual(myobs.min_vals, loaded_obs.min_vals) - self.assertEqual(myobs.max_vals, loaded_obs.max_vals) + self.assertEqual(myobs.min_val, loaded_obs.min_val) + self.assertEqual(myobs.max_val, loaded_obs.max_val) self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams()) @@ -314,9 +314,7 @@ class TestObserver(QuantizationTestCase): Tests that we can save and load state_dict for observers that are scripted in a quantized model. """ - obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, - MovingAveragePerChannelMinMaxObserver, HistogramObserver] + obs_list = [MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver] for obs in obs_list: model = SingleLayerLinearModel().eval() diff --git a/test/quantization/fx/test_equalize_fx.py b/test/quantization/fx/test_equalize_fx.py index b11f1a196bf..384584c9830 100644 --- a/test/quantization/fx/test_equalize_fx.py +++ b/test/quantization/fx/test_equalize_fx.py @@ -189,8 +189,8 @@ class TestEqualizeFx(QuantizationTestCase): # Check the min/max weight rows are correct ref_min_weights_scaled, ref_max_weights_scaled = self.channel_minmax(ref_w_scaled) - self.assertEqual(weight_quant_obs.min_vals, torch.tensor(ref_min_weights_scaled, dtype=torch.float32)) - self.assertEqual(weight_quant_obs.max_vals, torch.tensor(ref_max_weights_scaled, dtype=torch.float32)) + self.assertEqual(weight_quant_obs.min_val, torch.tensor(ref_min_weights_scaled, dtype=torch.float32)) + self.assertEqual(weight_quant_obs.max_val, torch.tensor(ref_max_weights_scaled, dtype=torch.float32)) weight_qparams = weight_quant_obs.calculate_qparams() diff --git a/test/quantization/serialized/TestSerialization.test_per_channel_observer.expected.pt b/test/quantization/serialized/TestSerialization.test_per_channel_observer.expected.pt new file mode 100644 index 00000000000..89e93848a6d Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_per_channel_observer.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_per_channel_observer.input.pt b/test/quantization/serialized/TestSerialization.test_per_channel_observer.input.pt new file mode 100644 index 00000000000..89e93848a6d Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_per_channel_observer.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_per_channel_observer.state_dict.pt b/test/quantization/serialized/TestSerialization.test_per_channel_observer.state_dict.pt new file mode 100644 index 00000000000..70eed70d0b9 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_per_channel_observer.state_dict.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_per_tensor_observer.expected.pt b/test/quantization/serialized/TestSerialization.test_per_tensor_observer.expected.pt new file mode 100644 index 00000000000..89e93848a6d Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_per_tensor_observer.expected.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_per_tensor_observer.input.pt b/test/quantization/serialized/TestSerialization.test_per_tensor_observer.input.pt new file mode 100644 index 00000000000..89e93848a6d Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_per_tensor_observer.input.pt differ diff --git a/test/quantization/serialized/TestSerialization.test_per_tensor_observer.state_dict.pt b/test/quantization/serialized/TestSerialization.test_per_tensor_observer.state_dict.pt new file mode 100644 index 00000000000..5f3817504e3 Binary files /dev/null and b/test/quantization/serialized/TestSerialization.test_per_tensor_observer.state_dict.pt differ diff --git a/torch/quantization/fx/_equalize.py b/torch/quantization/fx/_equalize.py index 6b3080868fc..caf6f852f03 100644 --- a/torch/quantization/fx/_equalize.py +++ b/torch/quantization/fx/_equalize.py @@ -80,7 +80,7 @@ class _InputEqualizationObserver(nn.Module): return self.input_obs(x_orig) def get_input_minmax(self): - return (self.input_obs.min_vals, self.input_obs.max_vals) + return (self.input_obs.min_val, self.input_obs.max_val) def set_equalization_scale(self, equalization_scale): # Reshape the equalization scale along axis=1 so that it can be @@ -154,7 +154,7 @@ class _WeightEqualizationObserver(nn.Module): return self.weight_col_obs(w_orig) def get_weight_col_minmax(self): - return (self.weight_col_obs.min_vals, self.weight_col_obs.max_vals) + return (self.weight_col_obs.min_val, self.weight_col_obs.max_val) def set_equalization_scale(self, equalization_scale): self.equalization_scale = equalization_scale diff --git a/torch/quantization/observer.py b/torch/quantization/observer.py index 3e76932eeea..0cccf12b273 100644 --- a/torch/quantization/observer.py +++ b/torch/quantization/observer.py @@ -149,7 +149,9 @@ class _ObserverBase(ObserverBase): # Version 3 # for HistogramObserver only, changed the shape of uninitialized # min_val and max_val buffers from torch.Size([0]) to torch.Size([]) - _version = 2 + # for PerChannelObservers, changed the name of the buffers from min_vals + # to min_val and from max_vals to max_val. + _version = 3 eps: torch.Tensor @@ -606,8 +608,8 @@ class PerChannelMinMaxObserver(_ObserverBase): .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ - min_vals: torch.Tensor - max_vals: torch.Tensor + min_val: torch.Tensor + max_val: torch.Tensor def __init__( self, @@ -629,8 +631,8 @@ class PerChannelMinMaxObserver(_ObserverBase): ) factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) self.ch_axis = ch_axis - self.register_buffer("min_vals", torch.tensor([], **factory_kwargs)) - self.register_buffer("max_vals", torch.tensor([], **factory_kwargs)) + self.register_buffer("min_val", torch.tensor([], **factory_kwargs)) + self.register_buffer("max_val", torch.tensor([], **factory_kwargs)) if ( self.qscheme == torch.per_channel_symmetric and self.reduce_range @@ -647,8 +649,8 @@ class PerChannelMinMaxObserver(_ObserverBase): if x_orig.numel() == 0: return x_orig x = x_orig.detach() # avoid keeping autograd tape - min_vals = self.min_vals - max_vals = self.max_vals + min_val = self.min_val + max_val = self.max_val x_dim = x.size() new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 @@ -657,28 +659,27 @@ class PerChannelMinMaxObserver(_ObserverBase): y = x.permute(new_axis_list) # Need to match dtype of min/max because the updates to buffers # are done in place and types need to match for comparisons - y = y.to(self.min_vals.dtype) + y = y.to(self.min_val.dtype) y = torch.flatten(y, start_dim=1) - if min_vals.numel() == 0 or max_vals.numel() == 0: - min_vals, max_vals = torch._aminmax(y, 1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch._aminmax(y, 1) else: - min_vals_cur, max_vals_cur = torch._aminmax(y, 1) - min_vals = torch.min(min_vals_cur, min_vals) - max_vals = torch.max(max_vals_cur, max_vals) - self.min_vals.resize_(min_vals.shape) - self.max_vals.resize_(max_vals.shape) - self.min_vals.copy_(min_vals) - self.max_vals.copy_(max_vals) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val = torch.min(min_val_cur, min_val) + max_val = torch.max(max_val_cur, max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) return x_orig @torch.jit.export def calculate_qparams(self): - return self._calculate_qparams(self.min_vals, self.max_vals) + return self._calculate_qparams(self.min_val, self.max_val) def extra_repr(self): - return "min_val={}, max_val={}".format(self.min_vals, self.max_vals) + return "min_val={}, max_val={}".format(self.min_val, self.max_val) - @torch.jit.export def _load_from_state_dict( self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], @@ -689,26 +690,38 @@ class PerChannelMinMaxObserver(_ObserverBase): unexpected_keys: List[str], error_msgs: List[str], ): - local_state = ["min_vals", "max_vals"] + version = local_metadata.get("version", None) + if version is None or version < 3: + local_state = ["min_vals", "max_vals"] + expected_min_name = "min_vals" + expected_max_name = "max_vals" + else: + local_state = ["min_val", "max_val"] + expected_min_name = "min_val" + expected_max_name = "max_val" for name in local_state: key = prefix + name if key in state_dict: val = state_dict[key] - # Custom handling to allow loading min_vals or max_vals + # Custom handling to allow loading min_val or max_val # of size N into uninitialized buffers of size 0. The # buffers are resized here, and the values are copied in # the default state_dict loading code of the parent. - if name == "min_vals": - self.min_vals.resize_(val.shape) + if name == expected_min_name: + self.min_val.resize_(val.shape) + elif name == expected_max_name: + self.max_val.resize_(val.shape) else: - self.max_vals.resize_(val.shape) + warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name)) # For torchscript module we need to update the attributes here since we do not # call the `_load_from_state_dict` function defined module.py if torch.jit.is_scripting(): - if name == "min_vals": - self.min_vals.copy_(val) + if name == expected_min_name: + self.min_val.copy_(val) + elif name == expected_max_name: + self.max_val.copy_(val) else: - self.max_vals.copy_(val) + warnings.warn("Observer load_from_state_dict got unexpected name {}".format(name)) elif strict: missing_keys.append(key) @@ -717,13 +730,12 @@ class PerChannelMinMaxObserver(_ObserverBase): state_dict, prefix, local_metadata, - strict, + False, missing_keys, unexpected_keys, error_msgs, ) - @torch.jit.export def _load_from_state_dict_script( self, state_dict: Union[Dict[str, torch.Tensor], Dict[str, torch.Tensor]], @@ -748,8 +760,8 @@ class PerChannelMinMaxObserver(_ObserverBase): @torch.jit.export def reset_min_max_vals(self): """Resets the min/max values.""" - self.min_vals = torch.tensor([]) - self.max_vals = torch.tensor([]) + self.min_val = torch.tensor([]) + self.max_val = torch.tensor([]) class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): @@ -805,9 +817,9 @@ class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): if x_orig.numel() == 0: return x_orig x = x_orig.detach() # avoid keeping autograd tape - x = x.to(self.min_vals.dtype) - min_vals = self.min_vals - max_vals = self.max_vals + x = x.to(self.min_val.dtype) + min_val = self.min_val + max_val = self.max_val x_dim = x.size() new_axis_list = [i for i in range(len(x_dim))] # noqa: C416 @@ -815,16 +827,16 @@ class MovingAveragePerChannelMinMaxObserver(PerChannelMinMaxObserver): new_axis_list[0] = self.ch_axis y = x.permute(new_axis_list) y = torch.flatten(y, start_dim=1) - if min_vals.numel() == 0 or max_vals.numel() == 0: - min_vals, max_vals = torch._aminmax(y, 1) + if min_val.numel() == 0 or max_val.numel() == 0: + min_val, max_val = torch._aminmax(y, 1) else: - min_vals_cur, max_vals_cur = torch._aminmax(y, 1) - min_vals = min_vals + self.averaging_constant * (min_vals_cur - min_vals) - max_vals = max_vals + self.averaging_constant * (max_vals_cur - max_vals) - self.min_vals.resize_(min_vals.shape) - self.max_vals.resize_(max_vals.shape) - self.min_vals.copy_(min_vals) - self.max_vals.copy_(max_vals) + min_val_cur, max_val_cur = torch._aminmax(y, 1) + min_val = min_val + self.averaging_constant * (min_val_cur - min_val) + max_val = max_val + self.averaging_constant * (max_val_cur - max_val) + self.min_val.resize_(min_val.shape) + self.max_val.resize_(max_val.shape) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) return x_orig