From b3dea3e413e1d58fa92ebc60616c6a78c331fcc0 Mon Sep 17 00:00:00 2001 From: Kulin Seth Date: Wed, 10 Aug 2022 14:30:20 +0000 Subject: [PATCH] Add the Conv1D support for NHWC format. (#83121) Fixes https://github.com/pytorch/pytorch/issues/81557 cc @DenisVieriu97 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83121 Approved by: https://github.com/malfet --- .../ATen/native/mps/operations/Convolution.mm | 46 ++++++++++++++----- test/test_mps.py | 24 ++++++++++ 2 files changed, 58 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 0fe690698c3..2c74dcf0766 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -33,8 +33,9 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC; - descriptor_.weightsLayout = (memory_format == at::MemoryFormat::Contiguous) ? - MPSGraphTensorNamedDataLayoutOIHW : MPSGraphTensorNamedDataLayoutHWIO; + + // PyTorch always uses OIHW memory layout for weights + descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW; descriptor_.groups = groups; } @@ -61,6 +62,7 @@ Tensor _mps_convolution( bias_defined = bias_opt->defined(); auto memory_format = input_t.suggest_memory_format(); + bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast); auto output_t = at::empty( conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation), @@ -68,7 +70,7 @@ Tensor _mps_convolution( c10::nullopt, kMPS, c10::nullopt, - memory_format); + c10::nullopt); if (output_t.numel() == 0) { return output_t; @@ -122,6 +124,18 @@ Tensor _mps_convolution( + mps::getTensorsStringKey({input_t, weight_t}) + ":" + to_string(bias_defined) + ":" + bias_shape_key; CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + MPSShape* inputShape = nil; + + if (is_channels_last) { + const auto inputSizes = input_t.sizes(); + const NSUInteger N = inputSizes[0]; + const NSUInteger C = inputSizes[1]; + const NSUInteger H = inputSizes[2]; + const NSUInteger W = inputSizes[3]; + inputShape = @[@(N), @(H), @(W), @(C)]; + } else { + inputShape = native_mps::getMPSShape(input_t); + } if(!cachedGraph) { native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { @@ -138,21 +152,29 @@ Tensor _mps_convolution( padding[1], padding[0], memory_format, groups); - MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t); + MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape); MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t); + MPSGraphTensor* biasTensor = nil; if(bias_defined) biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType((bias_opt.value()).scalar_type())); - MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor - weightsTensor:weightTensor - descriptor:descriptor_ - name:nil]; + MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor + weightsTensor: weightTensor + descriptor: descriptor_ + name: nil]; + if (is_channels_last) { + // NHWC -> NCHW + outputTensor = [mpsGraph transposeTensor: [mpsGraph transposeTensor:outputTensor dimension:-1 withDimension:-2 name:nil] + dimension: -2 + withDimension: -3 + name: nil]; + } if(bias_defined) { - outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor - secondaryTensor:biasTensor - name:nil]; + outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor + secondaryTensor: biasTensor + name: nil]; } newCachedGraph->inputTensor_ = inputTensor; @@ -165,7 +187,7 @@ Tensor _mps_convolution( cachedGraph = static_cast(tmpCachedGraph); } - auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t); + auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape); auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t); auto biasPlaceholder = native_mps::Placeholder(); // Reshape the bias to be broadcastable with output of conv2d diff --git a/test/test_mps.py b/test/test_mps.py index 3d006b8a182..f0740b280b6 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -1096,6 +1096,30 @@ class TestMPS(TestCase): helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride, padding=padding, output_padding=output_padding, dilation=dilation) + def test_conv1d_channels_last(self): + model_cpu = torch.nn.Conv1d(1, 128, 3) + a_cpu = torch.arange((128 * 176), dtype=torch.float32) + a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1) + out_cpu = model_cpu(a_cpu) # pass + + a_mps = a_cpu.detach().clone().to("mps") + model_mps = model_cpu.to("mps") + out_mps = model_mps(a_mps) + + self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04) + + def test_conv1d_contiguous(self): + model_cpu = torch.nn.Conv1d(1, 128, 3) + a_cpu = torch.ones(128, 1, 176) + out_cpu = model_cpu(a_cpu) + + a_mps = a_cpu.detach().clone().to("mps") + model_mps = model_cpu.to("mps") + out_mps = model_mps(a_mps) + + self.assertEqual(out_cpu.shape, out_mps.shape) + self.assertEqual(out_cpu, out_mps.cpu()) + # Test sigmoid def test_sigmoid(self): def helper(shape):