Add the quantized batch_norm3d and also batch_norm3d fused with relu operators (#34702)

Summary:
as title, for bringing up the quantized video model. Will add the batch_norm_relu test in another PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34702

Differential Revision: D20436092

Pulled By: lly-zero-one

fbshipit-source-id: 116bd306f7880bfd763d8575654fbd6c92818338
This commit is contained in:
Lingyi Liu 2020-03-13 20:27:40 -07:00 committed by Facebook GitHub Bot
parent da11646db1
commit 68758b2fa0
6 changed files with 239 additions and 6 deletions

View file

@ -20,8 +20,8 @@ void compute_fused_params(
const float* mean_data,
const float* var_data,
double eps,
float input_scale,
float output_scale,
double input_scale,
double output_scale,
float* alpha_data,
float* beta_data) {
// Batch Normalization
@ -46,7 +46,7 @@ Tensor q_batch_norm_impl(
Tensor mean,
Tensor var,
double eps,
float output_scale,
double output_scale,
int64_t output_zero_point) {
if (qx.numel() == 0) {
@ -112,6 +112,82 @@ Tensor q_batch_norm_impl(
return qy;
}
template <bool ReluFused>
Tensor q_batch_norm3d_impl(
Tensor qx,
Tensor weight,
Tensor bias,
Tensor mean,
Tensor var,
double eps,
double output_scale,
int64_t output_zero_point) {
if (qx.numel() == 0) {
auto out = qx.clone();
return out;
}
int64_t ndim = qx.dim();
TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5.");
const int64_t N = qx.size(0);
const int64_t C = qx.size(1);
const int64_t D = qx.size(2);
const int64_t H = qx.size(3);
const int64_t W = qx.size(4);
TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
const float* weight_data = weight.template data<float>();
const float* bias_data = bias.template data<float>();
TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
float* alpha_data = alpha.data_ptr<float>();
float* beta_data = beta.data_ptr<float>();
const float* mean_data = mean.template data<float>();
const float* var_data = var.template data<float>();
auto oSizes = qx.sizes();
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast3d);
Tensor qy = at::_empty_affine_quantized(
oSizes,
at::device(kCPU).dtype(qx_nhwc.scalar_type()),
output_scale,
output_zero_point,
MemoryFormat::ChannelsLast3d);
compute_fused_params(
C,
weight_data,
bias_data,
mean_data,
var_data,
eps,
qx.q_scale(),
output_scale,
alpha_data,
beta_data);
qbatch_norm_stub(
qx.device().type(),
N,
C,
D * H * W,
qx.q_zero_point(),
output_zero_point,
qx_nhwc,
alpha,
beta,
qy);
return qy;
}
} // namespace
Tensor quantized_batch_norm(
@ -131,6 +207,8 @@ Tensor quantized_batch_norm(
// Keep the registry in the anonymous namespace.
namespace {
template <bool ReLUFused = false>
class QBatchNorm2d final : public torch::OperatorKernel {
public:
Tensor operator()(
@ -142,7 +220,24 @@ class QBatchNorm2d final : public torch::OperatorKernel {
double eps,
double output_scale,
int64_t output_zero_point) {
return q_batch_norm_impl<false>(
return q_batch_norm_impl<ReLUFused>(
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
}
};
template <bool ReLUFused = false>
class QBatchNorm3d final : public torch::OperatorKernel {
public:
Tensor operator()(
Tensor qx,
Tensor weight,
Tensor bias,
Tensor mean,
Tensor var,
double eps,
double output_scale,
int64_t output_zero_point) {
return q_batch_norm3d_impl<ReLUFused>(
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
}
};
@ -156,7 +251,40 @@ static auto registry = torch::RegisterOperators().op(
"float eps, "
"float output_scale, "
"int output_zero_point) -> Tensor",
torch::RegisterOperators::options().kernel<QBatchNorm2d>(
torch::RegisterOperators::options().kernel<QBatchNorm2d<false>>(
DispatchKey::QuantizedCPUTensorId))
.op(
"quantized::batch_norm2d_relu(Tensor qx, "
"Tensor weight, "
"Tensor bias, "
"Tensor mean, "
"Tensor var, "
"float eps, "
"float output_scale, "
"int output_zero_point) -> Tensor",
torch::RegisterOperators::options().kernel<QBatchNorm2d<true>>(
DispatchKey::QuantizedCPUTensorId))
.op(
"quantized::batch_norm3d(Tensor qx, "
"Tensor weight, "
"Tensor bias, "
"Tensor mean, "
"Tensor var, "
"float eps, "
"float output_scale, "
"int output_zero_point) -> Tensor",
torch::RegisterOperators::options().kernel<QBatchNorm3d<false>>(
DispatchKey::QuantizedCPUTensorId))
.op(
"quantized::batch_norm3d_relu(Tensor qx, "
"Tensor weight, "
"Tensor bias, "
"Tensor mean, "
"Tensor var, "
"float eps, "
"float output_scale, "
"int output_zero_point) -> Tensor",
torch::RegisterOperators::options().kernel<QBatchNorm3d<true>>(
DispatchKey::QuantizedCPUTensorId));
} // namespace

View file

@ -1394,6 +1394,35 @@ class TestQuantizedOps(TestCase):
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5,
min_side=1, max_side=32),
qparams=hu.qparams()),
Y_scale=st.floats(0.2, 2.6),
Y_zero_point=st.integers(0, 5),
qengine=st.sampled_from(("qnnpack", "fbgemm")))
def test_batch_norm3d(self, X, Y_scale, Y_zero_point, qengine):
if qengine not in torch.backends.quantized.supported_engines:
return
with override_quantized_engine(qengine):
X, (scale_x, zero_point_x, dtype_x) = X
X = torch.from_numpy(X)
c = X.shape[1]
mean = torch.rand(c).float()
var = torch.rand(c).float()
weight = torch.rand(c).float()
bias = torch.rand(c).float()
eps = 0.001
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
qy = torch.ops.quantized.batch_norm3d(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
running_mean=mean, running_var=var, training=False, momentum=0, eps=eps)
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")

View file

@ -832,5 +832,23 @@ class ModuleAPITest(QuantizationTestCase):
self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
message="BatchNorm2d module API failed")
def test_batch_norm3d(self):
"""Tests the correctness of the batchnorm3d module.
The correctness is defined against the functional implementation.
"""
x = torch.randn((2, 4, 6, 8, 10), dtype=torch.float)
float_mod = torch.nn.BatchNorm3d(4)
float_mod.training = False
y_ref = float_mod(x)
quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
quant_mod = nnq.BatchNorm3d(4)
qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
qy = quant_mod(qx)
self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
message="BatchNorm3d module API failed")
if __name__ == '__main__':
run_tests()

View file

@ -4,7 +4,7 @@ import torch
from torch.nn.modules.pooling import MaxPool2d
from .activation import ReLU, ReLU6
from .batchnorm import BatchNorm2d
from .batchnorm import BatchNorm2d, BatchNorm3d
from .conv import Conv2d, Conv3d
from .linear import Linear
@ -79,6 +79,7 @@ class DeQuantize(torch.nn.Module):
__all__ = [
'BatchNorm2d',
'BatchNorm3d',
'Conv2d',
'Conv3d',
'DeQuantize',

View file

@ -61,3 +61,59 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
new_mod.scale = float(scale)
new_mod.zero_point = int(zero_point)
return new_mod
class BatchNorm3d(torch.nn.BatchNorm3d):
r"""Applies Quantized Batch Normalization over a 5D input (a mini-batch of 3D inputs
with additional channel dimension) as described in the paper
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
Because the Batch Normalization is done over the `C` dimension, computing statistics
on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, D, H, W)`
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
Shape:
- Input: :math:`(N, C, D, H, W)`
- Output: :math:`(N, C, D, H, W)` (same shape as input)
Examples:
>>> m = nn.quantized.BatchNorm3d(100)
>>> input = torch.randn(20, 100, 25, 35, 45)
>>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
>>> output = m(input)
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BatchNorm3d, self).__init__(num_features)
self.eps = eps
self.scale = 1.0
self.zero_point = 0
def forward(self, input):
return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean,
self.running_var, self.eps, self.scale, self.zero_point)
def _get_name(self):
return 'QuantizedBatchNorm3d'
@classmethod
def from_float(cls, mod):
assert type(mod) == torch.nn.BatchNorm3d,\
"QuantizedBatchNorm3d expects an instance of BatchNorm3d"
scale, zero_point = mod.activation_post_process.calculate_qparams()
new_mod = BatchNorm3d(mod.num_features, mod.eps)
new_mod.scale = float(scale)
new_mod.zero_point = int(zero_point)
return new_mod

View file

@ -18,6 +18,7 @@ DEFAULT_MODULE_MAPPING = {
nn.Conv2d: nnq.Conv2d,
nn.Conv3d: nnq.Conv3d,
nn.BatchNorm2d: nnq.BatchNorm2d,
nn.BatchNorm3d: nnq.BatchNorm3d,
QuantStub: nnq.Quantize,
DeQuantStub: nnq.DeQuantize,
# Wrapper Modules: