mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Add the quantized batch_norm3d and also batch_norm3d fused with relu operators (#34702)
Summary: as title, for bringing up the quantized video model. Will add the batch_norm_relu test in another PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/34702 Differential Revision: D20436092 Pulled By: lly-zero-one fbshipit-source-id: 116bd306f7880bfd763d8575654fbd6c92818338
This commit is contained in:
parent
da11646db1
commit
68758b2fa0
6 changed files with 239 additions and 6 deletions
|
|
@ -20,8 +20,8 @@ void compute_fused_params(
|
|||
const float* mean_data,
|
||||
const float* var_data,
|
||||
double eps,
|
||||
float input_scale,
|
||||
float output_scale,
|
||||
double input_scale,
|
||||
double output_scale,
|
||||
float* alpha_data,
|
||||
float* beta_data) {
|
||||
// Batch Normalization
|
||||
|
|
@ -46,7 +46,7 @@ Tensor q_batch_norm_impl(
|
|||
Tensor mean,
|
||||
Tensor var,
|
||||
double eps,
|
||||
float output_scale,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
|
||||
if (qx.numel() == 0) {
|
||||
|
|
@ -112,6 +112,82 @@ Tensor q_batch_norm_impl(
|
|||
return qy;
|
||||
}
|
||||
|
||||
template <bool ReluFused>
|
||||
Tensor q_batch_norm3d_impl(
|
||||
Tensor qx,
|
||||
Tensor weight,
|
||||
Tensor bias,
|
||||
Tensor mean,
|
||||
Tensor var,
|
||||
double eps,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
|
||||
if (qx.numel() == 0) {
|
||||
auto out = qx.clone();
|
||||
return out;
|
||||
}
|
||||
int64_t ndim = qx.dim();
|
||||
TORCH_CHECK(ndim == 5, "Expecting the input tensor of rank 5.");
|
||||
const int64_t N = qx.size(0);
|
||||
const int64_t C = qx.size(1);
|
||||
const int64_t D = qx.size(2);
|
||||
const int64_t H = qx.size(3);
|
||||
const int64_t W = qx.size(4);
|
||||
|
||||
TORCH_CHECK(weight.numel() == C, "Expect weight size to match C");
|
||||
TORCH_CHECK(bias.numel() == C, "Expect weight size to match C");
|
||||
|
||||
const float* weight_data = weight.template data<float>();
|
||||
const float* bias_data = bias.template data<float>();
|
||||
|
||||
TORCH_CHECK(mean.numel() == C, "Mean size must match channel dimension");
|
||||
TORCH_CHECK(var.numel() == C, "Variance size must match channel dimension");
|
||||
|
||||
Tensor alpha = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
Tensor beta = at::empty_like(mean, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
float* alpha_data = alpha.data_ptr<float>();
|
||||
float* beta_data = beta.data_ptr<float>();
|
||||
|
||||
const float* mean_data = mean.template data<float>();
|
||||
const float* var_data = var.template data<float>();
|
||||
|
||||
auto oSizes = qx.sizes();
|
||||
auto qx_nhwc = qx.contiguous(MemoryFormat::ChannelsLast3d);
|
||||
Tensor qy = at::_empty_affine_quantized(
|
||||
oSizes,
|
||||
at::device(kCPU).dtype(qx_nhwc.scalar_type()),
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
MemoryFormat::ChannelsLast3d);
|
||||
|
||||
compute_fused_params(
|
||||
C,
|
||||
weight_data,
|
||||
bias_data,
|
||||
mean_data,
|
||||
var_data,
|
||||
eps,
|
||||
qx.q_scale(),
|
||||
output_scale,
|
||||
alpha_data,
|
||||
beta_data);
|
||||
|
||||
qbatch_norm_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
D * H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
|
||||
return qy;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor quantized_batch_norm(
|
||||
|
|
@ -131,6 +207,8 @@ Tensor quantized_batch_norm(
|
|||
|
||||
// Keep the registry in the anonymous namespace.
|
||||
namespace {
|
||||
|
||||
template <bool ReLUFused = false>
|
||||
class QBatchNorm2d final : public torch::OperatorKernel {
|
||||
public:
|
||||
Tensor operator()(
|
||||
|
|
@ -142,7 +220,24 @@ class QBatchNorm2d final : public torch::OperatorKernel {
|
|||
double eps,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
return q_batch_norm_impl<false>(
|
||||
return q_batch_norm_impl<ReLUFused>(
|
||||
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool ReLUFused = false>
|
||||
class QBatchNorm3d final : public torch::OperatorKernel {
|
||||
public:
|
||||
Tensor operator()(
|
||||
Tensor qx,
|
||||
Tensor weight,
|
||||
Tensor bias,
|
||||
Tensor mean,
|
||||
Tensor var,
|
||||
double eps,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
return q_batch_norm3d_impl<ReLUFused>(
|
||||
qx, weight, bias, mean, var, eps, output_scale, output_zero_point);
|
||||
}
|
||||
};
|
||||
|
|
@ -156,7 +251,40 @@ static auto registry = torch::RegisterOperators().op(
|
|||
"float eps, "
|
||||
"float output_scale, "
|
||||
"int output_zero_point) -> Tensor",
|
||||
torch::RegisterOperators::options().kernel<QBatchNorm2d>(
|
||||
torch::RegisterOperators::options().kernel<QBatchNorm2d<false>>(
|
||||
DispatchKey::QuantizedCPUTensorId))
|
||||
.op(
|
||||
"quantized::batch_norm2d_relu(Tensor qx, "
|
||||
"Tensor weight, "
|
||||
"Tensor bias, "
|
||||
"Tensor mean, "
|
||||
"Tensor var, "
|
||||
"float eps, "
|
||||
"float output_scale, "
|
||||
"int output_zero_point) -> Tensor",
|
||||
torch::RegisterOperators::options().kernel<QBatchNorm2d<true>>(
|
||||
DispatchKey::QuantizedCPUTensorId))
|
||||
.op(
|
||||
"quantized::batch_norm3d(Tensor qx, "
|
||||
"Tensor weight, "
|
||||
"Tensor bias, "
|
||||
"Tensor mean, "
|
||||
"Tensor var, "
|
||||
"float eps, "
|
||||
"float output_scale, "
|
||||
"int output_zero_point) -> Tensor",
|
||||
torch::RegisterOperators::options().kernel<QBatchNorm3d<false>>(
|
||||
DispatchKey::QuantizedCPUTensorId))
|
||||
.op(
|
||||
"quantized::batch_norm3d_relu(Tensor qx, "
|
||||
"Tensor weight, "
|
||||
"Tensor bias, "
|
||||
"Tensor mean, "
|
||||
"Tensor var, "
|
||||
"float eps, "
|
||||
"float output_scale, "
|
||||
"int output_zero_point) -> Tensor",
|
||||
torch::RegisterOperators::options().kernel<QBatchNorm3d<true>>(
|
||||
DispatchKey::QuantizedCPUTensorId));
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -1394,6 +1394,35 @@ class TestQuantizedOps(TestCase):
|
|||
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
|
||||
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
|
||||
|
||||
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5,
|
||||
min_side=1, max_side=32),
|
||||
qparams=hu.qparams()),
|
||||
Y_scale=st.floats(0.2, 2.6),
|
||||
Y_zero_point=st.integers(0, 5),
|
||||
qengine=st.sampled_from(("qnnpack", "fbgemm")))
|
||||
def test_batch_norm3d(self, X, Y_scale, Y_zero_point, qengine):
|
||||
if qengine not in torch.backends.quantized.supported_engines:
|
||||
return
|
||||
|
||||
with override_quantized_engine(qengine):
|
||||
X, (scale_x, zero_point_x, dtype_x) = X
|
||||
|
||||
X = torch.from_numpy(X)
|
||||
c = X.shape[1]
|
||||
|
||||
mean = torch.rand(c).float()
|
||||
var = torch.rand(c).float()
|
||||
weight = torch.rand(c).float()
|
||||
bias = torch.rand(c).float()
|
||||
eps = 0.001
|
||||
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
|
||||
qy = torch.ops.quantized.batch_norm3d(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
|
||||
|
||||
float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
|
||||
running_mean=mean, running_var=var, training=False, momentum=0, eps=eps)
|
||||
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
|
||||
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
|
||||
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
|
||||
" with instruction set support avx2 or newer.")
|
||||
|
|
|
|||
|
|
@ -832,5 +832,23 @@ class ModuleAPITest(QuantizationTestCase):
|
|||
self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
|
||||
message="BatchNorm2d module API failed")
|
||||
|
||||
def test_batch_norm3d(self):
|
||||
"""Tests the correctness of the batchnorm3d module.
|
||||
The correctness is defined against the functional implementation.
|
||||
"""
|
||||
x = torch.randn((2, 4, 6, 8, 10), dtype=torch.float)
|
||||
float_mod = torch.nn.BatchNorm3d(4)
|
||||
float_mod.training = False
|
||||
|
||||
y_ref = float_mod(x)
|
||||
quant_ref = torch.quantize_per_tensor(y_ref, 1.0, 0, dtype=torch.quint8)
|
||||
|
||||
quant_mod = nnq.BatchNorm3d(4)
|
||||
qx = torch.quantize_per_tensor(x, 1.0, 0, dtype=torch.quint8)
|
||||
qy = quant_mod(qx)
|
||||
|
||||
self.assertEqual(quant_ref.int_repr().numpy(), qy.int_repr().numpy(),
|
||||
message="BatchNorm3d module API failed")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch.nn.modules.pooling import MaxPool2d
|
||||
|
||||
from .activation import ReLU, ReLU6
|
||||
from .batchnorm import BatchNorm2d
|
||||
from .batchnorm import BatchNorm2d, BatchNorm3d
|
||||
from .conv import Conv2d, Conv3d
|
||||
from .linear import Linear
|
||||
|
||||
|
|
@ -79,6 +79,7 @@ class DeQuantize(torch.nn.Module):
|
|||
|
||||
__all__ = [
|
||||
'BatchNorm2d',
|
||||
'BatchNorm3d',
|
||||
'Conv2d',
|
||||
'Conv3d',
|
||||
'DeQuantize',
|
||||
|
|
|
|||
|
|
@ -61,3 +61,59 @@ class BatchNorm2d(torch.nn.BatchNorm2d):
|
|||
new_mod.scale = float(scale)
|
||||
new_mod.zero_point = int(zero_point)
|
||||
return new_mod
|
||||
|
||||
class BatchNorm3d(torch.nn.BatchNorm3d):
|
||||
r"""Applies Quantized Batch Normalization over a 5D input (a mini-batch of 3D inputs
|
||||
with additional channel dimension) as described in the paper
|
||||
`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .
|
||||
|
||||
.. math::
|
||||
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
|
||||
Because the Batch Normalization is done over the `C` dimension, computing statistics
|
||||
on `(N, D, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
|
||||
|
||||
Args:
|
||||
num_features: :math:`C` from an expected input of size
|
||||
:math:`(N, C, D, H, W)`
|
||||
eps: a value added to the denominator for numerical stability.
|
||||
Default: 1e-5
|
||||
momentum: the value used for the running_mean and running_var
|
||||
computation. Can be set to ``None`` for cumulative moving average
|
||||
(i.e. simple average). Default: 0.1
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, C, D, H, W)`
|
||||
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
||||
|
||||
Examples:
|
||||
|
||||
>>> m = nn.quantized.BatchNorm3d(100)
|
||||
>>> input = torch.randn(20, 100, 25, 35, 45)
|
||||
>>> quantized_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
||||
super(BatchNorm3d, self).__init__(num_features)
|
||||
self.eps = eps
|
||||
self.scale = 1.0
|
||||
self.zero_point = 0
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.quantized.batch_norm3d(input, self.weight, self.bias, self.running_mean,
|
||||
self.running_var, self.eps, self.scale, self.zero_point)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizedBatchNorm3d'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
assert type(mod) == torch.nn.BatchNorm3d,\
|
||||
"QuantizedBatchNorm3d expects an instance of BatchNorm3d"
|
||||
scale, zero_point = mod.activation_post_process.calculate_qparams()
|
||||
new_mod = BatchNorm3d(mod.num_features, mod.eps)
|
||||
new_mod.scale = float(scale)
|
||||
new_mod.zero_point = int(zero_point)
|
||||
return new_mod
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ DEFAULT_MODULE_MAPPING = {
|
|||
nn.Conv2d: nnq.Conv2d,
|
||||
nn.Conv3d: nnq.Conv3d,
|
||||
nn.BatchNorm2d: nnq.BatchNorm2d,
|
||||
nn.BatchNorm3d: nnq.BatchNorm3d,
|
||||
QuantStub: nnq.Quantize,
|
||||
DeQuantStub: nnq.DeQuantize,
|
||||
# Wrapper Modules:
|
||||
|
|
|
|||
Loading…
Reference in a new issue