mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Revert D20464855: [pytorch][PR] Add the fusion of quantized batchnorm and relu
Test Plan: revert-hammer Differential Revision: D20464855 Original commit changeset: 57090d427053 fbshipit-source-id: e7c50b5e7cd27a479539d7ee17580118377971c5
This commit is contained in:
parent
a4afac6076
commit
e7fc55ef7b
4 changed files with 36 additions and 92 deletions
|
|
@ -1477,7 +1477,6 @@ REGISTER_DISPATCH(qcat_nhwc_stub, &qcat_nhwc_kernel<false>);
|
|||
REGISTER_DISPATCH(qcat_relu_nhwc_stub, &qcat_nhwc_kernel<true>);
|
||||
REGISTER_DISPATCH(qtopk_stub, &qtopk_kernel);
|
||||
REGISTER_DISPATCH(qbatch_norm_stub, &q_batch_norm_kernel<false>);
|
||||
REGISTER_DISPATCH(qbatch_norm_relu_stub, &q_batch_norm_kernel<true>);
|
||||
REGISTER_DISPATCH(fake_quant_tensor_stub, &fake_quantize_tensor_kernel);
|
||||
REGISTER_DISPATCH(fake_quant_grad_tensor_stub, &fake_quantize_grad_tensor_kernel);
|
||||
REGISTER_DISPATCH(fake_quant_per_channel_stub, &fake_quant_per_channel_cpu);
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ namespace at {
|
|||
namespace native {
|
||||
|
||||
DEFINE_DISPATCH(qbatch_norm_stub);
|
||||
DEFINE_DISPATCH(qbatch_norm_relu_stub);
|
||||
|
||||
namespace {
|
||||
void compute_fused_params(
|
||||
|
|
@ -98,31 +97,18 @@ Tensor q_batch_norm_impl(
|
|||
output_scale,
|
||||
alpha_data,
|
||||
beta_data);
|
||||
if (ReluFused) {
|
||||
qbatch_norm_relu_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
} else {
|
||||
qbatch_norm_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
}
|
||||
|
||||
qbatch_norm_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
return qy;
|
||||
}
|
||||
|
||||
|
|
@ -187,31 +173,18 @@ Tensor q_batch_norm3d_impl(
|
|||
alpha_data,
|
||||
beta_data);
|
||||
|
||||
if (ReluFused) {
|
||||
qbatch_norm_relu_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
D * H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
} else {
|
||||
qbatch_norm_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
D * H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
}
|
||||
qbatch_norm_stub(
|
||||
qx.device().type(),
|
||||
N,
|
||||
C,
|
||||
D * H * W,
|
||||
qx.q_zero_point(),
|
||||
output_zero_point,
|
||||
qx_nhwc,
|
||||
alpha,
|
||||
beta,
|
||||
qy);
|
||||
|
||||
return qy;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -137,7 +137,6 @@ DECLARE_DISPATCH(qcat_nhwc_fn, qcat_nhwc_stub);
|
|||
DECLARE_DISPATCH(qcat_nhwc_fn, qcat_relu_nhwc_stub);
|
||||
DECLARE_DISPATCH(qtopk_fn, qtopk_stub);
|
||||
DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_stub);
|
||||
DECLARE_DISPATCH(qbatch_norm_fn, qbatch_norm_relu_stub);
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -1401,10 +1401,13 @@ class TestQuantizedOps(TestCase):
|
|||
min_side=1, max_side=32),
|
||||
qparams=hu.qparams()),
|
||||
Y_scale=st.floats(0.2, 2.6),
|
||||
Y_zero_point=st.integers(0, 5))
|
||||
def test_batch_norm(self, X, Y_scale, Y_zero_point):
|
||||
Y_zero_point=st.integers(0, 5),
|
||||
qengine=st.sampled_from(("qnnpack", "fbgemm")))
|
||||
def test_batch_norm(self, X, Y_scale, Y_zero_point, qengine):
|
||||
if qengine not in torch.backends.quantized.supported_engines:
|
||||
return
|
||||
|
||||
with override_quantized_engine("fbgemm"):
|
||||
with override_quantized_engine(qengine):
|
||||
X, (scale_x, zero_point_x, dtype_x) = X
|
||||
|
||||
X = torch.from_numpy(X)
|
||||
|
|
@ -1423,47 +1426,17 @@ class TestQuantizedOps(TestCase):
|
|||
quantize_ref = torch.quantize_per_tensor(float_ref, Y_scale, Y_zero_point, dtype_x)
|
||||
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
|
||||
|
||||
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=5,
|
||||
min_side=1, max_side=32),
|
||||
qparams=hu.qparams()),
|
||||
Y_scale=st.floats(0.2, 2.6),
|
||||
Y_zero_point=st.integers(0, 5))
|
||||
def test_batch_norm_relu(self, X, Y_scale, Y_zero_point):
|
||||
|
||||
with override_quantized_engine("fbgemm"):
|
||||
X, (scale_x, zero_point_x, dtype_x) = X
|
||||
|
||||
X = torch.from_numpy(X)
|
||||
c = X.shape[1]
|
||||
|
||||
mean = torch.rand(c).float()
|
||||
var = torch.rand(c).float()
|
||||
weight = torch.rand(c).float()
|
||||
bias = torch.rand(c).float()
|
||||
eps = 0.001
|
||||
qx = torch.quantize_per_tensor(X, scale_x, zero_point_x, dtype_x)
|
||||
if len(X.shape) == 4:
|
||||
qy = torch.ops.quantized.batch_norm2d_relu(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
|
||||
else:
|
||||
qy = torch.ops.quantized.batch_norm3d_relu(qx, weight, bias, mean, var, eps, Y_scale, Y_zero_point)
|
||||
|
||||
|
||||
float_ref = F.batch_norm(qx.dequantize(), weight=weight, bias=bias,
|
||||
running_mean=mean, running_var=var, training=False, momentum=0, eps=eps).numpy()
|
||||
|
||||
float_ref_relu = float_ref.copy()
|
||||
float_ref_relu[float_ref < 0] = 0
|
||||
quantize_ref = torch.quantize_per_tensor(torch.from_numpy(float_ref_relu), Y_scale, Y_zero_point, dtype_x)
|
||||
self.assertEqual(qy.int_repr().numpy(), quantize_ref.int_repr().numpy())
|
||||
|
||||
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=5, max_dims=5,
|
||||
min_side=1, max_side=32),
|
||||
qparams=hu.qparams()),
|
||||
Y_scale=st.floats(0.2, 2.6),
|
||||
Y_zero_point=st.integers(0, 5))
|
||||
def test_batch_norm3d(self, X, Y_scale, Y_zero_point):
|
||||
Y_zero_point=st.integers(0, 5),
|
||||
qengine=st.sampled_from(("qnnpack", "fbgemm")))
|
||||
def test_batch_norm3d(self, X, Y_scale, Y_zero_point, qengine):
|
||||
if qengine not in torch.backends.quantized.supported_engines:
|
||||
return
|
||||
|
||||
with override_quantized_engine("fbgemm"):
|
||||
with override_quantized_engine(qengine):
|
||||
X, (scale_x, zero_point_x, dtype_x) = X
|
||||
|
||||
X = torch.from_numpy(X)
|
||||
|
|
|
|||
Loading…
Reference in a new issue