[MPS] Convert channels_last_3d to contiguous for input tensor in nn.Conv3d (#141780)

When the input tensor to Conv3d is in the channels_last_3d memory format the Conv3d op will generate incorrect output (see example image in #141471). This PR checks if the op is 3d, and then attempts to convert the input tensor to contiguous.

Added a regression test that verifies the output by running the same op on the CPU.

I'm unsure if Conv3d supports the channels last memory format after #128393. If it does, we should consider updating the logic to utilize this as it would be more efficient. Perhaps @DenisVieriu97 knows or has more context?

Fixes #141471
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141780
Approved by: https://github.com/malfet
This commit is contained in:
Roy Hvaara 2024-12-01 18:36:53 +00:00 committed by PyTorch MergeBot
parent 5deca07c0d
commit 90f19fee8a
2 changed files with 15 additions and 2 deletions

View file

@ -127,13 +127,13 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
Tensor input_t = input_t_;
if (!is_macOS_15_0_or_newer) {
bool is3DConv = input_t.dim() == 5;
if (!is_macOS_15_0_or_newer || is3DConv) {
input_t = input_t.contiguous();
}
TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
"Conv3D is only supported on MPS for MacOS_13_2 or newer");
bool is3DConv = input_t.dim() == 5;
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");

View file

@ -9058,6 +9058,19 @@ class TestNNMPS(NNTestCase):
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
y2.sum().backward()
# Regression test for https://github.com/pytorch/pytorch/issues/141471
def test_conv3d_channels_last_3d(self):
m_cpu = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0), device="cpu")
m_mps = copy.deepcopy(m_cpu).to("mps")
x_cpu = torch.randn(20, 16, 10, 50, 100, device="cpu").to(memory_format=torch.channels_last_3d)
x_mps = x_cpu.detach().clone().to("mps")
res_cpu = m_cpu(x_cpu)
res_mps = m_mps(x_mps)
self.assertEqual(res_cpu, res_mps)
def test_gemm_permute_transpose(self):
batch_size = 32
n = 20