mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS] Add bilineard2d_aa implementation
This commit is contained in:
parent
2a55311773
commit
31a7658843
4 changed files with 93 additions and 1 deletions
|
|
@ -227,6 +227,59 @@ kernel void upsample_bilinear2d(
|
|||
}
|
||||
}
|
||||
|
||||
inline float bilinear_functor(float x) {
|
||||
return abs(x) < 1.0 ? 1.0 - abs(x) : abs(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void upsample_bilinear2d_aa(
|
||||
constant T* inputData [[buffer(0)]],
|
||||
device T* outputData [[buffer(1)]],
|
||||
constant ulong4& input_strides [[buffer(2)]],
|
||||
constant ulong4& output_strides [[buffer(3)]],
|
||||
constant long4& input_sizes [[buffer(4)]],
|
||||
constant long4& output_sizes [[buffer(5)]],
|
||||
constant float2& scales [[buffer(6)]],
|
||||
constant bool& align_corners [[buffer(7)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
auto output_x = thread_index % output_sizes.x;
|
||||
auto output_y = thread_index / output_sizes.x;
|
||||
auto x_center = area_pixel_compute_source_index(
|
||||
scales.x, output_x, align_corners, /*cubic=*/true) +
|
||||
.5;
|
||||
auto y_center = area_pixel_compute_source_index(
|
||||
scales.y, output_y, align_corners, /*cubic=*/true) +
|
||||
.5;
|
||||
auto clamped_scales = max(1.0, scales);
|
||||
auto x_min = max(0L, long(floor(x_center - clamped_scales.x + .5)));
|
||||
auto x_max =
|
||||
min(input_sizes.x, long(floor(x_center + clamped_scales.x + .5)));
|
||||
auto y_min = max(0L, long(floor(y_center - clamped_scales.y + .5)));
|
||||
auto y_max =
|
||||
min(input_sizes.y, long(floor(y_center + clamped_scales.y + .5)));
|
||||
for (int n = 0; n < output_sizes.w; n++) {
|
||||
for (int c = 0; c < output_sizes.z; c++) {
|
||||
float res = 0.0;
|
||||
float ws = 0.0;
|
||||
constant auto* input =
|
||||
inputData + n * input_strides.w + c * input_strides.z;
|
||||
for (auto y = y_min; y < y_max; ++y) {
|
||||
auto dy = bilinear_functor((y - y_center + 0.5) / clamped_scales.y);
|
||||
for (auto x = x_min; x < x_max; ++x) {
|
||||
auto dx = bilinear_functor((x - x_center + 0.5) / clamped_scales.x);
|
||||
auto val = input[x * input_strides.x + y * input_strides.y];
|
||||
res += val * dx * dy;
|
||||
ws += dx * dy;
|
||||
}
|
||||
}
|
||||
outputData
|
||||
[n * output_strides.w + c * output_strides.z +
|
||||
output_x * output_strides.x + output_y * output_strides.y] =
|
||||
res / ws;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void upsample_bicubic2d(
|
||||
constant T* inputData [[buffer(0)]],
|
||||
|
|
@ -378,6 +431,19 @@ kernel void upsample_bicubic2d_backward(
|
|||
constant bool& align_corners [[buffer(7)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
#define INSTANTIATE_UPSAMPLE_BILINEAR_AA(DTYPE) \
|
||||
template [[host_name("upsample_bilinear2d_aa_" #DTYPE)]] kernel void \
|
||||
upsample_bilinear2d_aa<DTYPE>( \
|
||||
constant DTYPE * inputData [[buffer(0)]], \
|
||||
device DTYPE * outputData [[buffer(1)]], \
|
||||
constant ulong4 & input_strides [[buffer(2)]], \
|
||||
constant ulong4 & output_strides [[buffer(3)]], \
|
||||
constant long4 & input_sizes [[buffer(4)]], \
|
||||
constant long4 & output_sizes [[buffer(5)]], \
|
||||
constant float2 & scales [[buffer(6)]], \
|
||||
constant bool& align_corners [[buffer(7)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
#define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \
|
||||
template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \
|
||||
upsample_bicubic2d_backward<DTYPE>( \
|
||||
|
|
@ -394,6 +460,7 @@ kernel void upsample_bicubic2d_backward(
|
|||
INSTANTIATE_UPSAMPLE_BILINEAR(uchar);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(float);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(float);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR_AA(float);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(half);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(half);
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@
|
|||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/_upsample_bilinear2d_aa_backward_native.h>
|
||||
#include <ATen/ops/_upsample_bilinear2d_aa_native.h>
|
||||
#include <ATen/ops/_upsample_nearest_exact1d.h>
|
||||
#include <ATen/ops/_upsample_nearest_exact1d_backward.h>
|
||||
#include <ATen/ops/_upsample_nearest_exact1d_backward_native.h>
|
||||
|
|
@ -461,4 +463,26 @@ TORCH_IMPL_FUNC(upsample_bicubic2d_backward_out_mps)
|
|||
grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w, "bicubic2d");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_out_mps)
|
||||
(const Tensor& input,
|
||||
IntArrayRef output_size,
|
||||
bool align_corners,
|
||||
std::optional<double> scales_h,
|
||||
std::optional<double> scales_w,
|
||||
const Tensor& output) {
|
||||
mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d_aa");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(_upsample_bilinear2d_aa_backward_out_mps)
|
||||
(const Tensor& grad_output,
|
||||
IntArrayRef output_size,
|
||||
IntArrayRef input_size,
|
||||
bool align_corners,
|
||||
std::optional<double> scales_h,
|
||||
std::optional<double> scales_w,
|
||||
const Tensor& grad_input) {
|
||||
mps::upsample_kernel_backward_out_template(
|
||||
grad_input, grad_output, output_size, input_size, align_corners, scales_h, scales_w, "bilinear2d_aa");
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -12726,6 +12726,7 @@
|
|||
dispatch:
|
||||
CPU: _upsample_bilinear2d_aa_out_cpu
|
||||
CUDA: _upsample_bilinear2d_aa_out_cuda
|
||||
MPS: _upsample_bilinear2d_aa_out_mps
|
||||
|
||||
- func: _upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
|
||||
python_module: nn
|
||||
|
|
@ -12737,6 +12738,7 @@
|
|||
dispatch:
|
||||
CPU: _upsample_bilinear2d_aa_backward_out_cpu
|
||||
CUDA: _upsample_bilinear2d_aa_backward_out_cuda
|
||||
MPS: _upsample_bilinear2d_aa_backward_out_mps
|
||||
|
||||
- func: _upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor
|
||||
python_module: nn
|
||||
|
|
|
|||
|
|
@ -9833,7 +9833,6 @@ class TestNNDeviceType(NNTestCase):
|
|||
else:
|
||||
_ = F.interpolate(x, (12, 12), mode=mode, antialias=antialias)
|
||||
|
||||
@expectedFailureMPS # NotImplementedError: aten::_upsample_bilinear2d_aa.out https://github.com/pytorch/pytorch/issues/77764
|
||||
@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last])
|
||||
def test_upsamplingBilinear2d_aa_correctness(self, device, memory_format):
|
||||
# NOTE: We expand the batch dim such that `b*c` is above the maximum
|
||||
|
|
|
|||
Loading…
Reference in a new issue