mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary: Extra C binding module for flatbuffer was introduced because not all dependencies of Pytorch want (or can) bundle in flatbuffer. However, flatbuffer is in by default now so this separate binding is not longer needed. Test Plan: existing unit tests Differential Revision: D44352583 Pull Request resolved: https://github.com/pytorch/pytorch/pull/97476 Approved by: https://github.com/dbort
529 lines
22 KiB
Python
529 lines
22 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Owner(s): ["oncall: quantization"]
|
|
|
|
import torch
|
|
import torch._C
|
|
|
|
from torch.ao.quantization import (
|
|
default_dynamic_qconfig,
|
|
per_channel_dynamic_qconfig,
|
|
)
|
|
|
|
from torch.ao.quantization.quantize_jit import (
|
|
prepare_dynamic_jit,
|
|
convert_dynamic_jit,
|
|
_prepare_ondevice_dynamic_jit,
|
|
_quantize_ondevice_dynamic_jit,
|
|
)
|
|
|
|
from torch.testing._internal.common_utils import TestCase
|
|
|
|
from torch.testing._internal.common_quantization import (
|
|
get_script_module,
|
|
LinearAddModel,
|
|
)
|
|
|
|
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
|
|
|
|
from torch.testing import FileCheck
|
|
from torch.utils import bundled_inputs as bundled_inputs
|
|
|
|
import io
|
|
from typing import Dict
|
|
|
|
class myMod(torch.nn.Module):
|
|
def __init__(self, weight):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(5, 5).float()
|
|
self.fc1.weight = weight
|
|
self.fc2 = torch.nn.Linear(5, 5).float()
|
|
|
|
def forward(self, x):
|
|
return self.fc2(self.fc1(x))
|
|
|
|
|
|
class MyConvLinearModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 5, 3)
|
|
weight = torch.nn.Parameter(torch.ones(5, 5))
|
|
self.weight1 = torch.nn.Parameter(torch.ones(5, 5))
|
|
self.mymod = myMod(weight)
|
|
|
|
def forward(self, x):
|
|
conv_output = self.conv(x)
|
|
y = self.mymod(conv_output)
|
|
z = torch.nn.functional.linear(y, self.weight1)
|
|
return z
|
|
|
|
def get_example_inputs(self):
|
|
return (torch.rand(1, 3, 12, 7),)
|
|
|
|
|
|
class OnDevicePTQUtils:
|
|
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
|
|
|
|
@staticmethod
|
|
def insert_observers(model, qconfig_dict):
|
|
inputs = model.get_example_inputs()
|
|
scripted_model = get_script_module(model, False, inputs)
|
|
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
|
|
return scripted_model
|
|
|
|
@staticmethod
|
|
def ptq_dynamic_quantize(model, qconfig_dict):
|
|
inputs = model.get_example_inputs()
|
|
m = get_script_module(model, False, inputs)
|
|
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
|
|
return m
|
|
|
|
@staticmethod
|
|
def find_observer_modules(m):
|
|
observer_modules = []
|
|
for child_module in m.children():
|
|
if child_module.original_name in OnDevicePTQUtils.observer_module_name:
|
|
observer_modules.append(child_module)
|
|
return observer_modules
|
|
|
|
@staticmethod
|
|
def is_value_type_observer(value):
|
|
type_name = value.type()
|
|
for observer_type in OnDevicePTQUtils.observer_module_name:
|
|
if observer_type in type_name.str():
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def is_calculate_qparam(node):
|
|
if node.kind() == "prim::CallMethod":
|
|
if node.s('name') == "calculate_qparams":
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_linear_packed_param_fp_weight(node):
|
|
weight = node.inputsAt(0).node()
|
|
if weight.kind() != "aten::quantize_per_tensor" and weight.kind() != "aten::quantize_per_channel":
|
|
raise ValueError("Quantized weight must be produced.")
|
|
fp_weight = weight.inputsAt(0).node()
|
|
assert fp_weight.kind() == "prim::GetAttr", "Weight must be an attribute of the module."
|
|
fp_weight_name = fp_weight.s('name')
|
|
return fp_weight_name
|
|
|
|
@staticmethod
|
|
def is_per_channel_quantized_packed_param(node):
|
|
assert node.kind() == 'quantized::linear_prepack', "Node must corresponds to linear_prepack."
|
|
weight = node.inputsAt(0).node()
|
|
assert weight.kind() != "aten::quantize_per_tensor" or weight.kind() != "aten::quantize_per_channel"
|
|
return weight.kind() != "aten::quantize_per_tensor"
|
|
|
|
|
|
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
|
|
def _check_num_and_type_of_observers(self, model, num_observers):
|
|
qconfig_dict = {"": default_dynamic_qconfig}
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
|
|
self.assertTrue(len(observer_modules) == num_observers)
|
|
for observer in observer_modules:
|
|
self.assertTrue(observer.original_name == 'MinMaxObserver')
|
|
|
|
qconfig_dict = {"": per_channel_dynamic_qconfig}
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
|
|
self.assertTrue(len(observer_modules) == num_observers)
|
|
for observer in observer_modules:
|
|
self.assertTrue(observer.original_name == 'PerChannelMinMaxObserver')
|
|
|
|
def _check_observer_method(self, model, num_observers):
|
|
qconfig_dict = {"": default_dynamic_qconfig}
|
|
inputs = model.get_example_inputs()
|
|
orig_scripted_model = get_script_module(model, False, inputs)
|
|
torch._C._jit_pass_inline(orig_scripted_model.graph)
|
|
orig_forward_graph = orig_scripted_model.graph.str()
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
quant_forward_graph = scripted_model.graph.str()
|
|
# exact graph matching is difficult so just resorting to # of lines
|
|
# instead of implementing graph matching
|
|
self.assertEqual(len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines()))
|
|
observe_method = scripted_model.observe_forward.graph
|
|
FileCheck().check_count("prim::CallMethod[name=\"forward\"](%_observer",
|
|
num_observers, exactly=True).run(observe_method)
|
|
reset_observers_method = scripted_model.reset_observers_forward.graph
|
|
FileCheck().check_count(
|
|
"prim::CallMethod[name=\"reset_min_max_vals\"](%_observer", num_observers, exactly=True).run(reset_observers_method)
|
|
|
|
def _observer_is_weight_only(self, node):
|
|
if (node.kind() == "prim::CallMethod") and node.s("name") == "forward":
|
|
if (OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0))):
|
|
return (node.inputsAt(1).node().kind() == "prim::GetAttr")
|
|
return False
|
|
|
|
def test_num_observers(self):
|
|
model = LinearAddModel()
|
|
self._check_num_and_type_of_observers(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_num_and_type_of_observers(model, 3)
|
|
|
|
def test_observe_method(self):
|
|
model = MyConvLinearModule()
|
|
self._check_observer_method(model, 3)
|
|
|
|
def test_weight_only_observers(self):
|
|
model = MyConvLinearModule()
|
|
qconfig_dict = {"": default_dynamic_qconfig}
|
|
inputs = model.get_example_inputs()
|
|
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
|
|
observe_forward_graph = scripted_model.observe_forward.graph
|
|
num_weight_only_observers = 0
|
|
for node in observe_forward_graph.nodes():
|
|
if (self._observer_is_weight_only(node)):
|
|
num_weight_only_observers += 1
|
|
self.assertEqual(num_weight_only_observers, 3)
|
|
|
|
|
|
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
|
|
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
quantize_per_tensor = quantize_per_channel = 0
|
|
for n in quantize_forward_graph.nodes():
|
|
if "aten::quantize_per_tensor" in n.kind():
|
|
quantize_per_tensor += 1
|
|
if "aten::quantize_per_channel" in n.kind():
|
|
quantize_per_channel += 1
|
|
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
|
|
|
|
def _validate_calculate_qparams(self, model, num_nodes):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
num_calculate_qparams = 0
|
|
for n in quantize_forward_graph.nodes():
|
|
if OnDevicePTQUtils.is_calculate_qparam(n):
|
|
num_calculate_qparams += 1
|
|
self.assertEqual(num_calculate_qparams, num_nodes)
|
|
|
|
def _validate_no_observer_forward(self, model):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
for n in quantize_forward_graph.nodes():
|
|
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
|
|
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
|
|
return False
|
|
return True
|
|
|
|
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quant_dequant_nodes(m, num_nodes)
|
|
self._validate_calculate_qparams(m, num_nodes)
|
|
self._validate_no_observer_forward(m)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
|
|
self._validate_calculate_qparams(m, num_nodes)
|
|
self._validate_no_observer_forward(m)
|
|
|
|
def _check_quantize_forward_runs(self, model):
|
|
inputs = model.get_example_inputs()
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
# First must run observe forward to record the stats to produce
|
|
# correct scales and zero points
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
|
|
def test_num_quant_dequant_nodes(self):
|
|
model = LinearAddModel()
|
|
self._check_quant_dequant_and_calc_qparams(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_quant_dequant_and_calc_qparams(model, 3)
|
|
|
|
def test_quantize_forward_runs(self):
|
|
model = LinearAddModel()
|
|
self._check_quantize_forward_runs(model)
|
|
model = MyConvLinearModule()
|
|
self._check_quantize_forward_runs(model)
|
|
|
|
|
|
class TestOnDeviceDynamicPTQFinalize(TestCase):
|
|
def _validate_packed_params(self, model, num_nodes, per_channel=0):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
quantize_per_tensor = quantize_per_channel = 0
|
|
linear_prepack = 0
|
|
linear_prepack_uses = 0
|
|
for n in quantize_forward_graph.nodes():
|
|
if n.kind() == 'prim::SetAttr':
|
|
maybe_packed_param_value = n.inputsAt(1)
|
|
maybe_packed_param = maybe_packed_param_value.node()
|
|
if maybe_packed_param.kind() == 'quantized::linear_prepack':
|
|
linear_prepack += 1
|
|
linear_prepack_uses += len(maybe_packed_param_value.uses())
|
|
if OnDevicePTQUtils.is_per_channel_quantized_packed_param(maybe_packed_param):
|
|
quantize_per_channel += 1
|
|
else:
|
|
quantize_per_tensor += 1
|
|
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
|
|
self.assertEqual(quantize_per_channel, per_channel)
|
|
self.assertEqual(linear_prepack, num_nodes)
|
|
self.assertEqual(linear_prepack_uses, num_nodes)
|
|
|
|
|
|
def _validate_no_linear_unpack(self, model):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
for n in quantize_forward_graph.nodes():
|
|
if n.kind() == 'quantized::linear_unpack':
|
|
return False
|
|
return True
|
|
|
|
|
|
def _validate_setattr_fp_weights(self, model, num_nodes):
|
|
quantize_forward_graph = model.quantize_forward.graph
|
|
fp_weights_setattr = 0
|
|
fp_weight_names = []
|
|
for n in quantize_forward_graph.nodes():
|
|
if n.kind() == 'prim::SetAttr':
|
|
maybe_packed_param = n.inputsAt(1).node()
|
|
if maybe_packed_param.kind() == 'quantized::linear_prepack':
|
|
weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(maybe_packed_param)
|
|
fp_weight_names.append(weight_name)
|
|
|
|
for n in quantize_forward_graph.nodes():
|
|
# This is basically detecting
|
|
# %x = prim::Constant
|
|
# = prim::SetAttr(<weight_name>)(module_value, x)
|
|
# Thus making sure that the original fp weights are
|
|
# reset
|
|
if n.kind() == 'prim::SetAttr':
|
|
weight_name = n.s('name')
|
|
if weight_name in fp_weight_names:
|
|
maybe_constant = n.inputsAt(1).node()
|
|
if maybe_constant.kind() == 'prim::Constant':
|
|
fp_weights_setattr += 1
|
|
self.assertEqual(fp_weights_setattr, num_nodes)
|
|
|
|
|
|
def _validate_quantized_forward(self, model, num_nodes):
|
|
quantized_forward_graph = model.quantized_forward.graph
|
|
quantize_per_tensor = quantize_per_channel = 0
|
|
quantized_linear_dynamic = 0
|
|
linear_packed_params = 0
|
|
num_setattr = 0
|
|
for n in quantized_forward_graph.nodes():
|
|
if "aten::quantize_per_tensor" in n.kind():
|
|
quantize_per_tensor += 1
|
|
if "aten::quantize_per_channel" in n.kind():
|
|
quantize_per_channel += 1
|
|
if "quantized::linear_dynamic" in n.kind():
|
|
quantized_linear_dynamic += 1
|
|
if n.kind() == 'prim::GetAttr':
|
|
output = n.outputsAt(0)
|
|
output_type = output.type()
|
|
if "LinearPackedParamsBase" in output_type.str():
|
|
linear_packed_params += 1
|
|
if n.kind() == 'prim::SetAttr':
|
|
num_setattr += 1
|
|
self.assertEqual(quantize_per_tensor, 0)
|
|
self.assertEqual(quantize_per_channel, 0)
|
|
self.assertEqual(quantized_linear_dynamic, num_nodes)
|
|
self.assertEqual(linear_packed_params, num_nodes)
|
|
# self.assertEqual(num_setattr, 0)
|
|
|
|
|
|
def _check_quantize_forward(self, model, num_nodes):
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_packed_params(m, num_nodes)
|
|
self._validate_no_linear_unpack(m)
|
|
self._validate_setattr_fp_weights(m, num_nodes)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_packed_params(m, num_nodes, num_nodes)
|
|
self._validate_no_linear_unpack(m)
|
|
self._validate_setattr_fp_weights(m, num_nodes)
|
|
|
|
|
|
def _check_quantized_forward(self, model, num_nodes):
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quantized_forward(m, num_nodes)
|
|
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
self._validate_quantized_forward(m, num_nodes)
|
|
|
|
|
|
def _check_against_ref_dynamic_ptq(self, model):
|
|
model.eval()
|
|
inputs = model.get_example_inputs()
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
thrown = False
|
|
try:
|
|
m(*inputs)
|
|
except Exception as e:
|
|
thrown = True
|
|
self.assertTrue(thrown)
|
|
|
|
# test with per channel quant
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
thrown = False
|
|
try:
|
|
m(*inputs)
|
|
except Exception as e:
|
|
thrown = True
|
|
self.assertTrue(thrown)
|
|
|
|
|
|
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
|
|
model.eval()
|
|
inputs = model.get_example_inputs()
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : default_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(ref_m, buffer)
|
|
buffer.seek(0)
|
|
ref_m = torch.jit.load(buffer)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
if not check_device_side_api:
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
buffer.seek(0)
|
|
m = torch.jit.load(buffer)
|
|
m.reset_observers_forward()
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
else:
|
|
# check for lite interpreter
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
first_input, = inputs
|
|
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
|
|
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
|
|
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
m = _load_for_lite_interpreter(buffer) # Error here
|
|
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
|
|
self.assertFalse(m.find_method("quantized_forward"))
|
|
self.assertFalse(m.find_method("quantize_forward"))
|
|
self.assertFalse(m.find_method("observe_forward"))
|
|
self.assertFalse(m.find_method("reset_observers_forward"))
|
|
output = m(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
|
|
# Now serialize to flabuffer and load from fb and check
|
|
dict: Dict[str, str] = {}
|
|
bytes = torch._C._save_mobile_module_to_bytes(m._c, dict)
|
|
m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes))
|
|
fb_output = m(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, fb_output))
|
|
|
|
model.eval()
|
|
inputs = model.get_example_inputs()
|
|
ref_m = torch.jit.script(model)
|
|
torch._C._jit_pass_inline(ref_m.graph)
|
|
qconfig_dict = {"" : per_channel_dynamic_qconfig}
|
|
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
|
|
ref_m = convert_dynamic_jit(ref_m)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(ref_m, buffer)
|
|
buffer.seek(0)
|
|
ref_m = torch.jit.load(buffer)
|
|
ref_output = ref_m(*inputs)
|
|
|
|
if not check_device_side_api:
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
buffer.seek(0)
|
|
m = torch.jit.load(buffer)
|
|
m.reset_observers_forward()
|
|
m.observe_forward(*inputs)
|
|
m.quantize_forward(*inputs)
|
|
output = m.quantized_forward(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
else:
|
|
# check for lite interpreter
|
|
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
|
|
first_input, = inputs
|
|
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
|
|
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
|
|
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
|
|
buffer.seek(0)
|
|
m = _load_for_lite_interpreter(buffer) # Error here
|
|
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
|
|
self.assertFalse(m.find_method("quantized_forward"))
|
|
self.assertFalse(m.find_method("quantize_forward"))
|
|
self.assertFalse(m.find_method("observe_forward"))
|
|
self.assertFalse(m.find_method("reset_observers_forward"))
|
|
output = m(*inputs)
|
|
self.assertTrue(torch.allclose(ref_output, output))
|
|
|
|
|
|
def _check_serialization_deserialization(self, model):
|
|
self._check_serdes_and_device_side_api_helper(model, False)
|
|
|
|
|
|
def _check_device_side_api(self, model):
|
|
self._check_serdes_and_device_side_api_helper(model, True)
|
|
|
|
|
|
def test_quantize_forward(self):
|
|
model = LinearAddModel()
|
|
self._check_quantize_forward(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_quantize_forward(model, 3)
|
|
|
|
|
|
def test_quantized_forward(self):
|
|
model = LinearAddModel()
|
|
self._check_quantized_forward(model, 2)
|
|
model = MyConvLinearModule()
|
|
self._check_quantized_forward(model, 3)
|
|
|
|
|
|
def test_against_offdevice_dynamic_ptq(self):
|
|
model = LinearAddModel()
|
|
self._check_against_ref_dynamic_ptq(model)
|
|
model = MyConvLinearModule()
|
|
self._check_against_ref_dynamic_ptq(model)
|
|
|
|
|
|
def test_serialization_deserialization(self):
|
|
model = MyConvLinearModule()
|
|
self._check_serialization_deserialization(model)
|
|
|
|
|
|
def test_device_side_api(self):
|
|
model = MyConvLinearModule()
|
|
self._check_device_side_api(model)
|