pytorch/test/quantization/test_quantize_fx.py
Supriya Rao 4c3f59b70e [quant][fx] Make scale, zero_point buffers in the model and use FQN (for quantized ops) (#51166)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51166

Currently scale and zero_point values are stored as constant values in the graph.
This prevents these values from being updated in the graph and also does not enable saving
these values to state_dict

After this PR we store scale/zero_point values for quantized ops as buffers in the root module
and createe get_attr nodes for them in the graph.

We also use the FQN of the module where the quantized ops are present to name these attributes so
that they can be uniquely  identified and mapped to quantized ops.

Test Plan:
python test/test_quantization.py TestQuantizeFx.test_qparams_buffers

Imported from OSS

Reviewed By: jerryzh168

Differential Revision: D26092965

fbshipit-source-id: b549b2d3dccb45c5d38415ce95a09c26f5bd590b
2021-01-28 08:35:42 -08:00

3006 lines
111 KiB
Python

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.quantized as nniq
import torch.multiprocessing as mp
# graph mode quantization based on fx
from torch.quantization.quantize_fx import (
prepare_fx,
convert_fx,
prepare_qat_fx,
)
from torch.quantization.fx.pattern_utils import (
is_match,
MatchAllNode,
)
from torch.quantization import (
QuantType,
QuantStub,
DeQuantStub,
QuantWrapper,
quant_type_to_str,
default_qconfig,
default_dynamic_qconfig,
default_qat_qconfig,
per_channel_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_weight_only_qconfig,
get_default_qconfig,
get_default_qat_qconfig,
fuse_modules,
prepare,
prepare_qat,
convert,
quantize_dynamic,
default_placeholder_observer,
PerChannelMinMaxObserver,
QConfigDynamic,
FixedQParamsFakeQuantize,
)
# test utils
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
skip_if_no_torchvision,
train_one_epoch,
run_ddp,
test_only_eval_fn,
test_only_train_fn,
)
from torch.testing._internal.common_quantization import (
LinearModelWithSubmodule,
ResNetBase,
RNNDynamicModel,
RNNCellDynamicModel,
)
from torch.testing._internal.common_quantized import (
supported_qengines,
override_qengines,
override_quantized_engine,
)
from torch.testing._internal.common_distributed import skip_if_not_multigpu
from torch.testing._internal.common_quantization import NodeSpec as ns
from torch.testing import FileCheck
import copy
import itertools
import operator
import unittest
import io
from typing import Callable
class TestFuseFx(QuantizationTestCase):
def test_fuse_conv_bn_relu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(1, 1, 1)
self.conv2d = nn.Conv2d(1, 1, 1)
self.conv3d = nn.Conv3d(1, 1, 1)
self.bn1d = nn.BatchNorm1d(1)
self.bn2d = nn.BatchNorm2d(1)
self.bn3d = nn.BatchNorm3d(1)
self.conv1d2 = nn.Conv1d(1, 1, 1)
self.conv2d2 = nn.Conv2d(1, 1, 1)
self.conv3d2 = nn.Conv3d(1, 1, 1)
self.bn1d2 = nn.BatchNorm1d(1)
self.bn2d2 = nn.BatchNorm2d(1)
self.bn3d2 = nn.BatchNorm3d(1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv1d(x)
x = self.bn1d(x)
x = self.conv2d(x)
x = self.bn2d(x)
x = self.conv3d(x)
x = self.bn3d(x)
x = self.conv1d2(x)
x = self.bn1d2(x)
x = self.relu(x)
x = self.conv2d2(x)
x = self.bn2d2(x)
x = self.relu(x)
x = self.conv3d2(x)
x = self.bn3d2(x)
x = self.relu(x)
return x
# test train mode
m = M().train()
# currently we don't check if the module are configured with qconfig before fusion
# TODO: if we decide to do that in the future, this test needs to
# be updated
# train mode fuse_fx is called in prepare_qat_fx
m = prepare_qat_fx(m, {})
expected_nodes = [
ns.call_module(nni.ConvBn1d),
ns.call_module(nni.ConvBn2d),
ns.call_module(nni.ConvBn3d),
ns.call_module(nni.ConvBnReLU1d),
ns.call_module(nni.ConvBnReLU2d),
ns.call_module(nni.ConvBnReLU3d),
]
expected_occurrence = {
ns.call_module(nn.ReLU): 0
}
self.checkGraphModuleNodes(
m,
expected_node_list=expected_nodes,
expected_node_occurrence=expected_occurrence)
# test eval mode
m = M().eval()
from torch.quantization.quantize_fx import fuse_fx
# fuse_fx is a top level api and only supports eval mode
m = fuse_fx(m)
expected_nodes = [
ns.call_module(nn.Conv1d),
ns.call_module(nn.Conv2d),
ns.call_module(nn.Conv3d),
ns.call_module(nni.ConvReLU1d),
ns.call_module(nni.ConvReLU2d),
ns.call_module(nni.ConvReLU3d),
]
# ConvBnRelu1d is not fused
expected_occurrence = {
ns.call_module(nn.ReLU): 0
}
self.checkGraphModuleNodes(
m,
expected_node_list=expected_nodes,
expected_node_occurrence=expected_occurrence)
def test_fuse_module_relu(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1d = nn.Conv1d(1, 1, 1)
self.conv2d = nn.Conv2d(1, 1, 1)
self.conv3d = nn.Conv3d(1, 1, 1)
self.bn1d = nn.BatchNorm1d(1)
self.bn2d = nn.BatchNorm2d(1)
self.bn3d = nn.BatchNorm3d(1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv1d(x)
x = self.relu(x)
x = self.conv2d(x)
x = self.relu(x)
x = self.conv3d(x)
x = self.relu(x)
x = self.bn1d(x)
x = self.relu(x)
x = self.bn2d(x)
x = self.relu(x)
x = self.bn3d(x)
x = self.relu(x)
return x
m = M().eval()
from torch.quantization.quantize_fx import fuse_fx
m = fuse_fx(m)
expected_nodes = [
ns.call_module(nni.ConvReLU1d),
ns.call_module(nni.ConvReLU2d),
ns.call_module(nni.ConvReLU3d),
ns.call_module(nni.BNReLU2d),
ns.call_module(nni.BNReLU3d),
]
self.checkGraphModuleNodes(m, expected_node_list=expected_nodes)
@skipIfNoFBGEMM
class TestQuantizeFx(QuantizationTestCase):
def test_pattern_match(self):
""" test MatchAllNode with
conv - bn - add - relu pattern
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU()
def forward(self, x, y):
x = self.conv(x)
x = self.bn(x)
x = x + y
x = self.relu(x)
return x
pattern = (nn.ReLU, (operator.add, (nn.BatchNorm2d, nn.Conv2d), MatchAllNode))
m = torch.fx.symbolic_trace(M())
modules = dict(m.named_modules())
for n in m.graph.nodes:
if n.op == 'call_module' and type(modules[n.target]) == nn.ReLU:
self.assertTrue(is_match(modules, n, pattern))
def _get_conv_linear_test_cases(self):
''' Returns a list of test cases, with format:
is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_op
'''
class Conv(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
def forward(self, x):
return F.conv2d(x, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
conv_input = torch.rand(1, 3, 224, 224)
conv_weight = torch.rand(3, 3, 3, 3)
class Linear(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
def forward(self, x):
return F.linear(x, self.weight)
linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
linear_module_input = torch.rand(8, 5)
tests = [
(False, Conv, (conv_weight,), (conv_input,),
ns.call_function(torch.ops.quantized.conv2d),
ns.call_function(torch.ops.quantized.conv2d_prepack)),
(True, Linear, (linear_weight,), (linear_input,),
ns.call_function(torch.ops.quantized.linear_dynamic),
ns.call_function(torch.ops.quantized.linear_prepack)),
(False, Linear, (linear_weight,), (linear_input,),
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.linear_prepack)),
(True, LinearModule, (), (linear_module_input,),
ns.call_module(nnqd.Linear),
None),
(False, LinearModule, (), (linear_module_input,),
ns.call_module(nnq.Linear),
None),
]
return tests
"""
Unit tests for functionalities
"""
@skipIfNoFBGEMM
def test_functional_no_debug(self):
""" Test quantizing functional conv and linear
"""
tests = self._get_conv_linear_test_cases()
for (is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_node) in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node=quantized_node,
expected_node_occurrence=node_occurrence,
debug=False)
@skipIfNoFBGEMM
def test_functional_debug(self):
""" Test quantizing functional conv and linear with debug option
"""
tests = self._get_conv_linear_test_cases()
for (is_dynamic, ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_node) in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 0
node_occurrence[quantized_node] = 0
self.checkGraphModeFxOp(
ModuleClass(*module_constructor_inputs),
inputs, quant_type,
expected_node_occurrence=node_occurrence,
debug=True)
@skipIfNoFBGEMM
def test_dynamic_quant_weight_observer(self):
''' Test that weight observer is run in convert step
'''
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
def forward(self, x):
return F.linear(x, self.weight)
m = M(torch.rand(1, 1)).eval()
qconfig = default_dynamic_qconfig
qconfig_dict = {'': qconfig}
prepared = prepare_fx(m, qconfig_dict)
quantized = convert_fx(prepared, debug=True)
qparams = (quantized._scale_0, quantized._zero_point_0)
weight_obs = qconfig.weight()
weight_obs(quantized.weight)
ref_qparams = weight_obs.calculate_qparams()
self.assertEqual(qparams, ref_qparams)
def test_conv_bn_relu(self):
convs = {
1: nn.Conv1d,
2: nn.Conv2d,
3: nn.Conv3d,
}
bns = {
1: nn.BatchNorm1d,
2: nn.BatchNorm2d,
3: nn.BatchNorm3d,
}
quantized_convs = {
1: nnq.Conv1d,
2: nnq.Conv2d,
3: nnq.Conv3d,
}
quantized_conv_relus = {
1: nniq.ConvReLU1d,
2: nniq.ConvReLU2d,
3: nniq.ConvReLU3d,
}
class M(torch.nn.Module):
def __init__(self, dim, has_relu):
super().__init__()
self.conv = convs[dim](3, 3, 3)
self.bn = bns[dim](3)
self.relu = nn.ReLU() if has_relu else nn.Identity()
self.has_relu = has_relu
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x):
x = self.quant(x)
x = self.conv(x)
x = self.bn(x)
if self.has_relu:
x = self.relu(x)
x = self.dequant(x)
return x
options = itertools.product([1, 2], [True, False], self.static_quant_types)
for dim, has_relu, quant_type in options:
expected_node = ns.call_module(
quantized_conv_relus[dim] if has_relu
else quantized_convs[dim])
m = M(dim, has_relu)
m_eager = copy.deepcopy(m)
result = self.checkGraphModeFxOp(
m,
self.img_data_dict[dim],
quant_type,
expected_node=expected_node,
)
# check numerics
qengine = torch.backends.quantized.engine
if quant_type == QuantType.STATIC:
m_eager.eval()
qconfig = get_default_qconfig(qengine)
prepare_fn = prepare
else:
m_eager.train()
qconfig = get_default_qat_qconfig(qengine)
prepare_fn = prepare_qat
fuse_list = ["conv", "bn"]
if has_relu:
fuse_list.append("relu")
fuse_modules(m_eager, fuse_list, inplace=True)
m_eager.qconfig = qconfig
m_eager = prepare_fn(m_eager)
m_eager(*self.img_data_dict[dim][0])
m_eager = convert(m_eager)
result_eager = m_eager(*self.img_data_dict[dim][0])
self.assertEqual(result, result_eager)
@skipIfNoFBGEMM
def test_dynamic_quant_fp16(self):
class Linear(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = torch.nn.Parameter(weight)
def forward(self, x):
return F.linear(x, self.weight)
linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
linear_module_input = torch.rand(8, 5)
tests = [
(Linear, (linear_weight,), (linear_input,),
ns.call_function(torch.ops.quantized.linear_dynamic),
ns.call_function(torch.ops.quantized.linear_prepack_fp16)),
(LinearModule, (), (linear_module_input,),
ns.call_module(nnqd.Linear),
None),
]
for (ModuleClass, module_constructor_inputs,
inputs, quantized_node, weight_prepack_node) in tests:
for debug in [True, False]:
node_occurrence = dict()
if weight_prepack_node:
node_occurrence[weight_prepack_node] = 0
m = ModuleClass(*module_constructor_inputs).eval()
qconfig_dict = {"": float16_dynamic_qconfig}
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m, debug=debug)
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
@override_qengines
def test_qat_prepare_device_affinity(self):
"""
Tests that FX QAT prepare pass respects device affinity
"""
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(1, 1, 1)
self.bn = nn.BatchNorm2d(1)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
model = Model()
qengine = torch.backends.quantized.engine
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)}
device = torch.device('cuda:0')
model.to(device)
# QAT prepare
model = prepare_qat_fx(model, qconfig_dict)
# ensure that running an input on CUDA works without any needed changes
input = torch.randn(4, 1, 4, 4, device=device)
model(input)
# ensure all buffers and parameters are on the device we expect
model_devices = {p.device for p in model.parameters()} | \
{p.device for p in model.buffers()}
self.assertEqual(len(model_devices), 1)
model_device = next(iter(model_devices))
self.assertEqual(model_device, device)
@skipIfNoFBGEMM
def test_dict_output(self):
""" Make sure quantization runs for models with dictionary output
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
return {"output": self.conv(x["input"])}
dict_input = {"input": torch.randn(1, 1, 1, 1)}
m = M().eval()
qconfig_dict = {"": default_qconfig}
m = prepare_fx(m, qconfig_dict)
m(dict_input)
m = convert_fx(m)
m(dict_input)
@override_qengines
def test_attention(self):
""" Make sure quantization runs for a corner case in attention module
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
q, k, v = x.chunk(3, dim=0)
q = q.contiguous().view(-1, 1).transpose(0, 1)
k = k.contiguous().view(-1, 1).transpose(0, 1)
v = v.contiguous().view(-1, 1).transpose(0, 1)
torch._assert(
k.size(1) == 1, "key size should be equal to 1"
)
r = torch.mm(k, v)
return q * k + r
tensor_input = torch.randn(3, 1, 1, 1)
m = M().eval()
qconfig_dict = {
"": None,
"object_type": [
(nn.Conv2d, default_qconfig),
]
}
# make sure it runs
m = prepare_fx(m, qconfig_dict)
m(tensor_input)
m = convert_fx(m)
m(tensor_input)
def _test_standalone_module(
self,
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check):
""" Test standalone module with different quantized input/quantized output
configurations
"""
class StandaloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.conv(x)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.standalone = StandaloneModule()
def forward(self, x):
x = self.conv(x)
x = self.standalone(x)
return x
class RefM(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M().eval()
original_ref_m = RefM().eval()
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
for is_name in [True, False]:
if is_name:
prepare_config = {
"standalone_module_name": [("standalone", None, interface_config)]
}
else:
prepare_config = {
"standalone_module_class": [(StandaloneModule, None, interface_config)]
}
original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)
qconfig_dict = {"": default_qconfig}
# check prepared model
m = prepare_fx(
original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config)
# calibration
m(data)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
# check converted/quantized model
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
res = m(data)
# quantize the reference model
ref_m = prepare_fx(original_ref_m_copy, qconfig_dict)
ref_m(data)
ref_m = convert_fx(ref_m)
ref_res = ref_m(data)
self.assertEqual(res, ref_res)
def test_standalone_module_float_interface(self):
float_interface_config = {
"input_quantized_idxs": [], # float input
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for input and output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d): 1,
ns.call_method("dequantize") : 1,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)
def test_standalone_module_quantized_interface(self):
quantized_interface_config = {
"input_quantized_idxs": [0], # quantized input
"output_quantized_idxs": [0], # quantized output
}
interface_config = quantized_interface_config
# observer for input and output of first conv
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2
}
# for output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 1
}
convert_count_check = {
# quantizing input for conv
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Conv2d) : 1,
# dequantizing output of standalone module
ns.call_method("dequantize") : 1,
}
standalone_convert_count_check = {
# quantization of input happens in parent module
# quantization of output happens in the quantized conv module
ns.call_function(torch.quantize_per_tensor) : 0,
ns.call_module(nnq.Conv2d): 1,
# dequantization for output happens in parent module
ns.call_method("dequantize") : 0,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)
@skipIfNoFBGEMM
def test_qconfig_none(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
m = M().eval()
qconfig_dict = {"": default_qconfig,
"module_name": [("conv2", None)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_module(nn.Conv2d),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
def test_qconfig_module_type(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
m = M().eval()
qconfig_dict = {"object_type": [(torch.nn.Conv2d, default_qconfig)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
def test_qconfig_function(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, x, y):
return x + y
m = M().eval()
qconfig_dict = {"object_type": [(operator.add, default_qconfig)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data, data)
m = convert_fx(m)
m(data, data)
# first conv is quantized, second conv is not quantized
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.ops.quantized.add),
ns.call_method("dequantize"),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
def test_qconfig_module_name_regex(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = nn.Conv2d(1, 1, 1)
self.conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
m = M().eval()
qconfig_dict = {"module_name_regex": [("conv*", default_qconfig)]}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_fx(m)
m(data)
# first conv is quantized, second conv is not quantized
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
def test_qconfig_precedence(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.linear = nn.Linear(1, 1)
self.conv = nn.Conv2d(1, 1, 1)
self.module_conv1 = nn.Conv2d(1, 1, 1)
self.module_conv2 = nn.Conv2d(1, 1, 1)
def forward(self, x):
# global
x = self.linear(x)
# global + object_type --> object_type
x = self.conv(x)
# global + object_type + module_name_regex --> module_name_regex
x = self.module_conv1(x)
# global + object_type + module_name_regex + module_name --> module_name
x = self.module_conv2(x)
return x
m = M().eval()
global_qconfig = default_qconfig
object_type_qconfig = default_dynamic_qconfig
module_name_regex_qconfig = float16_dynamic_qconfig
module_name_qconfig = default_qat_qconfig
qconfig_dict = {
"": global_qconfig,
"object_type": [(nn.Conv2d, object_type_qconfig)],
"module_name_regex": [("module_conv*", module_name_regex_qconfig)],
"module_name": [("module_conv2", module_name_qconfig)]}
m = prepare_fx(m, qconfig_dict)
self.assertEqual(m.linear.qconfig, global_qconfig)
self.assertEqual(m.conv.qconfig, object_type_qconfig)
self.assertEqual(m.module_conv1.qconfig, module_name_regex_qconfig)
self.assertEqual(m.module_conv2.qconfig, module_name_qconfig)
def test_remove_qconfig(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.avg_pool = torch.nn.AvgPool2d(1)
def forward(self, x):
return self.avg_pool(x)
m = M().eval()
qconfig_dict = {'': default_qconfig}
m = prepare_fx(m, qconfig_dict)
data = torch.randn(1, 1, 1, 1)
m(data)
m = convert_fx(m)
m(data)
for name, module in m.named_modules():
self.assertFalse(hasattr(module, 'qconfig'),
'qconfig is not removed for ' + name)
def test_default_quant_after_none_qconfig(self):
""" Make sure default quant is inserted properly"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = x.transpose(1, 2)
x = self.conv2(x)
m = M().eval()
qconfig_dict = {
"": default_qconfig,
"module_name": [
("conv1", None)
]
}
m = prepare_fx(m, qconfig_dict)
m = convert_fx(m)
def test_qconfig_for_call_method(self):
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = x.transpose(2, 3)
x = self.conv(x)
return x.transpose(2, 3)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.sub = Sub()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.sub(x)
x = self.conv2(x)
return x.transpose(2, 3)
qconfig_dict1 = {"": default_qconfig, "module_name": [("sub", None)]}
# since sub is configured to have qconfig None, we should dequantize the output
# of self.conv1 and quantize the input of self.conv2
# dequantize after conv2 should happen after transpose since
# it is configured with default_qconfig
# nodes in Sub module instance is not quantized
node_list1 = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_method("transpose"),
ns.call_module(nn.Conv2d),
ns.call_method("transpose"),
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("transpose"),
ns.call_method("dequantize")
]
qconfig_dict2 = {"": None, "module_name": [("sub", default_qconfig)]}
# Only nodes in Sub module instance are quantized
# the first transpose is not quantized because the input is not quantized
node_list2 = [
ns.call_module(nn.Conv2d),
ns.call_method("transpose"),
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_method("transpose"),
ns.call_method("dequantize"),
ns.call_module(nn.Conv2d),
ns.call_method("transpose"),
]
for qconfig_dict, node_list in [
(qconfig_dict1, node_list1),
(qconfig_dict2, node_list2)
]:
m = M().eval()
m = prepare_fx(m, qconfig_dict)
m(torch.randn(2, 1, 3, 3))
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node_list=node_list)
# make sure it runs
m(torch.randn(2, 1, 3, 3))
def test_qconfig_for_call_func(self):
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)
def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods2 = Linear()
def forward(self, x):
x = self.mods1(x)
x = self.mods2(x)
return x
model = M().eval()
qconfig_dict = {"": default_qconfig, "module_name": [("mods2", None)]}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))
m = convert_fx(m)
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.ops.quantized.linear),
ns.call_function(torch.ops.quantized.linear),
ns.call_method('dequantize'),
ns.call_function(torch.nn.functional.linear)
]
self.checkGraphModuleNodes(m, expected_node_list=node_list)
m(torch.rand(5, 5))
def test_preserve_attributes(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.conv(x)
m = M()
m.eval()
m.preserved_attr = 3
prepare_custom_config_dict = {
"preserved_attributes": ["preserved_attr"]
}
m = prepare_fx(m, {"": default_qconfig}, prepare_custom_config_dict)
def assertAttrPreserved(m):
self.assertTrue(hasattr(m, "preserved_attr"))
self.assertTrue(m.preserved_attr, 3)
assertAttrPreserved(m)
convert_custom_config_dict = {
"preserved_attributes": ["preserved_attr"]
}
m = convert_fx(m, convert_custom_config_dict=convert_custom_config_dict)
assertAttrPreserved(m)
@skipIfNoFBGEMM
def test_qat_and_script(self):
model = LinearModelWithSubmodule().train()
qengine = torch.backends.quantized.engine
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig(qengine)}
model = prepare_qat_fx(model, qconfig_dict)
# ensure scripting works
scripted = torch.jit.script(model)
# run one round to make sure model runs
x = torch.randn(5, 5)
scripted(x)
FileCheck().check_count('FakeQuantize = prim::GetAttr[name="', 4, exactly=True) \
.run(scripted.graph)
# disable fake_quant and observer
for epoch in range(3):
if epoch == 1:
scripted.apply(torch.quantization.disable_observer)
if epoch == 2:
scripted.apply(torch.quantization.disable_fake_quant)
# ensure the fake_quant and observer have been disabled.
matches = ['.fake_quant_enabled', '.observer_enabled']
for key, v in scripted.state_dict().items():
if any(x in key for x in matches):
self.assertEqual(v, torch.tensor([0], dtype=torch.uint8))
# enable them back
scripted.apply(torch.quantization.enable_fake_quant)
scripted.apply(torch.quantization.enable_observer)
for key, v in scripted.state_dict().items():
if any(x in key for x in matches):
self.assertEqual(v, torch.tensor([1], dtype=torch.uint8))
@skipIfNoFBGEMM
def test_save_observer_state_dict(self):
orig = LinearModelWithSubmodule().eval()
model = orig
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
model = prepare_fx(model, qconfig_dict)
# run it through input
x = torch.randn(5, 5)
model(x)
quant = convert_fx(model)
# save state_dict of model
obs_dict = torch.quantization.get_observer_state_dict(model)
b = io.BytesIO()
torch.save(obs_dict, b)
b.seek(0)
# Load the stats into new model
model_2 = orig
model_2 = prepare_fx(model_2, qconfig_dict)
loaded_dict = torch.load(b)
torch.quantization.load_observer_state_dict(model_2, loaded_dict)
quant_2 = convert_fx(model_2)
# Verify that loaded state dict produces same results.
self.assertEqual(quant(x), quant_2(x))
@skipIfNoFBGEMM
def test_custom_module_class(self):
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
def forward(self, x):
return self.linear(x)
class ObservedCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_float(cls, float_module):
assert hasattr(float_module, 'qconfig')
observed = cls(float_module.linear)
observed.qconfig = float_module.qconfig
return observed
class StaticQuantCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_observed(cls, observed_module):
assert hasattr(observed_module, 'qconfig')
assert hasattr(observed_module, 'activation_post_process')
observed_module.linear.activation_post_process = \
observed_module.activation_post_process
quantized = cls(nnq.Linear.from_float(observed_module.linear))
return quantized
class DynamicQuantCustomModule(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear
def forward(self, x):
return self.linear(x)
@classmethod
def from_observed(cls, observed_module):
assert hasattr(observed_module, 'qconfig')
quantized = cls(nnqd.Linear.from_float(observed_module.linear))
return quantized
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
self.custom = CustomModule()
def forward(self, x):
x = self.linear(x)
x = self.custom(x)
return x
class RefM(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(3, 3)
self.linear2 = torch.nn.Linear(3, 3)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
data = torch.randn(3, 3)
# instantiate M and RefM and align the parameters
original_m = M().eval()
original_ref_m = RefM().eval()
original_ref_m.linear1.weight = torch.nn.Parameter(original_m.linear.weight.detach())
original_ref_m.linear1.bias = torch.nn.Parameter(original_m.linear.bias.detach())
original_ref_m.linear2.weight = torch.nn.Parameter(original_m.custom.linear.weight.detach())
original_ref_m.linear2.bias = torch.nn.Parameter(original_m.custom.linear.bias.detach())
test_configs = {
"static": (default_qconfig, StaticQuantCustomModule, 3),
"dynamic": (default_dynamic_qconfig, DynamicQuantCustomModule, 0)
}
for quant_type in [QuantType.DYNAMIC]:
key = quant_type_to_str(quant_type)
qconfig, quantized_module_class, num_observers = test_configs[key]
qconfig_dict = {"": qconfig}
if key == "static":
prepare_custom_config_dict = {
"float_to_observed_custom_module_class": {
"static": {
CustomModule: ObservedCustomModule
}
}
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"static": {
ObservedCustomModule: quantized_module_class
}
}
}
else:
prepare_custom_config_dict = {
"non_traceable_module_class": [
CustomModule
]
}
convert_custom_config_dict = {
"observed_to_quantized_custom_module_class": {
"dynamic": {
CustomModule: quantized_module_class
}
}
}
# check prepared model
m = prepare_fx(
original_m,
qconfig_dict,
prepare_custom_config_dict=prepare_custom_config_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module
count_check = {
ns.call_module(torch.quantization.MinMaxObserver): num_observers
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
# check converted/quantized model
m = convert_fx(
m,
convert_custom_config_dict=convert_custom_config_dict)
if quant_type == QuantType.STATIC:
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnq.Linear) : 1,
ns.call_method('dequantize') : 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
self.assertEqual(type(m.custom), quantized_module_class)
res = m(data)
# quantize the reference model
ref_m = prepare_fx(original_ref_m, qconfig_dict)
ref_m(data)
ref_m = convert_fx(ref_m)
ref_res = ref_m(data)
self.assertEqual(res, ref_res)
@skipIfNoFBGEMM
def test_non_traceable_module(self):
class NonTraceable(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
for k in x.keys():
print(x[k])
return x
class NonTraceable2(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
# data dependent control flow is not traceable
for i in x:
print(i)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.m1 = NonTraceable()
self.m2 = NonTraceable2()
def forward(self, x):
x = self.m1(x)
x = self.m2(x)
return x
m = M().eval()
qconfig_dict = {"": default_qconfig}
prepare_custom_config_dict = {
"non_traceable_module_name": [
"m1"
],
"non_traceable_module_class": [
NonTraceable2
]
}
m = prepare_fx(
m, qconfig_dict,
prepare_custom_config_dict=prepare_custom_config_dict)
node_occurrence = {
ns.call_module(NonTraceable) : 1,
ns.call_module(NonTraceable2) : 1,
}
# make sure these modules are not traced
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_prepared_model_deepcopy(self):
"""Ensures that copy.deepcopy works correctly on a prepared model.
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self._foobar = 'foobar'
self.foobar2 = 'foobar2'
def forward(self, x):
x = self.conv(x)
return x
m = M()
m.eval()
qconfig_dict = {'': torch.quantization.default_qconfig}
prepared = prepare_fx(m, qconfig_dict)
# calibrate
prepared(torch.randn(4, 1, 4, 4))
# copy
prepared_copy = copy.deepcopy(prepared)
# quantize, should run with no errors
quantized = convert_fx(prepared_copy)
def test_dequantize(self):
r""" Test to make sure dequantize node are placed before
non-quantizable node
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.act = torch.nn.GELU()
def forward(self, x):
x = self.conv(x)
return self.act(x)
data = torch.rand(5, 1, 3, 3, dtype=torch.float)
for quant_type in self.static_quant_types:
node_list = [
ns.call_module(nnq.Conv2d),
ns.call_method("dequantize"),
ns.call_module(nn.GELU),
]
self.checkGraphModeFxOp(
M().eval(), (data,), quant_type, expected_node_list=node_list)
def test_sequential(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.convs = torch.nn.Sequential(
torch.nn.Conv2d(1, 1, 1),
torch.nn.Conv2d(1, 1, 1)
)
def forward(self, x):
x = self.convs(x)
return x
data = torch.rand(5, 1, 3, 3, dtype=torch.float)
for quant_type in self.static_quant_types:
node_list = [
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
]
self.checkGraphModeFxOp(
M().eval(), (data,), quant_type, expected_node_list=node_list)
def _test_quantized_inputs_outputs(
self, prepare_custom_config_dict, prepare_count_check,
convert_count_check):
"""
Test the option to have inputs and outputs of the graph quantized
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
# quantized input, quantized output
m = M()
qconfig_dict = {'': torch.quantization.default_qconfig}
m.eval()
mp = torch.quantization.quantize_fx.prepare_fx(
m, qconfig_dict,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
mp(torch.randn(1, 1, 4, 4))
mq = torch.quantization.quantize_fx.convert_fx(mp)
self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)
def test_quantized_input_quantized_output(self):
prepare_custom_config_dict = {
'input_quantized_idxs': [0], 'output_quantized_idxs': [0]}
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 0,
ns.call_method('dequantize'): 0,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_quantized_output(self):
prepare_custom_config_dict = {
'output_quantized_idxs': [0]}
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_method('dequantize'): 0,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_quantized_input_fp32_output(self):
prepare_custom_config_dict = {
'input_quantized_idxs': [0]}
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 2,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 0,
ns.call_method('dequantize'): 1,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_fp32_output(self):
prepare_custom_config_dict = {}
prepare_count_check = {
ns.call_module(torch.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_method('dequantize'): 1,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
@skipIfNoFBGEMM
def test_convtranspose_per_channel_fails_early(self):
r"""
Verifies that attempting to quantize a ConvTranspose module with per-Channel
weight observers fails in the prepare step, as opposed to the convert step.
"""
m = torch.nn.Sequential(torch.nn.ConvTranspose2d(1, 1, 1))
m.eval()
qconfig_dict = {'': torch.quantization.get_default_qconfig('fbgemm')}
with self.assertRaises(AssertionError) as context:
mp = prepare_fx(m, qconfig_dict)
self.assertTrue(
str(context.exception) ==
'Per channel weight observer is not supported yet for ConvTranspose{n}d.')
@skipIfNoFBGEMM
def test_qparams_buffers(self):
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.ones(5, 5)
self.b = torch.zeros(5)
def forward(self, x):
return torch.nn.functional.linear(x, self.w, self.b)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.mods1 = torch.nn.Sequential(
Linear(),
Linear()
)
self.mods2 = Linear()
def forward(self, x):
x = self.mods1(x)
x = self.mods2(x)
return x
model = M().eval()
qconfig_dict = {"": default_qconfig}
m = prepare_fx(model, qconfig_dict)
m(torch.rand(5, 5))
m = convert_fx(m)
keys = m.state_dict().keys()
scale_count = 0
zero_point_count = 0
for k in keys:
if 'scale' in k:
scale_count = scale_count + 1
elif 'zero_point' in k:
zero_point_count = zero_point_count + 1
# Expect each quantized linear op to have a scale and zero point
self.assertTrue(scale_count == 3, "Expect each quantized linear op to have a scale in state_dict")
self.assertTrue(zero_point_count == 3, "Expect each quantized linear op to have a zero_point in state_dict")
# ensure it runs
m(torch.rand(5, 5))
# ensure it is scriptable
scripted = torch.jit.script(m)
scripted_keys = scripted.state_dict().keys()
self.assertTrue(scripted_keys == keys, "Expected the scripted model to preserve the state_dict")
@skipIfNoFBGEMM
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
@skipIfNoFBGEMM
def test_linear_module(self):
class ModuleLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(ModuleLinear, self).__init__()
self.linear = torch.nn.Linear(30, 4).float()
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(self.linear(x))
data = (torch.rand((1, 30), dtype=torch.float),)
options = itertools.product(
[ModuleLinear(has_relu=False)],
self.all_quant_types)
quantized_nodes = {
# quant_type:
QuantType.DYNAMIC: ns.call_module(nnqd.Linear),
QuantType.STATIC: ns.call_module(nnq.Linear),
# note that we are checking the final result
QuantType.QAT: ns.call_module(nnq.Linear),
}
for model, quant_type in options:
self.checkGraphModeFxOp(
model, data, quant_type, quantized_nodes[quant_type])
for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
for model, quantized_node in [
(ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]:
self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
@skipIfNoFBGEMM
def test_functional_linear(self):
class FuncLinear(torch.nn.Module):
def __init__(self, use_bias, has_relu, f_relu):
super(FuncLinear, self).__init__()
self.w = torch.randn(4, 30)
self.b = torch.randn(4)
self.use_bias = use_bias
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
if self.use_bias:
x = F.linear(x, self.w, self.b)
else:
x = F.linear(x, self.w)
x = self.relu(x)
return x
data = (torch.rand((1, 30), dtype=torch.float),)
quant_type_to_prepare_expected_node_occurrence = {
QuantType.DYNAMIC: {},
# There should be 3 observers: after input, weight and activation.
QuantType.STATIC: {
ns.call_module(torch.quantization.HistogramObserver): 2,
ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1,
},
# There should be 3 observers: after input, weight and activation.
QuantType.QAT: {
ns.call_module(torch.quantization.FakeQuantize): 3,
},
}
quant_type_to_qlinear_fun = {
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
}
quant_type_to_qlinear_relu_fun = {
# we don't have linear_relu_dynamic
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear_relu),
QuantType.QAT: ns.call_function(torch.ops.quantized.linear_relu),
}
options = itertools.product(
self.all_quant_types,
(True, False), # use_bias
(True, False), # has_relu
(True, False), # functional relu
)
for quant_type, use_bias, has_relu, f_relu in options:
model = FuncLinear(use_bias, has_relu, f_relu)
if has_relu:
qlinear_fun = quant_type_to_qlinear_relu_fun[quant_type]
else:
qlinear_fun = quant_type_to_qlinear_fun[quant_type]
convert_node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
qlinear_fun: 1,
ns.call_method("dequantize"): 1 if quant_type != QuantType.DYNAMIC else 0
}
prepare_expected_node_occurrence = \
quant_type_to_prepare_expected_node_occurrence[quant_type]
self.checkGraphModeFxOp(
model, data, quant_type, qlinear_fun,
prepare_expected_node_occurrence=prepare_expected_node_occurrence,
expected_node_occurrence=convert_node_occurrence)
@skipIfNoFBGEMM
def test_conv_module(self):
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class ConvWrapper(torch.nn.Module):
def __init__(self, dim):
super(ConvWrapper, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return self.conv(x)
options = itertools.product([1, 2, 3], self.static_quant_types)
quantized_nodes = {
# dim
1: ns.call_module(nnq.Conv1d),
2: ns.call_module(nnq.Conv2d),
3: ns.call_module(nnq.Conv3d),
}
for dim, quant_type in options:
model = self.checkGraphModeFxOp(
ConvWrapper(dim), self.img_data_dict[dim], quant_type,
quantized_nodes[dim])
@skipIfNoFBGEMM
def test_functional_conv(self):
""" Test for function conv and functional conv + relu
"""
class FuncConv(torch.nn.Module):
def __init__(self, use_bias, has_relu, f_relu):
super().__init__()
self.w = torch.randn(3, 3, 3, 3)
self.b = torch.randn(3) if use_bias else None
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
self.use_bias = use_bias
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
x = F.conv2d(x, self.w, self.b, self.stride, self.padding, self.dilation, self.groups)
x = self.relu(x)
return x
data = (torch.randn((2, 3, 4, 4), dtype=torch.float),)
quant_type_to_prepare_expected_node_occurrence = {
QuantType.DYNAMIC: {},
# There should be 3 observers: after input, weight and activation.
QuantType.STATIC: {
ns.call_module(torch.quantization.HistogramObserver): 2,
ns.call_module(torch.quantization.PerChannelMinMaxObserver): 1,
},
# There should be 3 observers: after input, weight and activation.
QuantType.QAT: {
ns.call_module(torch.quantization.FakeQuantize): 3,
},
}
quant_type_to_qconv_fun = {
QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d),
QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d),
}
quant_type_to_qconv_relu_fun = {
QuantType.STATIC: ns.call_function(torch.ops.quantized.conv2d_relu),
QuantType.QAT: ns.call_function(torch.ops.quantized.conv2d_relu),
}
options = itertools.product(
self.static_quant_types,
(True, False), # use_bias
(True, False), # has_relu
(True, False), # functional relu
)
for quant_type, use_bias, has_relu, f_relu in options:
model = FuncConv(use_bias, has_relu, f_relu)
if has_relu:
qconv_fun = quant_type_to_qconv_relu_fun[quant_type]
else:
qconv_fun = quant_type_to_qconv_fun[quant_type]
convert_node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1 if quant_type != QuantType.DYNAMIC else 0,
qconv_fun: 1,
ns.call_method("dequantize"): 1 if quant_type != QuantType.DYNAMIC else 0
}
prepare_expected_node_occurrence = \
quant_type_to_prepare_expected_node_occurrence[quant_type]
self.checkGraphModeFxOp(
model, data, quant_type, qconv_fun,
prepare_expected_node_occurrence=prepare_expected_node_occurrence,
expected_node_occurrence=convert_node_occurrence)
@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class ConvNdRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(ConvNdRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNdFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x))
class ConvNdInplaceFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdInplaceFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x), True)
options = itertools.product([1, 2, 3], self.static_quant_types)
quantized_nodes = {
# dim
1: ns.call_module(nniq.ConvReLU1d),
2: ns.call_module(nniq.ConvReLU2d),
3: ns.call_module(nniq.ConvReLU3d),
}
for dim, quant_type in options:
for m in [ConvNdRelu(dim, True),
ConvNdRelu(dim, False),
ConvNdFunctionalRelu(dim),
ConvNdInplaceFunctionalRelu(dim)]:
self.checkGraphModeFxOp(
m, self.img_data_dict[dim], quant_type,
quantized_nodes[dim])
def _test_quantized_binary_op_impl(self, binary_op, ibinary_op, quantized_op):
class Op(torch.nn.Module):
def __init__(self, is_inplace, is_scalar):
super(Op, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
self.is_scalar = is_scalar
self.op = ibinary_op if is_inplace else binary_op
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
# x = x + y
x = self.op(x, y)
# x = y + x
x = self.op(y, x)
return x
# TODO: decide whether we want to quantize or not
# in this case
# class NonQuantizedOp(torch.nn.Module):
# def __init__(self, is_inplace, is_scalar):
# super(NonQuantizedOp, self).__init__()
# self.is_scalar = is_scalar
# self.op = ibinary_op if is_inplace else binary_op
# def forward(self, x, y):
# y = 3 if self.is_scalar else y
# x = self.op(x, y)
# return x
data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
torch.randn(1, 1, 1, 1, dtype=torch.float))
quantized_node = ns.call_function(quantized_op)
options = itertools.product([True, False], [True, False])
quant_type = QuantType.STATIC
for is_inplace, is_scalar in options:
self.checkGraphModeFxOp(
Op(is_inplace, is_scalar), data, quant_type, quantized_node)
def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op, quantized_op):
class OpRelu(torch.nn.Module):
def __init__(self, is_inplace, is_functional_relu,
is_scalar):
super(OpRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
self.op = ibinary_op if is_inplace else binary_op
self.is_functional_relu = is_functional_relu
self.is_scalar = is_scalar
self.relu = F.relu if self.is_functional_relu \
else torch.nn.ReLU()
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
x = self.op(x, y)
x = self.relu(x)
x = self.op(y, x)
x = self.relu(x)
return x
data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
torch.rand((1, 1, 1, 1), dtype=torch.float))
quant_type = QuantType.STATIC
quantized_node = ns.call_function(quantized_op)
options = itertools.product(
[True, False], [True, False], [True, False])
for is_inplace_op, is_functional_relu, is_scalar in options:
self.checkGraphModeFxOp(
OpRelu(is_inplace_op, is_functional_relu, is_scalar),
data, quant_type, quantized_node)
@skipIfNoFBGEMM
def test_quantized_add(self):
self._test_quantized_binary_op_impl(
operator.add, operator.iadd, torch.ops.quantized.add)
@skipIfNoFBGEMM
def test_quantized_mul(self):
self._test_quantized_binary_op_impl(
operator.mul, operator.imul, torch.ops.quantized.mul)
@skipIfNoFBGEMM
def test_quantized_add_relu(self):
self._test_quantized_binary_op_relu_impl(
operator.add, operator.iadd, torch.ops.quantized.add_relu)
@skipIfNoFBGEMM
def test_quantized_mul_relu(self):
self._test_quantized_binary_op_relu_impl(
operator.mul, operator.imul, torch.ops.quantized.mul_relu)
# TODO(future PR): make more generic
def _test_quantized_add_mul_qat(self, model, expected_node_occurrence):
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
mp = torch.quantization.quantize_fx.prepare_qat_fx(model, qconfig_dict)
self.checkGraphModuleNodes(
mp, expected_node_occurrence=expected_node_occurrence)
@skipIfNoFBGEMM
def test_quantized_add_qat(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = torch.add(x, 1.0)
x = self.conv1(x)
x = torch.add(x, 1.0)
x = torch.relu(x)
x = self.conv2(x)
return x
m = M()
expected_node_occurrence = {
ns.call_module(torch.quantization.FakeQuantize): 4,
}
self._test_quantized_add_mul_qat(m, expected_node_occurrence)
@skipIfNoFBGEMM
def test_quantized_mul_qat(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = torch.mul(x, 1.0)
x = self.conv1(x)
x = torch.mul(x, 1.0)
x = torch.relu(x)
x = self.conv2(x)
return x
m = M()
expected_node_occurrence = {
ns.call_module(torch.quantization.FakeQuantize): 4,
}
self._test_quantized_add_mul_qat(m, expected_node_occurrence)
def test_int8_input_no_unnecessary_fq(self):
"""
If the inputs to the graph are quantized and the only node
does not need an activation observer, verifies that the
activation observer is not inserted.
"""
class M(nn.Module):
def __init__(self, scalar):
super().__init__()
self.scalar = scalar
self.add_func = torch.nn.quantized.FloatFunctional()
def forward(self, x):
return self.add_func.add_scalar(x, self.scalar)
m = M(0.5)
mp = torch.quantization.quantize_fx.prepare_qat_fx(
m, {'': torch.quantization.get_default_qat_qconfig('fbgemm')},
prepare_custom_config_dict={"input_quantized_idxs": [0]})
expected_node_occurrence = {
ns.call_module(torch.quantization.FakeQuantize): 0,
}
self.checkGraphModuleNodes(
mp, expected_node_occurrence=expected_node_occurrence)
def test_quant_output_always_observed(self):
"""
If the output is hardcoded to be quantized, ensure that
there is always an observer, even if the last non-output node is not
quantizeable.
"""
qconfig_dict = {'': torch.quantization.get_default_qat_qconfig('fbgemm')}
prepare_custom_config_dict = {'output_quantized_idxs': [0]}
data = (torch.randn(4, 1, 4, 4),)
# non-quantizeable node, quantized output
class M1(torch.nn.Module):
def __init__(self):
super().__init__()
self.identity = torch.nn.Identity()
def forward(self, x):
x = self.identity(x)
return x
m1 = M1()
self.checkGraphModeFxOp(
m1, data, QuantType.QAT,
prepare_expected_node_occurrence={
ns.call_module(torch.quantization.FakeQuantize): 1,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict)
# quantizeable node, quantized output
class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv(x)
return x
m2 = M2()
self.checkGraphModeFxOp(
m2, data, QuantType.QAT,
prepare_expected_node_occurrence={
# one for weights, one for activations
ns.call_module(torch.quantization.FakeQuantize): 2,
},
expected_node_occurrence={
ns.call_function(torch.quantize_per_tensor): 1,
},
prepare_custom_config_dict=prepare_custom_config_dict)
@skipIfNoFBGEMM
def test_quantized_cat(self):
""" quantization of the output of cat will be depend on the
input of cat. we only quantize the output of cat when its inputs are quantized.
"""
class QuantizedCat(torch.nn.Module):
def __init__(self):
super(QuantizedCat, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return torch.cat([x, y], 1)
# TODO: decide whether to quantize in this case
# class NonQuantizedCat(torch.nn.Module):
# def __init__(self):
# super(NonQuantizedCat, self).__init__()
# def forward(self, x, y):
# return torch.cat([x, y], 1)
data = (torch.randn(1, 2, 5, 5, dtype=torch.float),
torch.randn(1, 2, 5, 5, dtype=torch.float))
quantized_node = ns.call_function(torch.ops.quantized.cat)
for quant_type in self.static_quant_types:
self.checkGraphModeFxOp(QuantizedCat(), data, quant_type, quantized_node)
@skipIfNoFBGEMM
def test_qbatch_norm(self):
bn_module = {
# TODO: quantized batchnorm 1d module is missing
# 1 : torch.nn.BatchNorm1d,
2 : torch.nn.BatchNorm2d,
3 : torch.nn.BatchNorm3d,
}
class M(torch.nn.Module):
def __init__(self, dim):
super(M, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return self.bn(x)
options = itertools.product(self.static_quant_types, [2, 3])
quantized_nodes = {
# 1: ns.call_module(nnq.BatchNorm1d),
2: ns.call_module(nnq.BatchNorm2d),
3: ns.call_module(nnq.BatchNorm3d),
}
for quant_type, dim in options:
model = self.checkGraphModeFxOp(
M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim])
@skipIfNoFBGEMM
def test_qbatch_norm_relu(self):
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
class BNRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(BNRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
self.relu = torch.nn.ReLU(inplace=inplace)
def forward(self, x):
return self.relu(self.bn(x))
class BNFuncRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), False)
class BNFuncInplaceRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncInplaceRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), True)
options = itertools.product(self.static_quant_types, [2, 3])
quantized_nodes = {
2: ns.call_module(nniq.BNReLU2d),
3: ns.call_module(nniq.BNReLU3d),
}
for quant_type, dim in options:
for instance in [BNRelu(dim, True), BNRelu(dim, False),
BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
self.checkGraphModeFxOp(
instance, self.img_data_dict[dim], quant_type,
quantized_nodes[dim])
def _test_activation_impl(
self, float_module, float_op, quantized_module, quantized_op):
''' Test for activation op(with inplace options), float_op can be
torch op or functional op
'''
class M(torch.nn.Module):
def __init__(self, is_module, inplace):
super(M, self).__init__()
self.is_module = is_module
self.inplace = inplace
if self.is_module:
self.op = float_module(self.inplace)
else:
self.op = float_op
def forward(self, input):
if self.is_module:
return self.op(input)
else:
return self.op(input, self.inplace)
options = itertools.product([True, False], [True, False], self.static_quant_types)
quantized_nodes = {
# is_module
True: ns.call_module(quantized_module),
False: ns.call_function(quantized_op),
}
for is_module, is_inplace, quant_type in options:
self.checkGraphModeFxOp(
M(is_module, is_inplace), self.img_data_2d,
quant_type, quantized_nodes[is_module])
def test_hardswish(self):
self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish)
def test_elu(self):
self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu)
def test_leaky_relu(self):
self._test_activation_impl(nn.LeakyReLU, F.leaky_relu, nnq.LeakyReLU, torch.ops.quantized.leaky_relu)
def _test_norm_impl(
self, float_module, float_op, op_args, data, quantized_module, quantized_op,
skip_op_arg_for_functional=False):
''' Test for normalization op, float_op can be torch op or functional op,
op_args is a list of positional argument for the module/op
'''
class M(torch.nn.Module):
def __init__(self, is_module):
super(M, self).__init__()
self.is_module = is_module
if self.is_module:
self.op = float_module(*op_args)
else:
self.op = float_op
def forward(self, input):
if self.is_module:
return self.op(input)
else:
args = [input]
if not skip_op_arg_for_functional:
args += op_args
return self.op(*args)
options = itertools.product([True, False], self.static_quant_types)
quantized_nodes = {
# is_module
True: ns.call_module(quantized_module),
False: ns.call_function(quantized_op),
}
for is_module, quant_type in options:
self.checkGraphModeFxOp(
M(is_module), data, quant_type, quantized_nodes[is_module])
def test_layer_norm(self):
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
self._test_norm_impl(
nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm)
def test_instance_norm(self):
data_1d = (torch.rand((1, 4, 5), dtype=torch.float),)
data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),)
data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),)
data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d}
instance_norm_modules = {1 : nn.InstanceNorm1d,
2 : nn.InstanceNorm2d,
3 : nn.InstanceNorm3d}
quantized_instance_norm_modules = {
1 : nnq.InstanceNorm1d,
2 : nnq.InstanceNorm2d,
3 : nnq.InstanceNorm3d
}
for dim in [1, 2, 3]:
data = data_dict[dim]
module = instance_norm_modules[dim]
quantized_module = quantized_instance_norm_modules[dim]
self._test_norm_impl(
module, F.instance_norm, [4], data,
quantized_module, torch.ops.quantized.instance_norm,
skip_op_arg_for_functional=True)
@skipIfNoFBGEMM
def test_clamp(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu6 = torch.nn.ReLU6()
self.relu6_ = torch.nn.ReLU6(True)
self.hardtanh = torch.nn.Hardtanh()
self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.relu6(x)
self.relu6_(x)
x = F.relu6(x)
x = torch.clamp(x, -3, 3)
x = x.clamp(-2.5, 2.5)
# x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready
x = self.hardtanh(x)
self.hardtanh_(x)
x = F.hardtanh(x)
F.hardtanh_(x)
return x
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
# list of node that should occur in order
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_function(F.hardtanh_),
ns.call_method('dequantize')
]
for quant_type in self.static_quant_types:
m = self.checkGraphModeFxOp(
M(), data, quant_type, expected_node_list=node_list)
@skipIfNoFBGEMM
def test_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
self.dropout = torch.nn.Dropout()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
# add_scalar
x = x + 3
# mul_scalar
x = x * 3
# add_scalar_out
x += 3
# mul_scalar_out
x *= 3
# add_scalar_relu
x = x + 3
x = F.relu(x)
# add_scalar_relu_out
x += 3
x = F.relu(x)
# mul_scalar_relu
x = x * 3
x = F.relu(x)
# mul_scalar_relu_out
x *= 3
x = F.relu(x)
x = self.maxpool1d(x)
x = self.maxpool2d(x)
x = self.maxpool3d(x)
x = torch.flatten(x)
x = torch.max(x)
x = torch.min(x)
x = x.reshape([-1])
x = x.resize_(1, 1, x.numel())
x = x.view(-1)
# prim::ListConstruct
xs = [x, x]
# prim::ListUnpack
x, y = xs
# prim::TupleConstruct
xs = (x, x)
# prim::TupleUnpack
x, y = xs
x = x.transpose(1, 2)
x = x.contiguous()
x, y = torch.chunk(x, 2)
x = F.dropout(x)
x = self.dropout(x)
x, _ = torch.sort(x)
x = x.permute(0, 2, 3, 1)
x = x.repeat_interleave(3, 1)
x = torch.repeat_interleave(x, 3, 1)
x = self.relu(x)
x = F.relu(x)
x = F.relu(x, inplace=True)
x = x.relu()
x.relu_()
x = x.squeeze(0)
x.squeeze_(0)
x = torch.squeeze(x, 0)
x = x.unsqueeze(0)
x.unsqueeze_(0)
x = torch.unsqueeze(x, 0)
x = x.detach()
x.detach_()
x = x.repeat(4, 2)
y = []
y.append(x)
z = torch.stack(y, 0)
z = [z, z]
x, _ = z
x = self.conv2(x)
return x
data = torch.rand(1, 3, 10, 10)
# This model is not executable since we just put all ops
# in the same forward
m = M().eval()
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(m, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)
@skipIfNoFBGEMM
def test_general_value_ops(self):
""" A test that checks correct patterns are produced for
all supported general value ops like aten::avg_pool2d \
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.avg_pool1d = torch.nn.AvgPool1d(3)
self.avg_pool2d = torch.nn.AvgPool2d(3)
self.avg_pool3d = torch.nn.AvgPool3d(3)
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
def forward(self, x):
x = self.conv(x)
x = self.avg_pool1d(x)
x = self.avg_pool2d(x)
x = self.avg_pool3d(x)
x = self.adaptive_avg_pool1d(x)
x = self.adaptive_avg_pool2d(x)
x = self.adaptive_avg_pool3d(x)
x = F.avg_pool1d(x, 3)
x = F.avg_pool2d(x, 3)
x = F.avg_pool3d(x, 3)
x = F.adaptive_avg_pool1d(x, (1))
x = F.adaptive_avg_pool2d(x, (1, 1))
x = F.adaptive_avg_pool3d(x, (1, 1, 1))
x = torch.mean(x)
x = torch.mean(x, [2, 3], False)
x = x.mean()
x = x.mean([2, 3], True)
x = F.interpolate(x, 4, mode='nearest')
x = F.interpolate(x, 4, mode='linear')
x = self.conv(x)
return x
# This model is not executable since we just put all ops
# in the same forward
m = M().eval()
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(m, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)
@skipIfNoFBGEMM
def test_fixed_qparams_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.sigmoid = torch.nn.Sigmoid()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.tanh = torch.nn.Tanh()
def forward(self, x):
x = self.conv(x)
# F.sigmoid is deprecated
x = self.sigmoid(x)
x = torch.sigmoid(x)
x = x.sigmoid()
x.sigmoid_()
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x = F.hardsigmoid(x, inplace=True)
x = x.hardsigmoid()
x.hardsigmoid_()
x = self.tanh(x)
# F.tanh is deprecated
x = torch.tanh(x)
x = x.tanh()
x.tanh_()
x = self.conv(x)
return x
for eval_mode in [True, False]:
# This model is not executable since we just put all ops
# in the same forward
m = M()
if eval_mode:
m.eval()
qconfig = default_qconfig
prepare = prepare_fx
fq_count = 0
else:
m.train()
qconfig = default_qat_qconfig
prepare = prepare_qat_fx
fq_count = 13
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': qconfig}
prepared = prepare(m, qconfig_dict)
# check the correct number of activation_post_process is inserted
count_check = {
ns.call_module(FixedQParamsFakeQuantize) : fq_count,
}
self.checkGraphModuleNodes(
prepared,
expected_node_occurrence=count_check)
# not runnable
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nn.Sigmoid),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)
def test_float_functional(self):
class TorchAdd(nn.Module):
"""Wrapper around torch.add so that all ops can be found at build"""
def __init__(self):
super().__init__()
self.add_func = nnq.FloatFunctional()
def forward(self, x, y):
return self.add_func.add(x, y)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.ff1 = TorchAdd()
self.ff2 = nnq.FloatFunctional()
self.ff3 = nnq.FloatFunctional()
self.ff4 = nnq.FloatFunctional()
self.ff5 = nnq.FloatFunctional()
self.ff6 = nnq.FloatFunctional()
def forward(self, x):
x = self.ff1(x, x)
x = self.ff2.add_scalar(x, 3)
x = self.ff3.mul(x, x)
x = self.ff4.mul_scalar(x, 3)
x = self.ff5.add_relu(x, x)
x = self.ff6.cat([x])
return x
data = torch.rand(3, 3)
# Note: QAT test succeeded by chance, to make it actually work
# we need to fix eager mode FloatFunctional by removing
# activation_post_process in add_scalar and mul_scalar
for quant_type in self.static_quant_types:
m = M()
ref_m = torch.quantization.QuantWrapper(M())
is_qat = quant_type == QuantType.QAT
if is_qat:
m.train()
ref_m.train()
qconfig = default_qat_qconfig
expected_act_post_process = torch.quantization.FakeQuantize
else:
m.eval()
ref_m.eval()
qconfig = default_qconfig
expected_act_post_process = torch.quantization.MinMaxObserver
prepare_fx_function = prepare_qat_fx if is_qat else prepare_fx
qconfig_dict = {"": qconfig}
m = prepare_fx_function(m, qconfig_dict)
node_occurrence = {
ns.call_module(expected_act_post_process): 5,
ns.call_module(torch.nn.quantized.FloatFunctional): 0
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
m(data)
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_function(torch.ops.quantized.add),
ns.call_function(torch.ops.quantized.add),
ns.call_function(torch.ops.quantized.mul),
ns.call_function(torch.ops.quantized.mul),
ns.call_function(torch.ops.quantized.add_relu),
ns.call_function(torch.ops.quantized.cat),
ns.call_method('dequantize')
]
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node_list=node_list)
# make sure numerics match with eager mode
ref_m.qconfig = qconfig
prepare_function = prepare_qat if is_qat else prepare
ref_m = prepare_function(ref_m)
ref_m(data)
ref_m = convert(ref_m)
self.assertEqual(m(data), ref_m(data))
def test_embedding(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.Embedding(num_embeddings=10, embedding_dim=12)
def forward(self, indices):
return self.emb(indices)
model = M().eval()
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
quantized_node = ns.call_module(nnq.Embedding)
configs = [
(float_qparams_weight_only_qconfig, ns.call_module(nnq.Embedding)),
(None, ns.call_module(nn.Embedding)),
(default_qconfig, ns.call_module(nn.Embedding)),
]
for qconfig, node in configs:
qconfig_dict = {"": qconfig}
m = prepare_fx(model, qconfig_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence={
ns.call_module(torch.quantization.MinMaxObserver): 0
})
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node=node)
# make sure it runs
m(indices)
def test_embedding_bag(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, include_last_offset=True)
def forward(self, indices, offsets):
return self.emb(indices, offsets)
indices = torch.tensor([9, 6, 5, 7, 8, 8, 9, 2, 8, 6, 6, 9, 1, 6, 8, 8, 3, 2, 3, 6, 3, 6, 5, 7, 0, 8, 4, 6, 5, 8, 2, 3])
offsets = torch.tensor([0, 19, 20, 28, 28, 32])
quantized_node = ns.call_module(nnq.EmbeddingBag)
inputs = (indices, offsets)
for dtype in [torch.quint8, torch.quint4x2]:
model = M().eval()
float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype,
qscheme=torch.per_channel_affine_float_qparams,
ch_axis=0)
float_qparams_qconfig = QConfigDynamic(activation=default_placeholder_observer,
weight=float_qparams_observer)
self.checkGraphModeFxOp(
model,
inputs,
QuantType.DYNAMIC,
quantized_node,
custom_qconfig=float_qparams_qconfig
)
# check it works in None and static qconfig
for qconfig in [None, default_qconfig]:
qconfig_dict = {"": default_qconfig}
m = M().eval()
m = prepare_fx(model, qconfig_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence={
ns.call_module(torch.quantization.MinMaxObserver): 0
})
m = convert_fx(m)
self.checkGraphModuleNodes(m, expected_node=ns.call_module(nn.EmbeddingBag))
# make sure it runs
m(*inputs)
def _test_rnn_impl(self, qconfigs, M, module_type_strs, module_types, sample_input):
options = itertools.product(qconfigs, module_type_strs)
for qconfig, module_type_str in options:
model_eager = M(module_type_str).eval()
model_graph = copy.deepcopy(model_eager)
if torch.backends.quantized.engine == 'qnnpack' and \
qconfig is float16_dynamic_qconfig:
continue
# fp16 dynamic quant is not supported for qnnpack
eager_qconfig_dict = {x : qconfig for x in module_types}
model_eager = quantize_dynamic(model_eager, qconfig_spec=eager_qconfig_dict)
graph_qconfig_dict = {
"object_type": [
(x, qconfig) for x in module_types
]
}
model_graph = prepare_fx(model_graph, graph_qconfig_dict)
model_graph = convert_fx(model_graph)
self.assertEqual(model_eager(sample_input), model_graph(sample_input))
self.checkScriptable(model_graph, [[sample_input]], True)
def test_rnn_cell(self):
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
module_type_strs = ['LSTMCell', 'GRUCell', 'RNNTanh', 'RNNReLU']
module_types = [torch.nn.LSTMCell, torch.nn.GRUCell, torch.nn.RNNCell]
sample_input = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float)
self._test_rnn_impl(qconfigs, RNNCellDynamicModel, module_type_strs, module_types, sample_input)
def test_rnn(self):
qconfigs = [per_channel_dynamic_qconfig, default_dynamic_qconfig, float16_dynamic_qconfig]
module_type_strs = ['LSTM']
module_types = [torch.nn.LSTM]
niter = 10
sample_input = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
self._test_rnn_impl(qconfigs, RNNDynamicModel, module_type_strs, module_types, sample_input)
def _test_conv_transpose_impl(
self, float_cls: Callable, q_cls: Callable, data: torch.Tensor):
with override_quantized_engine('qnnpack'):
# Create fp32 versions of FX and Eager models
m1 = torch.nn.Sequential(float_cls(1, 1, 1))
m2 = torch.nn.Sequential(float_cls(1, 1, 1))
m2.load_state_dict(m1.state_dict())
m2 = torch.quantization.QuantWrapper(m2)
# FX graph
q_result1 = self.checkGraphModeFxOp(
m1, (data,), QuantType.STATIC,
expected_node_occurrence={
ns.call_module(q_cls): 1,
})
# Eager
m2.qconfig = get_default_qconfig(torch.backends.quantized.engine)
m2.eval()
m2p = torch.quantization.prepare(m2)
m2p(data)
m2q = torch.quantization.convert(m2p)
q_result2 = m2q(data)
# verify results match
self.assertTrue(torch.allclose(q_result1, q_result2))
@unittest.skipUnless('qnnpack' in supported_qengines,
"This Pytorch Build has not been built with or does not support QNNPACK")
def test_conv_transpose_1d(self):
self._test_conv_transpose_impl(
torch.nn.ConvTranspose1d, nnq.ConvTranspose1d, torch.randn(4, 1, 4))
@unittest.skipUnless('qnnpack' in supported_qengines,
"This Pytorch Build has not been built with or does not support QNNPACK")
def test_conv_transpose_2d(self):
self._test_conv_transpose_impl(
torch.nn.ConvTranspose2d, nnq.ConvTranspose2d, torch.randn(4, 1, 4, 4))
class TestQuantizeFxModels(QuantizationTestCase):
def _test_model_impl(
self, mode, name, model, eager_quantizable_model,
check_with_eager=True,
diff_of_quant=None,
diff_from_eager=None):
if diff_of_quant is None or diff_from_eager is None:
diff_of_quant = {}
diff_from_eager = {}
if mode not in diff_of_quant or mode not in diff_from_eager:
diff_of_quant[mode] = {}
diff_from_eager[mode] = {}
input_tensor = torch.rand(1, 3, 224, 224)
input_tensor_inception = torch.rand(1, 3, 299, 299)
output_value = torch.randint(0, 1, (1,))
# print('quantizing:', name, ' mode:', mode)
if name == 'inception_v3':
input_value = input_tensor_inception
else:
input_value = input_tensor
qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
qconfig_dict = {'': qconfig}
# print('graph module:', graph_module.src)
script = torch.jit.script(model)
# make sure graph module and script module are both runanble
original_out = model(input_value)
is_not_tuple_out = not isinstance(original_out, tuple)
script_out = script(input_value)
# set to train just before quantization
prepare_fx_fn = prepare_fx
if mode != 'static':
model.train()
prepare_fx_fn = prepare_qat_fx
prepared = prepare_fx_fn(model, qconfig_dict)
if mode == 'ddp':
mp.spawn(run_ddp,
args=(world_size, prepared),
nprocs=world_size,
join=True)
elif mode == 'qat':
assert prepared.training, 'prepared must be in training mode for qat'
optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
else:
for i in range(10):
prepared(input_value)
# print('after observation root:', prepared.root)
qgraph = convert_fx(prepared)
# print('after quantization root:', qgraph.root)
# print('after quantization code:', qgraph.src)
qgraph.eval()
qgraph_script = torch.jit.script(qgraph)
# print('quantized and scripted:', qgraph_script.graph)
qgraph_out = qgraph(input_value)
qgraph_script = qgraph_script(input_value)
if is_not_tuple_out:
diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph'
else:
print('tuple output')
if eager_quantizable_model is not None:
# comparing to eager mode quantization
qeager = eager_quantizable_model
ref_out = qeager(input_value)
qeager.qconfig = qconfig
if mode == 'static':
qeager.fuse_model()
prepare(qeager, inplace=True)
else:
qeager.train()
qeager.fuse_model()
prepare_qat(qeager, inplace=True)
# calibration
if mode == 'ddp':
mp.spawn(run_ddp,
args=(world_size, qeager),
nprocs=world_size,
join=True)
elif mode == 'qat':
assert qeager.training, 'qeager should be in training mode for qat'
optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
else:
for i in range(10):
qeager(input_value)
# print('ref after observation:', qeager)
convert(qeager, inplace=True)
qeager.eval()
# print('ref after quantization:', qeager)
qeager_out = qeager(input_value)
qeager_script = torch.jit.script(qeager)
qscript_out = qeager_script(input_value)
if is_not_tuple_out:
diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max()
if check_with_eager:
self.assertEqual(diff_from_eager[mode][name], 0,
'Result of graph mode quantization and ' +
'eager mode quantization on model: ' + name +
' should match. Mode: ' + mode +
' diff:' + str(diff_from_eager[mode][name]))
def _test_building_block(self, quant_type, BB):
eager = BB().float()
graph = copy.deepcopy(eager)
if quant_type == QuantType.STATIC:
qconfig = default_qconfig
eager_prepare = prepare
graph_prepare = prepare_fx
eager.eval()
graph.eval()
calibrate_or_train = test_only_eval_fn
data = self.img_data_2d
else:
assert quant_type == QuantType.QAT
qconfig = default_qat_qconfig
eager_prepare = prepare_qat
graph_prepare = prepare_qat_fx
eager.train()
graph.train()
calibrate_or_train = test_only_train_fn
data = self.img_data_2d_train
if hasattr(eager, "fuse_model"):
eager.fuse_model()
eager = QuantWrapper(eager)
eager.qconfig = qconfig
eager = eager_prepare(eager)
qconfig_dict = {"": qconfig}
graph = graph_prepare(graph, qconfig_dict)
eager_out = eager(data[0][0])
graph_out = graph(data[0][0])
self.assertEqual(eager_out, graph_out)
calibrate_or_train(eager, data)
calibrate_or_train(graph, data)
eager = convert(eager)
graph = convert_fx(graph)
eager_out = eager(data[0][0])
graph_out = graph(data[0][0])
self.assertEqual(eager_out, graph_out)
@override_qengines
def test_resnet_base(self):
models = [ResNetBase]
options = itertools.product(self.static_quant_types, models)
for quant_type, M in options:
self._test_building_block(quant_type, M)
@skip_if_no_torchvision
@skipIfNoFBGEMM
@unittest.skip("skip for now since tbb failed")
def test_torchvision(self):
from torchvision import models
from torchvision.models import quantization as quantized_models
def get_available_classification_models(models):
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
model_list = get_available_classification_models(models)
quantized_model_list = get_available_classification_models(quantized_models)
no_pretrained_model = set(['shufflenet_v2_x0_5', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'])
quantized_model_list = set(quantized_model_list) - no_pretrained_model
# test eager and graph consistency
model_list = quantized_model_list
# inception_v3 is not symbolically traceable: https://github.com/pytorch/pytorch/issues/48813
model_list = set(model_list) - {'inception_v3'}
# mobilenet: dropout error RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'QUInt8'
# incpetion_v3: looks like there is some problem with AuxLogits
quantized_not_working = [('qat', 'inception_v3'),
('static', 'inception_v3')]
fx_eager_not_matching = ['googlenet', # because _transform_input is not quantized in eager
'mobilenet_v2'] # because relu6 is replaced as relu in mobilenetv2
diff_of_quant = {}
diff_from_eager = {}
modes = ['static', 'qat']
options = itertools.product(modes, model_list)
for mode, name in options:
pretrained = name in quantized_model_list # load pretrained model to compare with quantized model
if name in quantized_model_list:
if (mode, name) in quantized_not_working:
eager_quantizable_model = None
else:
eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float()
# compare with eager mode quantized model when it is available
pretrained = eager_quantizable_model is not None
model = models.__dict__[name](pretrained=pretrained).eval().float()
check_with_eager = name not in fx_eager_not_matching
self._test_model_impl(
mode, name, model, eager_quantizable_model,
check_with_eager,
diff_of_quant, diff_from_eager)
def print_diffs(diffs):
for mode, diffs_for_mode in diffs.items():
print('mode:', mode)
for name, diff in diffs_for_mode.items():
print(name, ':', diff)
# print('differences between float and quantized')
# print_diffs(diff_of_quant)
# print('----------------------')
# print('differences between graph mode and eager mode')
# print_diffs(diff_from_eager)
# print('----------------------')
@skip_if_no_torchvision
@skip_if_not_multigpu
@skipIfNoFBGEMM
def test_resnet18_ddp(self):
from torchvision import models
from torchvision.models import quantization as quantized_models
eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float()
model = models.__dict__[name](pretrained=True).eval().float()
self._test_model_impl(
'ddp', 'resnet18', model, eager_quantizable_model)