pytorch/benchmarks/operator_benchmark/pt/quantization_test.py
zaf 432f037498 [quant][ao_migration] torch.nn.quantized.modulestorch.ao.nn.quantized.modules (#78713)
Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.

The list of the `nn.quantized` files that are being migrated:

- [ ] `torch.nn.quantized` → `torch.ao.nn.quantized`
    - [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
    - [X] [Current PR] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
    - [ ] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
    - [ ] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [ ] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [ ] `torch.nn.qat` → `torch.ao.nn.qat`
    - [ ] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
    - [ ] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
    - [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
    - [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
    - [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
        - [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
        - [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`

Majority of the files are just moved to the new location.
However, specific files need to be double checked:

- Documentation @vkuzo
  - docs/source/conf.py
  - docs/source/quantization.rst
- [quantize_fx](torch/ao/quantization/quantize_fx.py) @jerryzh168
- [common test routine](test/quantization/ao_migration/common.py) @HDCharles
- JIT stuff @jamesr66a
  - torch/csrc/jit/passes/hoist_conv_packed_params.cpp
  - torch/csrc/jit/passes/quantization/helper.h
  - torch/csrc/jit/serialization/import_source.cpp

Differential Revision: [D36860145](https://our.internmc.facebook.com/intern/diff/D36860145/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78713
Approved by: https://github.com/jerryzh168
2022-08-22 01:38:55 +00:00

346 lines
11 KiB
Python

import operator_benchmark as op_bench
import torch
import torch.ao.nn.quantized as nnq
import torch.ao.quantization as tq
import torch.nn as nn
"""Microbenchmarks for general quantization operations."""
# mode is used to show the direction of the benchmark:
# if 'Q', benchmark quantization, else dequantization
quantize_configs_short_dict = {
'attr_names': ['C', 'M', 'N', 'dtype', 'mode'],
'attrs': [
[3, 512, 512, torch.quint8, 'Q'],
[3, 512, 512, torch.quint8, 'D'],
],
'tags': ['short'],
}
quantize_configs_long_dict = {
'C': [3, 5, 8], # this is reused for per-channel: avoid single channel test
'M': [256, 1024],
'N': [256, 1024],
'dtype': [torch.quint8, torch.qint8, torch.qint32],
'mode': ['D', 'Q'],
'tags': ['long'],
}
quantize_per_tensor_configs_short = op_bench.config_list(
**quantize_configs_short_dict
)
quantize_per_tensor_configs_long = op_bench.cross_product_configs(
**quantize_configs_long_dict
)
class QuantizePerTensorBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks both quantization and dequantization."""
def init(self, C, M, N, dtype, mode):
assert(mode in ('Q', 'D'))
self.input = torch.rand(C, M, N)
self.dtype = dtype
self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype)
self.set_module_name('QuantizePerTensor')
if mode == 'D':
self.input = self.op(self.input)
self.op = nnq.DeQuantize()
self.set_module_name('DequantizePerTensor')
self.inputs = {
"input": self.input
}
def forward(self, input):
return self.op(input)
op_bench.generate_pt_test(
quantize_per_tensor_configs_short + quantize_per_tensor_configs_long,
QuantizePerTensorBenchmark)
# === Per Channel quantization ===
quantize_per_channel_configs_short = op_bench.config_list(
cross_product_configs={
'axis': (0,)
},
**quantize_configs_short_dict
)
quantize_per_channel_configs_long = op_bench.cross_product_configs(
axis=(0, 1, 2),
**quantize_configs_long_dict
)
class QuantizePerChannelBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks both quantization and dequantization."""
def init(self, C, M, N, dtype, axis, mode):
assert(mode in ('Q', 'D'))
self.input = torch.rand(C, M, N)
self.op = torch.quantize_per_channel
channel_len = (C, M, N)[axis]
self.kwargs = {
'scales': torch.tensor([1.0] * channel_len),
'zero_points': torch.tensor([0] * channel_len),
'dtype': dtype,
'axis': axis
}
self.set_module_name('QuantizePerChannel')
if mode == 'D':
self.input = self.op(self.input, **self.kwargs)
def dequant(input, scales, zero_points, axis: int, dtype: int):
return input.dequantize()
self.op = dequant
self.set_module_name('DequantizePerChannel')
self.inputs = {
"input": self.input,
'scales': torch.tensor([1.0] * channel_len),
'zero_points': torch.tensor([0] * channel_len),
'axis': axis,
'dtype': dtype
}
def forward(self, input, scales, zero_points, axis: int, dtype: int):
return self.op(input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype)
op_bench.generate_pt_test(
quantize_per_channel_configs_short + quantize_per_channel_configs_long,
QuantizePerChannelBenchmark)
# === Fake Quantization ===
# Generated benchmarks names start with 'learnable_kernel' or 'original_kernel',
# for ex. 'original_kernel_nbits8_cpu_N1_C1_H256_W256_zero_point_dtypetorch.int32_bwdall'
fake_quantize_configs_short_dict = {
'attr_names': ['N', 'C', 'H', 'W', 'zero_point_dtype'],
'attrs': [
[1, 3, 512, 512, torch.int32],
],
'tags': ['short']
}
fake_quantize_configs_long_dict = {
'N': [1],
'C': [1, 3, 8, 32],
'H': [256, 1024],
'W': [256, 1024],
'zero_point_dtype': [torch.int32],
'tags': ['long']
}
fake_quantize_configs_short = op_bench.config_list(
cross_product_configs={
'device': ('cpu', 'cuda'),
},
**fake_quantize_configs_short_dict
)
fake_quantize_configs_long = op_bench.cross_product_configs(
device=('cpu', 'cuda'),
**fake_quantize_configs_long_dict
)
class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks fake quantization with default parameters."""
def init(self, N, C, H, W, zero_point_dtype, device):
self.inputs = {
"input": torch.rand(N, C, H, W).to(device)
}
self.op = tq.FakeQuantize().to(device)
self.set_module_name('FakeQuantize')
def forward(self, input):
return self.op(input)
op_bench.generate_pt_test(
fake_quantize_configs_short + fake_quantize_configs_long,
FakeQuantizeBenchmark)
# op_type is used to describe the type of operator used in benchmarking:
# learnable_kernel represents the c++ kernel that can backpropagate on
# scale and zero point.
# original_kernel represents the original fake quantize c++ kernel.
def fakeQuantizePerTensorLearnableKernel(
input, scale, zero_point,
quant_min: int, quant_max: int
):
return torch._fake_quantize_learnable_per_tensor_affine(input, scale, zero_point, quant_min, quant_max)
def fakeQuantizePerTensorOriginalKernel(
input, scale, zero_point,
quant_min: int, quant_max: int
):
return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max)
fake_quantize_per_tensor_ops = op_bench.op_list(
attrs=(
('learnable_kernel', fakeQuantizePerTensorLearnableKernel),
('original_kernel', fakeQuantizePerTensorOriginalKernel)
),
attr_names=('op_name', 'op_func'),
)
fake_quantize_operator_configs_short = op_bench.config_list(
cross_product_configs={
'nbits': (4, 8),
'device': ('cpu', 'cuda'),
},
**fake_quantize_configs_short_dict
)
fake_quantize_operator_configs_long = op_bench.cross_product_configs(
nbits=(4, 8),
device=('cpu', 'cuda'),
**fake_quantize_configs_long_dict
)
# TODO(future PR) Combine config for floating point zero_point with other configs, once it is
# fully supported in all fakeQuant operators and devices for
# https://github.com/pytorch/pytorch/issues/61866.
fake_quantize_configs_long_dict_float_zero_point = fake_quantize_configs_long_dict.copy()
fake_quantize_configs_long_dict_float_zero_point['zero_point_dtype'] = [torch.float32, torch.half]
fake_quantize_operator_configs_long_float_zero_point = op_bench.cross_product_configs(
nbits=(8,),
device=('cpu', 'cuda'),
**fake_quantize_configs_long_dict_float_zero_point
)
class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks 3 different fake quantize per tensor operators."""
def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
self.quant_min = 0
self.quant_max = 2 ** nbits - 1
self.quant_range = 2 ** nbits
self.input = nn.Parameter(torch.rand(N, C, H, W, dtype=torch.float, device=device), requires_grad=self.auto_set())
self.scale = nn.Parameter(torch.tensor([1.]).to(device), requires_grad=self.auto_set())
if op_func.__name__ == 'fakeQuantizePerChannelOriginalKernel':
self.zero_point = nn.Parameter(torch.tensor([0.]).to(device).to(zero_point_dtype), requires_grad=self.auto_set())
else:
self.zero_point = nn.Parameter(torch.tensor([0.]).to(device), requires_grad=self.auto_set())
self.inputs = {
"input": self.input,
"scale": self.scale,
"zero_point": self.zero_point,
"quant_min": self.quant_min,
"quant_max": self.quant_max,
}
self.op_func = op_func
def forward(
self, input, scale, zero_point,
quant_min: int, quant_max: int
):
return self.op_func(input, scale, zero_point, quant_min, quant_max)
op_bench.generate_pt_tests_from_op_list(
fake_quantize_per_tensor_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerTensorBaseOpBenchmark
)
op_bench.generate_pt_gradient_tests_from_op_list(
fake_quantize_per_tensor_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerTensorBaseOpBenchmark
)
def fakeQuantizePerChannelLearnableKernel(
input, scale, zero_point, axis: int,
quant_min: int, quant_max: int
):
return torch._fake_quantize_learnable_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max)
def fakeQuantizePerChannelOriginalKernel(
input, scale, zero_point, axis: int,
quant_min: int, quant_max: int
):
return torch.fake_quantize_per_channel_affine(input, scale, zero_point, axis, quant_min, quant_max)
fake_quantize_per_channel_ops = op_bench.op_list(
attrs=(
('learnable_kernel', fakeQuantizePerChannelLearnableKernel),
('original_kernel', fakeQuantizePerChannelOriginalKernel)
),
attr_names=('op_name', 'op_func'),
)
fake_quantize_per_channel_float_zero_point_ops = op_bench.op_list(
attrs=(
('original_kernel', fakeQuantizePerChannelOriginalKernel),
),
attr_names=('op_name', 'op_func'),
)
class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase):
r"""Benchmarks 3 different fake quantize per channel operators."""
def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
self.quant_min = 0
self.quant_max = 2 ** nbits - 1
self.quant_range = 2 ** nbits
# Axis is chosen with respect to the number of channels: C.
self.axis = 1
self.input = nn.Parameter(torch.rand(N, C, H, W, dtype=torch.float, device=device, requires_grad=self.auto_set()))
if op_func.__name__ == 'fakeQuantizePerChannelOriginalKernel':
self.scale = torch.ones(C, device=device, dtype=torch.float32, requires_grad=False)
self.zero_point = torch.zeros(C, device=device, dtype=zero_point_dtype, requires_grad=False)
else:
self.scale = nn.Parameter(torch.ones(C, device=device, dtype=torch.float32), requires_grad=self.auto_set())
self.zero_point = nn.Parameter(torch.zeros(C, device=device, dtype=torch.float32), requires_grad=self.auto_set())
self.inputs = {
"input": self.input,
"scale": self.scale,
"zero_point": self.zero_point,
"axis": self.axis,
"quant_min": self.quant_min,
"quant_max": self.quant_max,
}
self.op_func = op_func
def forward(
self, input, scale, zero_point,
axis: int, quant_min: int, quant_max: int
):
return self.op_func(input, scale, zero_point, axis, quant_min, quant_max)
op_bench.generate_pt_tests_from_op_list(
fake_quantize_per_channel_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerChannelOpBenchmark
)
op_bench.generate_pt_tests_from_op_list(
fake_quantize_per_channel_float_zero_point_ops,
fake_quantize_operator_configs_long_float_zero_point,
FakeQuantizePerChannelOpBenchmark
)
op_bench.generate_pt_gradient_tests_from_op_list(
fake_quantize_per_channel_ops,
fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
FakeQuantizePerChannelOpBenchmark
)
if __name__ == "__main__":
op_bench.benchmark_runner.main()