From 90f19fee8ac1135a6a22feaedd1920b92ce7c982 Mon Sep 17 00:00:00 2001 From: Roy Hvaara Date: Sun, 1 Dec 2024 18:36:53 +0000 Subject: [PATCH] [MPS] Convert `channels_last_3d` to `contiguous` for input tensor in `nn.Conv3d` (#141780) When the input tensor to Conv3d is in the channels_last_3d memory format the Conv3d op will generate incorrect output (see example image in #141471). This PR checks if the op is 3d, and then attempts to convert the input tensor to contiguous. Added a regression test that verifies the output by running the same op on the CPU. I'm unsure if Conv3d supports the channels last memory format after #128393. If it does, we should consider updating the logic to utilize this as it would be more efficient. Perhaps @DenisVieriu97 knows or has more context? Fixes #141471 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141780 Approved by: https://github.com/malfet --- aten/src/ATen/native/mps/operations/Convolution.mm | 4 ++-- test/test_mps.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 83e33f7f610..5852be8fb74 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -127,13 +127,13 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); Tensor input_t = input_t_; - if (!is_macOS_15_0_or_newer) { + bool is3DConv = input_t.dim() == 5; + if (!is_macOS_15_0_or_newer || is3DConv) { input_t = input_t.contiguous(); } TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer), "Conv3D is only supported on MPS for MacOS_13_2 or newer"); - bool is3DConv = input_t.dim() == 5; TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); diff --git a/test/test_mps.py b/test/test_mps.py index 34b432462d9..b354fb7fd44 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -9058,6 +9058,19 @@ class TestNNMPS(NNTestCase): # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion y2.sum().backward() + # Regression test for https://github.com/pytorch/pytorch/issues/141471 + def test_conv3d_channels_last_3d(self): + m_cpu = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0), device="cpu") + m_mps = copy.deepcopy(m_cpu).to("mps") + + x_cpu = torch.randn(20, 16, 10, 50, 100, device="cpu").to(memory_format=torch.channels_last_3d) + x_mps = x_cpu.detach().clone().to("mps") + + res_cpu = m_cpu(x_cpu) + res_mps = m_mps(x_mps) + + self.assertEqual(res_cpu, res_mps) + def test_gemm_permute_transpose(self): batch_size = 32 n = 20