mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
fix onednn ConvTranspose2d channels last issue when ic=1 (#99539)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99539 Approved by: https://github.com/mingfeima
This commit is contained in:
parent
3af467eff4
commit
233cc34d3b
2 changed files with 25 additions and 61 deletions
|
|
@ -708,7 +708,7 @@ Tensor _mkldnn_convolution_transpose(
|
|||
|
||||
if (bias.defined()) {
|
||||
const ideep::tensor b = itensor_from_tensor(bias);
|
||||
ideep::convolution_transpose_forward::compute(
|
||||
ideep::convolution_transpose_forward::compute_v3(
|
||||
x,
|
||||
w,
|
||||
b,
|
||||
|
|
@ -719,9 +719,10 @@ Tensor _mkldnn_convolution_transpose(
|
|||
padding_r(padding, output_padding),
|
||||
dilation.vec(),
|
||||
groups,
|
||||
use_channels_last,
|
||||
op_attr);
|
||||
} else {
|
||||
ideep::convolution_transpose_forward::compute(
|
||||
ideep::convolution_transpose_forward::compute_v3(
|
||||
x,
|
||||
w,
|
||||
output_sizes,
|
||||
|
|
@ -731,6 +732,7 @@ Tensor _mkldnn_convolution_transpose(
|
|||
padding_r(padding, output_padding),
|
||||
dilation.vec(),
|
||||
groups,
|
||||
use_channels_last,
|
||||
op_attr);
|
||||
}
|
||||
if (input.is_mkldnn()) {
|
||||
|
|
@ -738,7 +740,6 @@ Tensor _mkldnn_convolution_transpose(
|
|||
} else if (!use_channels_last) {
|
||||
return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
|
||||
} else {
|
||||
TORCH_INTERNAL_ASSERT(y.get_desc().is_nhwc());
|
||||
return output;
|
||||
}
|
||||
}
|
||||
|
|
@ -923,65 +924,18 @@ Tensor mkldnn_convolution_transpose(
|
|||
IntArrayRef dilation,
|
||||
int64_t groups)
|
||||
{
|
||||
// See [Note: hacky wrapper removal for optional tensor]
|
||||
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
|
||||
const Tensor& bias = *bias_maybe_owned;
|
||||
|
||||
if (input.scalar_type() == ScalarType::BFloat16) {
|
||||
TORCH_CHECK(mkldnn_bf16_device_check(),
|
||||
"mkldnn_convolution_transpose: bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
|
||||
}
|
||||
|
||||
bool use_channels_last = mkldnn_conv_use_channels_last(input, weight);
|
||||
auto memory_format = mkldnn_convolution_memory_format(input.ndimension(), use_channels_last);
|
||||
|
||||
auto output_sizes = conv_input_size(input.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
|
||||
auto output = at::empty({0}, input.options());
|
||||
|
||||
const ideep::tensor x = itensor_from_tensor(input);
|
||||
ideep::tensor w = itensor_from_tensor(weight);
|
||||
// mkldnn transposed convolution has weight in logical order of OIHW or OIDHW,
|
||||
// while PyTorch has IOHW or IODHW, `._tranpose()` switches strides (no memory copy).
|
||||
w.transpose_(0, 1);
|
||||
|
||||
ideep::tensor y;
|
||||
if (use_channels_last) {
|
||||
output.resize_(output_sizes, memory_format);
|
||||
y = itensor_from_tensor(output);
|
||||
}
|
||||
if (bias.defined()) {
|
||||
const ideep::tensor b = itensor_from_tensor(bias);
|
||||
ideep::convolution_transpose_forward::compute(
|
||||
x,
|
||||
w,
|
||||
b,
|
||||
output_sizes,
|
||||
y,
|
||||
stride.vec(),
|
||||
padding.vec(),
|
||||
padding_r(padding, output_padding),
|
||||
dilation.vec(),
|
||||
groups);
|
||||
} else {
|
||||
ideep::convolution_transpose_forward::compute(
|
||||
x,
|
||||
w,
|
||||
output_sizes,
|
||||
y,
|
||||
stride.vec(),
|
||||
padding.vec(),
|
||||
padding_r(padding, output_padding),
|
||||
dilation.vec(),
|
||||
groups);
|
||||
}
|
||||
|
||||
if (input.is_mkldnn()) {
|
||||
return MKLDNNTensor(y, input.options());
|
||||
} else if (!use_channels_last) {
|
||||
return mkldnn_to_dense(MKLDNNTensor(y, input.options()));
|
||||
} else {
|
||||
return output;
|
||||
}
|
||||
return _mkldnn_convolution_transpose(
|
||||
input,
|
||||
weight,
|
||||
bias_opt,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
groups,
|
||||
use_channels_last
|
||||
);
|
||||
}
|
||||
|
||||
Tensor mkldnn_convolution_transpose_backward_input(
|
||||
|
|
|
|||
|
|
@ -427,6 +427,16 @@ class TestMkldnn(TestCase):
|
|||
def test_conv_transpose3d(self):
|
||||
self._test_conv_transpose_base(dim=3)
|
||||
|
||||
def test_conv_transposed2d_ic1_nhwc(self):
|
||||
x = torch.ones([1, 1, 2, 2]).contiguous(memory_format=torch.channels_last)
|
||||
model = torch.nn.ConvTranspose2d(in_channels=1, out_channels=2, kernel_size=[5, 5]).eval()
|
||||
model.weight.data = torch.ones([1, 2, 5, 5]).contiguous(memory_format=torch.channels_last)
|
||||
with torch.no_grad():
|
||||
y = model(x)
|
||||
with torch.backends.mkldnn.flags(enabled=False):
|
||||
y_ref = model(x)
|
||||
self.assertEqual(y, y_ref)
|
||||
|
||||
def test_conv2d_legacy_jit_model(self):
|
||||
"""
|
||||
MKLDNN integration used to serialize models with 5d weight for grouped
|
||||
|
|
|
|||
Loading…
Reference in a new issue