mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Add the Conv1D support for NHWC format. (#83121)
Fixes https://github.com/pytorch/pytorch/issues/81557 cc @DenisVieriu97 Pull Request resolved: https://github.com/pytorch/pytorch/pull/83121 Approved by: https://github.com/malfet
This commit is contained in:
parent
c25220f05f
commit
b3dea3e413
2 changed files with 58 additions and 12 deletions
|
|
@ -33,8 +33,9 @@ void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
|
|||
|
||||
descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ?
|
||||
MPSGraphTensorNamedDataLayoutNCHW : MPSGraphTensorNamedDataLayoutNHWC;
|
||||
descriptor_.weightsLayout = (memory_format == at::MemoryFormat::Contiguous) ?
|
||||
MPSGraphTensorNamedDataLayoutOIHW : MPSGraphTensorNamedDataLayoutHWIO;
|
||||
|
||||
// PyTorch always uses OIHW memory layout for weights
|
||||
descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
|
||||
descriptor_.groups = groups;
|
||||
}
|
||||
|
||||
|
|
@ -61,6 +62,7 @@ Tensor _mps_convolution(
|
|||
bias_defined = bias_opt->defined();
|
||||
|
||||
auto memory_format = input_t.suggest_memory_format();
|
||||
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
|
||||
auto output_t = at::empty(
|
||||
conv_output_size(input->sizes(), weight->sizes(),
|
||||
padding, stride, dilation),
|
||||
|
|
@ -68,7 +70,7 @@ Tensor _mps_convolution(
|
|||
c10::nullopt,
|
||||
kMPS,
|
||||
c10::nullopt,
|
||||
memory_format);
|
||||
c10::nullopt);
|
||||
|
||||
if (output_t.numel() == 0) {
|
||||
return output_t;
|
||||
|
|
@ -122,6 +124,18 @@ Tensor _mps_convolution(
|
|||
+ mps::getTensorsStringKey({input_t, weight_t}) + ":"
|
||||
+ to_string(bias_defined) + ":" + bias_shape_key;
|
||||
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
|
||||
MPSShape* inputShape = nil;
|
||||
|
||||
if (is_channels_last) {
|
||||
const auto inputSizes = input_t.sizes();
|
||||
const NSUInteger N = inputSizes[0];
|
||||
const NSUInteger C = inputSizes[1];
|
||||
const NSUInteger H = inputSizes[2];
|
||||
const NSUInteger W = inputSizes[3];
|
||||
inputShape = @[@(N), @(H), @(W), @(C)];
|
||||
} else {
|
||||
inputShape = native_mps::getMPSShape(input_t);
|
||||
}
|
||||
|
||||
if(!cachedGraph) {
|
||||
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
|
||||
|
|
@ -138,21 +152,29 @@ Tensor _mps_convolution(
|
|||
padding[1], padding[0],
|
||||
memory_format, groups);
|
||||
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
|
||||
MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, native_mps::getMPSScalarType(input_t.scalar_type()), inputShape);
|
||||
MPSGraphTensor* weightTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
|
||||
|
||||
MPSGraphTensor* biasTensor = nil;
|
||||
if(bias_defined)
|
||||
biasTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType((bias_opt.value()).scalar_type()));
|
||||
|
||||
MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
|
||||
weightsTensor:weightTensor
|
||||
descriptor:descriptor_
|
||||
name:nil];
|
||||
MPSGraphTensor* outputTensor = [mpsGraph convolution2DWithSourceTensor: inputTensor
|
||||
weightsTensor: weightTensor
|
||||
descriptor: descriptor_
|
||||
name: nil];
|
||||
if (is_channels_last) {
|
||||
// NHWC -> NCHW
|
||||
outputTensor = [mpsGraph transposeTensor: [mpsGraph transposeTensor:outputTensor dimension:-1 withDimension:-2 name:nil]
|
||||
dimension: -2
|
||||
withDimension: -3
|
||||
name: nil];
|
||||
}
|
||||
|
||||
if(bias_defined) {
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor
|
||||
secondaryTensor:biasTensor
|
||||
name:nil];
|
||||
outputTensor = [mpsGraph additionWithPrimaryTensor: outputTensor
|
||||
secondaryTensor: biasTensor
|
||||
name: nil];
|
||||
}
|
||||
|
||||
newCachedGraph->inputTensor_ = inputTensor;
|
||||
|
|
@ -165,7 +187,7 @@ Tensor _mps_convolution(
|
|||
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
|
||||
}
|
||||
|
||||
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
|
||||
auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
|
||||
auto weightsPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_t);
|
||||
auto biasPlaceholder = native_mps::Placeholder();
|
||||
// Reshape the bias to be broadcastable with output of conv2d
|
||||
|
|
|
|||
|
|
@ -1096,6 +1096,30 @@ class TestMPS(TestCase):
|
|||
helper((N, C_out, H, W), (C_out, C_in, kH, kW), bias_shape=(C_in), stride=stride,
|
||||
padding=padding, output_padding=output_padding, dilation=dilation)
|
||||
|
||||
def test_conv1d_channels_last(self):
|
||||
model_cpu = torch.nn.Conv1d(1, 128, 3)
|
||||
a_cpu = torch.arange((128 * 176), dtype=torch.float32)
|
||||
a_cpu = a_cpu.view(128, 176, 1).permute(0, 2, 1)
|
||||
out_cpu = model_cpu(a_cpu) # pass
|
||||
|
||||
a_mps = a_cpu.detach().clone().to("mps")
|
||||
model_mps = model_cpu.to("mps")
|
||||
out_mps = model_mps(a_mps)
|
||||
|
||||
self.assertEqual(out_cpu, out_mps.cpu(), rtol=2.6e-05, atol=2e-04)
|
||||
|
||||
def test_conv1d_contiguous(self):
|
||||
model_cpu = torch.nn.Conv1d(1, 128, 3)
|
||||
a_cpu = torch.ones(128, 1, 176)
|
||||
out_cpu = model_cpu(a_cpu)
|
||||
|
||||
a_mps = a_cpu.detach().clone().to("mps")
|
||||
model_mps = model_cpu.to("mps")
|
||||
out_mps = model_mps(a_mps)
|
||||
|
||||
self.assertEqual(out_cpu.shape, out_mps.shape)
|
||||
self.assertEqual(out_cpu, out_mps.cpu())
|
||||
|
||||
# Test sigmoid
|
||||
def test_sigmoid(self):
|
||||
def helper(shape):
|
||||
|
|
|
|||
Loading…
Reference in a new issue