Add the Conv1D support for NHWC format. (#83121)

Fixes https://github.com/pytorch/pytorch/issues/81557

cc @DenisVieriu97

Pull Request resolved: https://github.com/pytorch/pytorch/pull/83121
Approved by: https://github.com/malfet
This commit is contained in:
Kulin Seth 2022-08-10 14:30:20 +00:00 committed by PyTorch MergeBot
parent c25220f05f
commit b3dea3e413
2 changed files with 58 additions and 12 deletions

View file

@ -33,8 +33,9 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ?
MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
descriptor_.weightsLayout = (memory_format == at::MemoryFormat::Contiguous) ?
MPSGraphTensorNamedDataLayoutOIHW : MPSGraphTensorNamedDataLayoutHWIO;
// PyTorch always uses OIHW memory layout for weights
descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
descriptor_.groups = groups;
}
@ -61,6 +62,7 @@ Tensor _mps_convolution(
bias_defined = bias_opt->defined();
auto memory_format = input_t.suggest_memory_format();
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
auto output_t = at::empty(
conv_output_size(input->sizes(), weight->sizes(),
padding, stride, dilation),
@ -68,7 +70,7 @@ Tensor _mps_convolution(
c10::nullopt,
kMPS,
c10::nullopt,
memory_format);
c10::nullopt);
if (output_t.numel() == 0) {
return output_t;
@ -122,6 +124,18 @@ Tensor _mps_convolution(
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
+ to_string(bias_defined) + ":" + bias_shape_key;
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
MPSShape* inputShape = nil;
if (is_channels_last) {
const auto inputSizes = input_t.sizes();
const NSUInteger N = inputSizes[0];
const NSUInteger C = inputSizes[1];
const NSUInteger H = inputSizes[2];
const NSUInteger W = inputSizes[3];
inputShape = @[@(N), @(H), @(W), @(C)];
} else {
inputShape = native_mps::getMPSShape(input_t);
}
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
@ -138,21 +152,29 @@ Tensor _mps_convolution(
padding[1], padding[0],
memory_format, groups);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor* biasTensor = nil;
if(bias_defined)
biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType((bias_opt.value()).scalar_type()));
MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
weightsTensor:weightTensor
descriptor:descriptor_
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor
weightsTensor: weightTensor
descriptor: descriptor_
name: nil];
if (is_channels_last) {
// NHWC -> NCHW
outputTensor = [mpsGraph transposeTensor: [mpsGraph transposeTensor:outputTensor dimension:-1 withDimension:-2 name:nil]
dimension: -2
withDimension: -3
name: nil];
}
if(bias_defined) {
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor
secondaryTensor:biasTensor
name:nil];
outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor
secondaryTensor: biasTensor
name: nil];
}
newCachedGraph->inputTensor_ = inputTensor;
@ -165,7 +187,7 @@ Tensor _mps_convolution(
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t);
auto biasPlaceholder = native_mps::Placeholder();
// Reshape the bias to be broadcastable with output of conv2d

View file

@ -1096,6 +1096,30 @@ class TestMPS(TestCase):
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
padding=padding, output_padding=output_padding, dilation=dilation)
def test_conv1d_channels_last(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.arange((128 * 176), dtype=torch.float32)
a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
out_cpu = model_cpu(a_cpu) # pass
a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)
self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
a_cpu = torch.ones(128, 1, 176)
out_cpu = model_cpu(a_cpu)
a_mps = a_cpu.detach().clone().to("mps")
model_mps = model_cpu.to("mps")
out_mps = model_mps(a_mps)
self.assertEqual(out_cpu.shape, out_mps.shape)
self.assertEqual(out_cpu, out_mps.cpu())
# Test sigmoid
def test_sigmoid(self):
def helper(shape):