fix onednn ConvTranspose2d channels last issue when ic=1 (#99539)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99539
Approved by: https://github.com/mingfeima
This commit is contained in:
XiaobingSuper 2023-04-19 21:09:47 -04:00 committed by PyTorch MergeBot
parent 3af467eff4
commit 233cc34d3b
2 changed files with 25 additions and 61 deletions

View file

@ -708,7 +708,7 @@ Tensor _mkldnn_convolution_transpose(
if (bias.defined()) {
const ideep::tensor b = itensor_from_tensor(bias);
ideep::convolution_transpose_forward::compute(
ideep::convolution_transpose_forward::compute_v3(
x,
w,
b,
@ -719,9 +719,10 @@ Tensor _mkldnn_convolution_transpose(
padding_r(padding, output_padding),
dilation.vec(),
groups,
use_channels_last,
op_attr);
} else {
ideep::convolution_transpose_forward::compute(
ideep::convolution_transpose_forward::compute_v3(
x,
w,
output_sizes,
@ -731,6 +732,7 @@ Tensor _mkldnn_convolution_transpose(
padding_r(padding, output_padding),
dilation.vec(),
groups,
use_channels_last,
op_attr);
}
if (input.is_mkldnn()) {
@ -738,7 +740,6 @@ Tensor _mkldnn_convolution_transpose(
} else if (!use_channels_last) {
return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
} else {
TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc());
return output;
}
}
@ -923,65 +924,18 @@ Tensor mkldnn_convolution_transpose(
IntArrayRef dilation,
int64_t groups)
{
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
if (input.scalar_type() == ScalarType::BFloat16) {
TORCH_CHECK(mkldnn_bf16_device_check(),
"mkldnn_convolution_transpose: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
}
bool use_channels_last = mkldnn_conv_use_channels_last(input, weight);
auto memory_format = mkldnn_convolution_memory_format(input.ndimension(), use_channels_last);
auto output_sizes = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
auto output = at::empty({0}, input.options());
const ideep::tensor x = itensor_from_tensor(input);
ideep::tensor w = itensor_from_tensor(weight);
// mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
// while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
w.transpose_(0, 1);
ideep::tensor y;
if (use_channels_last) {
output.resize_(output_sizes, memory_format);
y = itensor_from_tensor(output);
}
if (bias.defined()) {
const ideep::tensor b = itensor_from_tensor(bias);
ideep::convolution_transpose_forward::compute(
x,
w,
b,
output_sizes,
y,
stride.vec(),
padding.vec(),
padding_r(padding, output_padding),
dilation.vec(),
groups);
} else {
ideep::convolution_transpose_forward::compute(
x,
w,
output_sizes,
y,
stride.vec(),
padding.vec(),
padding_r(padding, output_padding),
dilation.vec(),
groups);
}
if (input.is_mkldnn()) {
return MKLDNNTensor(y, input.options());
} else if (!use_channels_last) {
return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
} else {
return output;
}
return _mkldnn_convolution_transpose(
input,
weight,
bias_opt,
padding,
output_padding,
stride,
dilation,
groups,
use_channels_last
);
}
Tensor mkldnn_convolution_transpose_backward_input(

View file

@ -427,6 +427,16 @@ class TestMkldnn(TestCase):
def test_conv_transpose3d(self):
self._test_conv_transpose_base(dim=3)
def test_conv_transposed2d_ic1_nhwc(self):
x = torch.ones([1, 1, 2, 2]).contiguous(memory_format=torch.channels_last)
model = torch.nn.ConvTranspose2d(in_channels=1, out_channels=2, kernel_size=[5, 5]).eval()
model.weight.data = torch.ones([1, 2, 5, 5]).contiguous(memory_format=torch.channels_last)
with torch.no_grad():
y = model(x)
with torch.backends.mkldnn.flags(enabled=False):
y_ref = model(x)
self.assertEqual(y, y_ref)
def test_conv2d_legacy_jit_model(self):
"""
MKLDNN integration used to serialize models with 5d weight for grouped