[MPS] Add bilineard2d_aa implementation

This commit is contained in:
Nikita Shulga 2024-10-08 06:39:06 -07:00
parent 2a55311773
commit 31a7658843
4 changed files with 93 additions and 1 deletions

View file

@ -227,6 +227,59 @@ kernel void upsample_bilinear2d(
}
}
inline float bilinear_functor(float x) {
return abs(x) < 1.0 ? 1.0 - abs(x) : abs(x);
}
template <typename T>
kernel void upsample_bilinear2d_aa(
constant T* inputData [[buffer(0)]],
device T* outputData [[buffer(1)]],
constant ulong4& input_strides [[buffer(2)]],
constant ulong4& output_strides [[buffer(3)]],
constant long4& input_sizes [[buffer(4)]],
constant long4& output_sizes [[buffer(5)]],
constant float2& scales [[buffer(6)]],
constant bool& align_corners [[buffer(7)]],
uint thread_index [[thread_position_in_grid]]) {
auto output_x = thread_index % output_sizes.x;
auto output_y = thread_index / output_sizes.x;
auto x_center = area_pixel_compute_source_index(
scales.x, output_x, align_corners, /*cubic=*/true) +
.5;
auto y_center = area_pixel_compute_source_index(
scales.y, output_y, align_corners, /*cubic=*/true) +
.5;
auto clamped_scales = max(1.0, scales);
auto x_min = max(0L, long(floor(x_center - clamped_scales.x + .5)));
auto x_max =
min(input_sizes.x, long(floor(x_center + clamped_scales.x + .5)));
auto y_min = max(0L, long(floor(y_center - clamped_scales.y + .5)));
auto y_max =
min(input_sizes.y, long(floor(y_center + clamped_scales.y + .5)));
for (int n = 0; n < output_sizes.w; n++) {
for (int c = 0; c < output_sizes.z; c++) {
float res = 0.0;
float ws = 0.0;
constant auto* input =
inputData + n * input_strides.w + c * input_strides.z;
for (auto y = y_min; y < y_max; ++y) {
auto dy = bilinear_functor((y - y_center + 0.5) / clamped_scales.y);
for (auto x = x_min; x < x_max; ++x) {
auto dx = bilinear_functor((x - x_center + 0.5) / clamped_scales.x);
auto val = input[x * input_strides.x + y * input_strides.y];
res += val * dx * dy;
ws += dx * dy;
}
}
outputData
[n * output_strides.w + c * output_strides.z +
output_x * output_strides.x + output_y * output_strides.y] =
res / ws;
}
}
}
template <typename T>
kernel void upsample_bicubic2d(
constant T* inputData [[buffer(0)]],
@ -378,6 +431,19 @@ kernel void upsample_bicubic2d_backward(
constant bool& align_corners [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]])
#define INSTANTIATE_UPSAMPLE_BILINEAR_AA(DTYPE) \
template [[host_name("upsample_bilinear2d_aa_" #DTYPE)]] kernel void \
upsample_bilinear2d_aa<DTYPE>( \
constant DTYPE * inputData [[buffer(0)]], \
device DTYPE * outputData [[buffer(1)]], \
constant ulong4 & input_strides [[buffer(2)]], \
constant ulong4 & output_strides [[buffer(3)]], \
constant long4 & input_sizes [[buffer(4)]], \
constant long4 & output_sizes [[buffer(5)]], \
constant float2 & scales [[buffer(6)]], \
constant bool& align_corners [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]])
#define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \
template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \
upsample_bicubic2d_backward<DTYPE>( \
@ -394,6 +460,7 @@ kernel void upsample_bicubic2d_backward(
INSTANTIATE_UPSAMPLE_BILINEAR(uchar);
INSTANTIATE_UPSAMPLE_BICUBIC(float);
INSTANTIATE_UPSAMPLE_BILINEAR(float);
INSTANTIATE_UPSAMPLE_BILINEAR_AA(float);
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float);
INSTANTIATE_UPSAMPLE_BICUBIC(half);
INSTANTIATE_UPSAMPLE_BILINEAR(half);

View file

@ -9,6 +9,8 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
#include <ATen/ops/_upsample_bilinear2d_aa_native.h>
#include <ATen/ops/_upsample_nearest_exact1d.h>
#include <ATen/ops/_upsample_nearest_exact1d_backward.h>
#include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
@ -461,4 +463,26 @@ TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_mps)
grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w, "bicubic2d");
}
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_mps)
(const Tensor& input,
IntArrayRef output_size,
bool align_corners,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& output) {
mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d_aa");
}
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_mps)
(const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size,
bool align_corners,
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& grad_input) {
mps::upsample_kernel_backward_out_template(
grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w, "bilinear2d_aa");
}
} // namespace at::native

View file

@ -12726,6 +12726,7 @@
dispatch:
CPU: _upsample_bilinear2d_aa_out_cpu
CUDA: _upsample_bilinear2d_aa_out_cuda
MPS: _upsample_bilinear2d_aa_out_mps
- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
python_module: nn
@ -12737,6 +12738,7 @@
dispatch:
CPU: _upsample_bilinear2d_aa_backward_out_cpu
CUDA: _upsample_bilinear2d_aa_backward_out_cuda
MPS: _upsample_bilinear2d_aa_backward_out_mps
- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
python_module: nn

View file

@ -9833,7 +9833,6 @@ class TestNNDeviceType(NNTestCase):
else:
_ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
@expectedFailureMPS # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
# NOTE: We expand the batch dim such that `b*c` is above the maximum