mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Convert channels_last_3d to contiguous for input tensor in nn.Conv3d (#141780)
When the input tensor to Conv3d is in the channels_last_3d memory format the Conv3d op will generate incorrect output (see example image in #141471). This PR checks if the op is 3d, and then attempts to convert the input tensor to contiguous. Added a regression test that verifies the output by running the same op on the CPU. I'm unsure if Conv3d supports the channels last memory format after #128393. If it does, we should consider updating the logic to utilize this as it would be more efficient. Perhaps @DenisVieriu97 knows or has more context? Fixes #141471 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141780 Approved by: https://github.com/malfet
This commit is contained in:
parent
5deca07c0d
commit
90f19fee8a
2 changed files with 15 additions and 2 deletions
|
|
@ -127,13 +127,13 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_,
|
|||
const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
|
||||
const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
|
||||
Tensor input_t = input_t_;
|
||||
if (!is_macOS_15_0_or_newer) {
|
||||
bool is3DConv = input_t.dim() == 5;
|
||||
if (!is_macOS_15_0_or_newer || is3DConv) {
|
||||
input_t = input_t.contiguous();
|
||||
}
|
||||
|
||||
TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
|
||||
"Conv3D is only supported on MPS for MacOS_13_2 or newer");
|
||||
bool is3DConv = input_t.dim() == 5;
|
||||
|
||||
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
|
||||
|
||||
|
|
|
|||
|
|
@ -9058,6 +9058,19 @@ class TestNNMPS(NNTestCase):
|
|||
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
|
||||
y2.sum().backward()
|
||||
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/141471
|
||||
def test_conv3d_channels_last_3d(self):
|
||||
m_cpu = nn.Conv3d(16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(4, 2, 0), device="cpu")
|
||||
m_mps = copy.deepcopy(m_cpu).to("mps")
|
||||
|
||||
x_cpu = torch.randn(20, 16, 10, 50, 100, device="cpu").to(memory_format=torch.channels_last_3d)
|
||||
x_mps = x_cpu.detach().clone().to("mps")
|
||||
|
||||
res_cpu = m_cpu(x_cpu)
|
||||
res_mps = m_mps(x_mps)
|
||||
|
||||
self.assertEqual(res_cpu, res_mps)
|
||||
|
||||
def test_gemm_permute_transpose(self):
|
||||
batch_size = 32
|
||||
n = 20
|
||||
|
|
|
|||
Loading…
Reference in a new issue