From 68758b2fa0ed03ea92bb2e7c8a6543f71c6ac2d9 Mon Sep 17 00:00:00 2001 From: Lingyi Liu Date: Fri, 13 Mar 2020 20:27:40 -0700 Subject: [PATCH] Add the quantized batch_norm3d and also batch_norm3d fused with relu operators (#34702) Summary: as title, for bringing up the quantized video model. Will add the batch_norm_relu test in another PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34702 Differential Revision: D20436092 Pulled By: lly-zero-one fbshipit-source-id: 116bd306f7880bfd763d8575654fbd6c92818338 --- .../ATen/native/quantized/cpu/qbatch_norm.cpp | 138 +++++++++++++++++- test/test_quantized.py | 29 ++++ test/test_quantized_nn_mods.py | 18 +++ torch/nn/quantized/modules/__init__.py | 3 +- torch/nn/quantized/modules/batchnorm.py | 56 +++++++ torch/quantization/default_mappings.py | 1 + 6 files changed, 239 insertions(+), 6 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp index 4afd02358af..52b002b7294 100644 --- a/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp +++ b/aten/src/ATen/native/quantized/cpu/qbatch_norm.cpp @@ -20,8 +20,8 @@ void compute_fused_params( const float* mean_data, const float* var_data, double eps, - float input_scale, - float output_scale, + double input_scale, + double output_scale, float* alpha_data, float* beta_data) { // Batch Normalization @@ -46,7 +46,7 @@ Tensor q_batch_norm_impl( Tensor mean, Tensor var, double eps, - float output_scale, + double output_scale, int64_t output_zero_point) { if (qx.numel() == 0) { @@ -112,6 +112,82 @@ Tensor q_batch_norm_impl( return qy; } +template +Tensor q_batch_norm3d_impl( + Tensor qx, + Tensor weight, + Tensor bias, + Tensor mean, + Tensor var, + double eps, + double output_scale, + int64_t output_zero_point) { + + if (qx.numel() == 0) { + auto out = qx.clone(); + return out; + } + int64_t ndim = qx.dim(); + TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5."); + const int64_t N = qx.size(0); + const int64_t C = qx.size(1); + const int64_t D = qx.size(2); + const int64_t H = qx.size(3); + const int64_t W = qx.size(4); + + TORCH_CHECK(weight.numel() == C, "Expect weight size to match C"); + TORCH_CHECK(bias.numel() == C, "Expect weight size to match C"); + + const float* weight_data = weight.template data(); + const float* bias_data = bias.template data(); + + TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension"); + TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension"); + + Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + float* alpha_data = alpha.data_ptr(); + float* beta_data = beta.data_ptr(); + + const float* mean_data = mean.template data(); + const float* var_data = var.template data(); + + auto oSizes = qx.sizes(); + auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast3d); + Tensor qy = at::_empty_affine_quantized( + oSizes, + at::device(kCPU).dtype(qx_nhwc.scalar_type()), + output_scale, + output_zero_point, + MemoryFormat::ChannelsLast3d); + + compute_fused_params( + C, + weight_data, + bias_data, + mean_data, + var_data, + eps, + qx.q_scale(), + output_scale, + alpha_data, + beta_data); + + qbatch_norm_stub( + qx.device().type(), + N, + C, + D * H * W, + qx.q_zero_point(), + output_zero_point, + qx_nhwc, + alpha, + beta, + qy); + + return qy; +} + } // namespace Tensor quantized_batch_norm( @@ -131,6 +207,8 @@ Tensor quantized_batch_norm( // Keep the registry in the anonymous namespace. namespace { + +template class QBatchNorm2d final : public torch::OperatorKernel { public: Tensor operator()( @@ -142,7 +220,24 @@ class QBatchNorm2d final : public torch::OperatorKernel { double eps, double output_scale, int64_t output_zero_point) { - return q_batch_norm_impl( + return q_batch_norm_impl( + qx, weight, bias, mean, var, eps, output_scale, output_zero_point); + } +}; + +template +class QBatchNorm3d final : public torch::OperatorKernel { + public: + Tensor operator()( + Tensor qx, + Tensor weight, + Tensor bias, + Tensor mean, + Tensor var, + double eps, + double output_scale, + int64_t output_zero_point) { + return q_batch_norm3d_impl( qx, weight, bias, mean, var, eps, output_scale, output_zero_point); } }; @@ -156,7 +251,40 @@ static auto registry = torch::RegisterOperators().op( "float eps, " "float output_scale, " "int output_zero_point) -> Tensor", - torch::RegisterOperators::options().kernel( + torch::RegisterOperators::options().kernel>( + DispatchKey::QuantizedCPUTensorId)) +.op( + "quantized::batch_norm2d_relu(Tensor qx, " + "Tensor weight, " + "Tensor bias, " + "Tensor mean, " + "Tensor var, " + "float eps, " + "float output_scale, " + "int output_zero_point) -> Tensor", + torch::RegisterOperators::options().kernel>( + DispatchKey::QuantizedCPUTensorId)) +.op( + "quantized::batch_norm3d(Tensor qx, " + "Tensor weight, " + "Tensor bias, " + "Tensor mean, " + "Tensor var, " + "float eps, " + "float output_scale, " + "int output_zero_point) -> Tensor", + torch::RegisterOperators::options().kernel>( + DispatchKey::QuantizedCPUTensorId)) +.op( + "quantized::batch_norm3d_relu(Tensor qx, " + "Tensor weight, " + "Tensor bias, " + "Tensor mean, " + "Tensor var, " + "float eps, " + "float output_scale, " + "int output_zero_point) -> Tensor", + torch::RegisterOperators::options().kernel>( DispatchKey::QuantizedCPUTensorId)); } // namespace diff --git a/test/test_quantized.py b/test/test_quantized.py index a7c3f8c915d..f5c767be1ba 100644 --- a/test/test_quantized.py +++ b/test/test_quantized.py @@ -1394,6 +1394,35 @@ class TestQuantizedOps(TestCase): quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x) self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy()) + @given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5, + min_side=1, max_side=32), + qparams=hu.qparams()), + Y_scale=st.floats(0.2, 2.6), + Y_zero_point=st.integers(0, 5), + qengine=st.sampled_from(("qnnpack", "fbgemm"))) + def test_batch_norm3d(self, X, Y_scale, Y_zero_point, qengine): + if qengine not in torch.backends.quantized.supported_engines: + return + + with override_quantized_engine(qengine): + X, (scale_x, zero_point_x, dtype_x) = X + + X = torch.from_numpy(X) + c = X.shape[1] + + mean = torch.rand(c).float() + var = torch.rand(c).float() + weight = torch.rand(c).float() + bias = torch.rand(c).float() + eps = 0.001 + qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x) + qy = torch.ops.quantized.batch_norm3d(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point) + + float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias, + running_mean=mean, running_var=var, training=False, momentum=0, eps=eps) + quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x) + self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy()) + @unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines, " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" " with instruction set support avx2 or newer.") diff --git a/test/test_quantized_nn_mods.py b/test/test_quantized_nn_mods.py index e747b50f953..c2b26dcd82d 100644 --- a/test/test_quantized_nn_mods.py +++ b/test/test_quantized_nn_mods.py @@ -832,5 +832,23 @@ class ModuleAPITest(QuantizationTestCase): self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(), message="BatchNorm2d module API failed") + def test_batch_norm3d(self): + """Tests the correctness of the batchnorm3d module. + The correctness is defined against the functional implementation. + """ + x = torch.randn((2, 4, 6, 8, 10), dtype=torch.float) + float_mod = torch.nn.BatchNorm3d(4) + float_mod.training = False + + y_ref = float_mod(x) + quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8) + + quant_mod = nnq.BatchNorm3d(4) + qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8) + qy = quant_mod(qx) + + self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(), + message="BatchNorm3d module API failed") + if __name__ == '__main__': run_tests() diff --git a/torch/nn/quantized/modules/__init__.py b/torch/nn/quantized/modules/__init__.py index 3979783ae02..b3894fa4e91 100644 --- a/torch/nn/quantized/modules/__init__.py +++ b/torch/nn/quantized/modules/__init__.py @@ -4,7 +4,7 @@ import torch from torch.nn.modules.pooling import MaxPool2d from .activation import ReLU, ReLU6 -from .batchnorm import BatchNorm2d +from .batchnorm import BatchNorm2d, BatchNorm3d from .conv import Conv2d, Conv3d from .linear import Linear @@ -79,6 +79,7 @@ class DeQuantize(torch.nn.Module): __all__ = [ 'BatchNorm2d', + 'BatchNorm3d', 'Conv2d', 'Conv3d', 'DeQuantize', diff --git a/torch/nn/quantized/modules/batchnorm.py b/torch/nn/quantized/modules/batchnorm.py index dcab6e1946e..090472939ec 100644 --- a/torch/nn/quantized/modules/batchnorm.py +++ b/torch/nn/quantized/modules/batchnorm.py @@ -61,3 +61,59 @@ class BatchNorm2d(torch.nn.BatchNorm2d): new_mod.scale = float(scale) new_mod.zero_point = int(zero_point) return new_mod + +class BatchNorm3d(torch.nn.BatchNorm3d): + r"""Applies Quantized Batch Normalization over a 5D input (a mini-batch of 3D inputs + with additional channel dimension) as described in the paper + `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ . + + .. math:: + + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + Because the Batch Normalization is done over the `C` dimension, computing statistics + on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization. + + Args: + num_features: :math:`C` from an expected input of size + :math:`(N, C, D, H, W)` + eps: a value added to the denominator for numerical stability. + Default: 1e-5 + momentum: the value used for the running_mean and running_var + computation. Can be set to ``None`` for cumulative moving average + (i.e. simple average). Default: 0.1 + + Shape: + - Input: :math:`(N, C, D, H, W)` + - Output: :math:`(N, C, D, H, W)` (same shape as input) + + Examples: + + >>> m = nn.quantized.BatchNorm3d(100) + >>> input = torch.randn(20, 100, 25, 35, 45) + >>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8) + >>> output = m(input) + """ + + def __init__(self, num_features, eps=1e-5, momentum=0.1): + super(BatchNorm3d, self).__init__(num_features) + self.eps = eps + self.scale = 1.0 + self.zero_point = 0 + + def forward(self, input): + return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean, + self.running_var, self.eps, self.scale, self.zero_point) + + def _get_name(self): + return 'QuantizedBatchNorm3d' + + @classmethod + def from_float(cls, mod): + assert type(mod) == torch.nn.BatchNorm3d,\ + "QuantizedBatchNorm3d expects an instance of BatchNorm3d" + scale, zero_point = mod.activation_post_process.calculate_qparams() + new_mod = BatchNorm3d(mod.num_features, mod.eps) + new_mod.scale = float(scale) + new_mod.zero_point = int(zero_point) + return new_mod diff --git a/torch/quantization/default_mappings.py b/torch/quantization/default_mappings.py index e706502b016..4ff90b79d6b 100644 --- a/torch/quantization/default_mappings.py +++ b/torch/quantization/default_mappings.py @@ -18,6 +18,7 @@ DEFAULT_MODULE_MAPPING = { nn.Conv2d: nnq.Conv2d, nn.Conv3d: nnq.Conv3d, nn.BatchNorm2d: nnq.BatchNorm2d, + nn.BatchNorm3d: nnq.BatchNorm3d, QuantStub: nnq.Quantize, DeQuantStub: nnq.DeQuantize, # Wrapper Modules: