pytorch/test/quantization/jit/test_ondevice_quantization.py
Han Qi (qihqi) b895a0a675 [BE] Move flatbuffer related python C bindings to script_init (#97476)
Summary:
Extra C binding module for flatbuffer was introduced because
not all dependencies of Pytorch want (or can) bundle in flatbuffer.

However, flatbuffer is in by default now so this separate binding is not longer needed.

Test Plan: existing unit tests

Differential Revision: D44352583

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97476
Approved by: https://github.com/dbort
2023-03-28 17:56:32 +00:00

529 lines
22 KiB
Python

# -*- coding: utf-8 -*-
# Owner(s): ["oncall: quantization"]
import torch
import torch._C
from torch.ao.quantization import (
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
)
from torch.ao.quantization.quantize_jit import (
prepare_dynamic_jit,
convert_dynamic_jit,
_prepare_ondevice_dynamic_jit,
_quantize_ondevice_dynamic_jit,
)
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_quantization import (
get_script_module,
LinearAddModel,
)
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
from torch.testing import FileCheck
from torch.utils import bundled_inputs as bundled_inputs
import io
from typing import Dict
class myMod(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.fc1 = torch.nn.Linear(5, 5).float()
self.fc1.weight = weight
self.fc2 = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc2(self.fc1(x))
class MyConvLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
weight = torch.nn.Parameter(torch.ones(5, 5))
self.weight1 = torch.nn.Parameter(torch.ones(5, 5))
self.mymod = myMod(weight)
def forward(self, x):
conv_output = self.conv(x)
y = self.mymod(conv_output)
z = torch.nn.functional.linear(y, self.weight1)
return z
def get_example_inputs(self):
return (torch.rand(1, 3, 12, 7),)
class OnDevicePTQUtils:
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
@staticmethod
def insert_observers(model, qconfig_dict):
inputs = model.get_example_inputs()
scripted_model = get_script_module(model, False, inputs)
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
return scripted_model
@staticmethod
def ptq_dynamic_quantize(model, qconfig_dict):
inputs = model.get_example_inputs()
m = get_script_module(model, False, inputs)
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
return m
@staticmethod
def find_observer_modules(m):
observer_modules = []
for child_module in m.children():
if child_module.original_name in OnDevicePTQUtils.observer_module_name:
observer_modules.append(child_module)
return observer_modules
@staticmethod
def is_value_type_observer(value):
type_name = value.type()
for observer_type in OnDevicePTQUtils.observer_module_name:
if observer_type in type_name.str():
return True
return False
@staticmethod
def is_calculate_qparam(node):
if node.kind() == "prim::CallMethod":
if node.s('name') == "calculate_qparams":
return True
return False
@staticmethod
def get_linear_packed_param_fp_weight(node):
weight = node.inputsAt(0).node()
if weight.kind() != "aten::quantize_per_tensor" and weight.kind() != "aten::quantize_per_channel":
raise ValueError("Quantized weight must be produced.")
fp_weight = weight.inputsAt(0).node()
assert fp_weight.kind() == "prim::GetAttr", "Weight must be an attribute of the module."
fp_weight_name = fp_weight.s('name')
return fp_weight_name
@staticmethod
def is_per_channel_quantized_packed_param(node):
assert node.kind() == 'quantized::linear_prepack', "Node must corresponds to linear_prepack."
weight = node.inputsAt(0).node()
assert weight.kind() != "aten::quantize_per_tensor" or weight.kind() != "aten::quantize_per_channel"
return weight.kind() != "aten::quantize_per_tensor"
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
def _check_num_and_type_of_observers(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'MinMaxObserver')
qconfig_dict = {"": per_channel_dynamic_qconfig}
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'PerChannelMinMaxObserver')
def _check_observer_method(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
orig_scripted_model = get_script_module(model, False, inputs)
torch._C._jit_pass_inline(orig_scripted_model.graph)
orig_forward_graph = orig_scripted_model.graph.str()
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
quant_forward_graph = scripted_model.graph.str()
# exact graph matching is difficult so just resorting to # of lines
# instead of implementing graph matching
self.assertEqual(len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines()))
observe_method = scripted_model.observe_forward.graph
FileCheck().check_count("prim::CallMethod[name=\"forward\"](%_observer",
num_observers, exactly=True).run(observe_method)
reset_observers_method = scripted_model.reset_observers_forward.graph
FileCheck().check_count(
"prim::CallMethod[name=\"reset_min_max_vals\"](%_observer", num_observers, exactly=True).run(reset_observers_method)
def _observer_is_weight_only(self, node):
if (node.kind() == "prim::CallMethod") and node.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0))):
return (node.inputsAt(1).node().kind() == "prim::GetAttr")
return False
def test_num_observers(self):
model = LinearAddModel()
self._check_num_and_type_of_observers(model, 2)
model = MyConvLinearModule()
self._check_num_and_type_of_observers(model, 3)
def test_observe_method(self):
model = MyConvLinearModule()
self._check_observer_method(model, 3)
def test_weight_only_observers(self):
model = MyConvLinearModule()
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observe_forward_graph = scripted_model.observe_forward.graph
num_weight_only_observers = 0
for node in observe_forward_graph.nodes():
if (self._observer_is_weight_only(node)):
num_weight_only_observers += 1
self.assertEqual(num_weight_only_observers, 3)
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = 0
for n in quantize_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
def _validate_calculate_qparams(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
num_calculate_qparams = 0
for n in quantize_forward_graph.nodes():
if OnDevicePTQUtils.is_calculate_qparam(n):
num_calculate_qparams += 1
self.assertEqual(num_calculate_qparams, num_nodes)
def _validate_no_observer_forward(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
return False
return True
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
def _check_quantize_forward_runs(self, model):
inputs = model.get_example_inputs()
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
# First must run observe forward to record the stats to produce
# correct scales and zero points
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
def test_num_quant_dequant_nodes(self):
model = LinearAddModel()
self._check_quant_dequant_and_calc_qparams(model, 2)
model = MyConvLinearModule()
self._check_quant_dequant_and_calc_qparams(model, 3)
def test_quantize_forward_runs(self):
model = LinearAddModel()
self._check_quantize_forward_runs(model)
model = MyConvLinearModule()
self._check_quantize_forward_runs(model)
class TestOnDeviceDynamicPTQFinalize(TestCase):
def _validate_packed_params(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = 0
linear_prepack = 0
linear_prepack_uses = 0
for n in quantize_forward_graph.nodes():
if n.kind() == 'prim::SetAttr':
maybe_packed_param_value = n.inputsAt(1)
maybe_packed_param = maybe_packed_param_value.node()
if maybe_packed_param.kind() == 'quantized::linear_prepack':
linear_prepack += 1
linear_prepack_uses += len(maybe_packed_param_value.uses())
if OnDevicePTQUtils.is_per_channel_quantized_packed_param(maybe_packed_param):
quantize_per_channel += 1
else:
quantize_per_tensor += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
self.assertEqual(quantize_per_channel, per_channel)
self.assertEqual(linear_prepack, num_nodes)
self.assertEqual(linear_prepack_uses, num_nodes)
def _validate_no_linear_unpack(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if n.kind() == 'quantized::linear_unpack':
return False
return True
def _validate_setattr_fp_weights(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
fp_weights_setattr = 0
fp_weight_names = []
for n in quantize_forward_graph.nodes():
if n.kind() == 'prim::SetAttr':
maybe_packed_param = n.inputsAt(1).node()
if maybe_packed_param.kind() == 'quantized::linear_prepack':
weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(maybe_packed_param)
fp_weight_names.append(weight_name)
for n in quantize_forward_graph.nodes():
# This is basically detecting
# %x = prim::Constant
# = prim::SetAttr(<weight_name>)(module_value, x)
# Thus making sure that the original fp weights are
# reset
if n.kind() == 'prim::SetAttr':
weight_name = n.s('name')
if weight_name in fp_weight_names:
maybe_constant = n.inputsAt(1).node()
if maybe_constant.kind() == 'prim::Constant':
fp_weights_setattr += 1
self.assertEqual(fp_weights_setattr, num_nodes)
def _validate_quantized_forward(self, model, num_nodes):
quantized_forward_graph = model.quantized_forward.graph
quantize_per_tensor = quantize_per_channel = 0
quantized_linear_dynamic = 0
linear_packed_params = 0
num_setattr = 0
for n in quantized_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
if "quantized::linear_dynamic" in n.kind():
quantized_linear_dynamic += 1
if n.kind() == 'prim::GetAttr':
output = n.outputsAt(0)
output_type = output.type()
if "LinearPackedParamsBase" in output_type.str():
linear_packed_params += 1
if n.kind() == 'prim::SetAttr':
num_setattr += 1
self.assertEqual(quantize_per_tensor, 0)
self.assertEqual(quantize_per_channel, 0)
self.assertEqual(quantized_linear_dynamic, num_nodes)
self.assertEqual(linear_packed_params, num_nodes)
# self.assertEqual(num_setattr, 0)
def _check_quantize_forward(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_packed_params(m, num_nodes)
self._validate_no_linear_unpack(m)
self._validate_setattr_fp_weights(m, num_nodes)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_packed_params(m, num_nodes, num_nodes)
self._validate_no_linear_unpack(m)
self._validate_setattr_fp_weights(m, num_nodes)
def _check_quantized_forward(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quantized_forward(m, num_nodes)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quantized_forward(m, num_nodes)
def _check_against_ref_dynamic_ptq(self, model):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : default_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
thrown = False
try:
m(*inputs)
except Exception as e:
thrown = True
self.assertTrue(thrown)
# test with per channel quant
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
thrown = False
try:
m(*inputs)
except Exception as e:
thrown = True
self.assertTrue(thrown)
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : default_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
buffer = io.BytesIO()
torch.jit.save(ref_m, buffer)
buffer.seek(0)
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# Now serialize to flabuffer and load from fb and check
dict: Dict[str, str] = {}
bytes = torch._C._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes))
fb_output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, fb_output))
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
buffer = io.BytesIO()
torch.jit.save(ref_m, buffer)
buffer.seek(0)
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
def _check_serialization_deserialization(self, model):
self._check_serdes_and_device_side_api_helper(model, False)
def _check_device_side_api(self, model):
self._check_serdes_and_device_side_api_helper(model, True)
def test_quantize_forward(self):
model = LinearAddModel()
self._check_quantize_forward(model, 2)
model = MyConvLinearModule()
self._check_quantize_forward(model, 3)
def test_quantized_forward(self):
model = LinearAddModel()
self._check_quantized_forward(model, 2)
model = MyConvLinearModule()
self._check_quantized_forward(model, 3)
def test_against_offdevice_dynamic_ptq(self):
model = LinearAddModel()
self._check_against_ref_dynamic_ptq(model)
model = MyConvLinearModule()
self._check_against_ref_dynamic_ptq(model)
def test_serialization_deserialization(self):
model = MyConvLinearModule()
self._check_serialization_deserialization(model)
def test_device_side_api(self):
model = MyConvLinearModule()
self._check_device_side_api(model)