[MPS][BE] Implement bilineard2d as shader (#145581)

That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation
569b785371/src/libImaging/Resample.c (L306-L309)
I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results.

But even with those changes, `atol`, `rtol` must be tweaked to `1, 0` when scale factor is `1/3` or `2/3` because of the difference of representation  of those values as floats and doubles.

Changes in the performance could be measured using the following script
```python
import torch
import time
import subprocess

def benchmark(device, dtype):
    # Create example inputs
    x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype)
    sf = .5

    # Check output
    y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
    z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="bilinear")
    outputs_match = torch.allclose(y.cpu(), z)
    if not outputs_match:
       atol = (y.cpu() - z).abs().max()
       rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
       print(f"atol={atol} rtol={rtol}")

    # Measure time manually
    start_time = time.time() * 1000
    for _ in range(1000):
        y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
    torch.mps.synchronize
    end_time = time.time() * 1000
    manual_delta = (end_time - start_time)
    average_time = f"{manual_delta:6.1f}"

    return "True " if outputs_match else "False", average_time

outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
    for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
        outputs_match, average_time = benchmark(device, dtype)
        outputs_match_list.append(str(outputs_match))
        average_time_list.append(average_time)

brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device            :                MPS                 |               CPU")
print("Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8")
print(f"Outputs Match     :  ", " |  ".join(outputs_match_list))
print(f"Average Time (us) :", "  |".join(average_time_list))
```

Benchmark results before
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device            :                MPS                 |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8
Outputs Match     :   True  |  True  |  True  |  False |  True  |  True  |  True  |  True
Average Time (us) :  277.3  | 197.2  | 188.0  | 163.5  | 302.8  | 248.1  | 308.7  | 650.9
```
After(almost **100x** perf gain):
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device            :                MPS                 |               CPU
Dtype             :   FP32  |  FP16  |  BF16  |   U8   |  FP32  |  FP16  |  BF16  |  U8
Outputs Match     :   True  |  True  |  True  |  True  |  True  |  True  |  True  |  True
Average Time (us) :    1.7  |   1.5  |   1.7  |   1.5  | 296.5  | 236.0  | 310.8  | 642.6
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145581
Approved by: https://github.com/Skylion007
ghstack dependencies: #145578
This commit is contained in:
Nikita Shulga 2025-01-25 12:00:43 -08:00 committed by PyTorch MergeBot
parent 0afdee4c39
commit 3cf7874ebe
3 changed files with 107 additions and 11 deletions

View file

@ -146,6 +146,87 @@ void upsample_increment_value_bounded(
value);
}
template <typename T>
struct linear_return_type {
typedef float type;
};
template <>
struct linear_return_type<uchar> {
typedef uchar type;
};
template <typename T>
using linear_return_t = typename linear_return_type<T>::type;
template <typename T>
inline linear_return_t<T> linear_interp(T v0, T v1, float x) {
return x * v1 + (1 - x) * v0;
}
// See Note [ Weights computation for uint8_t and multiplication trick ]
// Essentially fall back to fixed floating point arithmetic during uint8
// interpolation, which is not necesserily more accurate (see example below),
// but matches closes to what CPU can deliver
// I.e. mid-point 152+249+172+35 is 152, but algorithm yields 153 as horizontal
// and vertical interpolation is done in separate steps and results are rounded
// to uint8 Also, as Metal is currently limited to 32-bit floats, results will
// never match those on CPU especially for 1/3, 2/3 scale
template <>
inline uchar linear_interp(uchar v0, uchar v1, float x) {
constexpr auto PRECISION_BITS = 15;
constexpr auto one = 1L << (PRECISION_BITS);
constexpr auto onehalf = 1L << (PRECISION_BITS - 1);
auto ix = static_cast<long>(x * one + .5);
auto iomx = static_cast<long>((1.0 - x) * one + .5);
return (onehalf + v0 * iomx + v1 * ix) >> PRECISION_BITS;
}
template <typename T>
kernel void upsample_bilinear2d(
constant T* inputData [[buffer(0)]],
device T* outputData [[buffer(1)]],
constant ulong4& input_strides [[buffer(2)]],
constant ulong4& output_strides [[buffer(3)]],
constant long4& input_sizes [[buffer(4)]],
constant long4& output_sizes [[buffer(5)]],
constant float2& scales [[buffer(6)]],
constant bool& align_corners [[buffer(7)]],
uint thread_index [[thread_position_in_grid]]) {
auto output_x = thread_index % output_sizes.x;
auto output_y = thread_index / output_sizes.x;
auto real_x = area_pixel_compute_source_index(
scales.x, output_x, align_corners, /*cubic=*/false);
auto t_x = fract(real_x);
auto real_y = area_pixel_compute_source_index(
scales.y, output_y, align_corners, /*cubic=*/false);
auto t_y = fract(real_y);
for (int n = 0; n < output_sizes.w; n++) {
for (int c = 0; c < output_sizes.z; c++) {
auto i00 = upsample_get_value_bounded<T>(
inputData, input_sizes.xy, input_strides, n, c, real_y, real_x);
auto i01 = upsample_get_value_bounded<T>(
inputData, input_sizes.xy, input_strides, n, c, real_y, real_x + 1);
auto i10 = upsample_get_value_bounded<T>(
inputData, input_sizes.xy, input_strides, n, c, real_y + 1, real_x);
auto i11 = upsample_get_value_bounded<T>(
inputData,
input_sizes.xy,
input_strides,
n,
c,
real_y + 1,
real_x + 1);
auto i0_l = linear_interp(i00, i01, t_x);
auto i1_l = linear_interp(i10, i11, t_x);
auto res = linear_interp(i0_l, i1_l, t_y);
outputData
[n * output_strides.w + c * output_strides.z +
output_x * output_strides.x + output_y * output_strides.y] =
static_cast<T>(res);
}
}
}
template <typename T>
kernel void upsample_bicubic2d(
constant T* inputData [[buffer(0)]],
@ -284,6 +365,19 @@ kernel void upsample_bicubic2d_backward(
constant bool& align_corners [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]])
#define INSTANTIATE_UPSAMPLE_BILINEAR(DTYPE) \
template [[host_name("upsample_bilinear2d_" #DTYPE)]] kernel void \
upsample_bilinear2d<DTYPE>( \
constant DTYPE * inputData [[buffer(0)]], \
device DTYPE * outputData [[buffer(1)]], \
constant ulong4 & input_strides [[buffer(2)]], \
constant ulong4 & output_strides [[buffer(3)]], \
constant long4 & input_sizes [[buffer(4)]], \
constant long4 & output_sizes [[buffer(5)]], \
constant float2 & scales [[buffer(6)]], \
constant bool& align_corners [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]])
#define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \
template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \
upsample_bicubic2d_backward<DTYPE>( \
@ -297,11 +391,15 @@ kernel void upsample_bicubic2d_backward(
constant bool& align_corners [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]])
INSTANTIATE_UPSAMPLE_BILINEAR(uchar);
INSTANTIATE_UPSAMPLE_BICUBIC(float);
INSTANTIATE_UPSAMPLE_BILINEAR(float);
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float);
INSTANTIATE_UPSAMPLE_BICUBIC(half);
INSTANTIATE_UPSAMPLE_BILINEAR(half);
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UPSAMPLE_BICUBIC(bfloat);
INSTANTIATE_UPSAMPLE_BILINEAR(bfloat);
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(bfloat);
#endif

View file

@ -424,7 +424,7 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
std::optional<double> scales_h,
std::optional<double> scales_w,
const Tensor& output) {
mps::upsample_out_template(input, output_size, std::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d");
}
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps)

View file

@ -601,9 +601,6 @@ def mps_ops_modifier(ops):
# cpu not giving nan for x/0.0
'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
# inconsistency errors between cpu and mps, max seen atol is 2
'nn.functional.interpolatebilinear': [torch.uint8],
}
MACOS_BEFORE_13_3_XFAILLIST = {
@ -636,8 +633,6 @@ def mps_ops_modifier(ops):
MACOS_AFTER_13_1_XFAILLIST = {
# before macOS 13.2 it falls back to cpu and pass the forward pass
'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16], # Unsupported Border padding mode
# inconsistency errors between cpu and mps, max seen atol is 2
'nn.functional.interpolatebilinear': [torch.uint8],
}
MACOS_13_3_XFAILLIST = {
@ -12326,7 +12321,6 @@ class TestConsistency(TestCaseMPS):
'native_layer_norm',
'nn.functional.layer_norm',
'nn.functional.interpolate',
'nn.functional.upsample_bilinear',
'nn.functional.upsample_nearest',
'norm', 'masked.normalize',
'arange', 'linspace',
@ -12414,10 +12408,14 @@ class TestConsistency(TestCaseMPS):
mps_out = op(*mps_args, **mps_kwargs)
atol, rtol = self._compute_tolerances(op, dtype)
if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
atol = 1.0
rtol = 0.0
if (op.name == "nn.functional.interpolate" and dtype == torch.uint8 and
cpu_kwargs.get("mode") == "bilinear" and
cpu_kwargs.get("recompute_scale_factor") is True and
cpu_kwargs.get("scale_factor") == 1.7):
# For 1/3, 2/3 scale factors results will not match CPU ones
# As MPS compute scales in floats, but CPU always used doubles, which results
# in slight numerical differences
atol, rtol = 1, 0
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)