From 3cf7874ebea2300c9eddbbe79a004ef3d6801b93 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Sat, 25 Jan 2025 12:00:43 -0800 Subject: [PATCH] [MPS][BE] Implement bilineard2d as shader (#145581) That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation https://github.com/python-pillow/Pillow/blob/569b785371aa717a004adb0166feb565bbb01b7b/src/libImaging/Resample.c#L306-L309 I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results. But even with those changes, `atol`, `rtol` must be tweaked to `1, 0` when scale factor is `1/3` or `2/3` because of the difference of representation of those values as floats and doubles. Changes in the performance could be measured using the following script ```python import torch import time import subprocess def benchmark(device, dtype): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) sf = .5 # Check output y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear") z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="bilinear") outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear") torch.mps.synchronize end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]: outputs_match, average_time = benchmark(device, dtype) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() print(f"\nBenchmarking Results (collected on {brand_string}):") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Benchmark results before ``` Benchmarking Results (collected on Apple M4 Pro): ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8 Outputs Match : True | True | True | False | True | True | True | True Average Time (us) : 277.3 | 197.2 | 188.0 | 163.5 | 302.8 | 248.1 | 308.7 | 650.9 ``` After(almost **100x** perf gain): ``` Benchmarking Results (collected on Apple M4 Pro): ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8 Outputs Match : True | True | True | True | True | True | True | True Average Time (us) : 1.7 | 1.5 | 1.7 | 1.5 | 296.5 | 236.0 | 310.8 | 642.6 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/145581 Approved by: https://github.com/Skylion007 ghstack dependencies: #145578 --- .../ATen/native/mps/kernels/UpSample.metal | 98 +++++++++++++++++++ .../ATen/native/mps/operations/UpSample.mm | 2 +- test/test_mps.py | 18 ++-- 3 files changed, 107 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index 9d36f06ac20..b809fcc061f 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -146,6 +146,87 @@ void upsample_increment_value_bounded( value); } +template +struct linear_return_type { + typedef float type; +}; +template <> +struct linear_return_type { + typedef uchar type; +}; +template +using linear_return_t = typename linear_return_type::type; + +template +inline linear_return_t linear_interp(T v0, T v1, float x) { + return x * v1 + (1 - x) * v0; +} + +// See Note [ Weights computation for uint8_t and multiplication trick ] +// Essentially fall back to fixed floating point arithmetic during uint8 +// interpolation, which is not necesserily more accurate (see example below), +// but matches closes to what CPU can deliver +// I.e. mid-point 152+249+172+35 is 152, but algorithm yields 153 as horizontal +// and vertical interpolation is done in separate steps and results are rounded +// to uint8 Also, as Metal is currently limited to 32-bit floats, results will +// never match those on CPU especially for 1/3, 2/3 scale +template <> +inline uchar linear_interp(uchar v0, uchar v1, float x) { + constexpr auto PRECISION_BITS = 15; + constexpr auto one = 1L << (PRECISION_BITS); + constexpr auto onehalf = 1L << (PRECISION_BITS - 1); + auto ix = static_cast(x * one + .5); + auto iomx = static_cast((1.0 - x) * one + .5); + return (onehalf + v0 * iomx + v1 * ix) >> PRECISION_BITS; +} + +template +kernel void upsample_bilinear2d( + constant T* inputData [[buffer(0)]], + device T* outputData [[buffer(1)]], + constant ulong4& input_strides [[buffer(2)]], + constant ulong4& output_strides [[buffer(3)]], + constant long4& input_sizes [[buffer(4)]], + constant long4& output_sizes [[buffer(5)]], + constant float2& scales [[buffer(6)]], + constant bool& align_corners [[buffer(7)]], + uint thread_index [[thread_position_in_grid]]) { + auto output_x = thread_index % output_sizes.x; + auto output_y = thread_index / output_sizes.x; + auto real_x = area_pixel_compute_source_index( + scales.x, output_x, align_corners, /*cubic=*/false); + auto t_x = fract(real_x); + + auto real_y = area_pixel_compute_source_index( + scales.y, output_y, align_corners, /*cubic=*/false); + auto t_y = fract(real_y); + for (int n = 0; n < output_sizes.w; n++) { + for (int c = 0; c < output_sizes.z; c++) { + auto i00 = upsample_get_value_bounded( + inputData, input_sizes.xy, input_strides, n, c, real_y, real_x); + auto i01 = upsample_get_value_bounded( + inputData, input_sizes.xy, input_strides, n, c, real_y, real_x + 1); + auto i10 = upsample_get_value_bounded( + inputData, input_sizes.xy, input_strides, n, c, real_y + 1, real_x); + auto i11 = upsample_get_value_bounded( + inputData, + input_sizes.xy, + input_strides, + n, + c, + real_y + 1, + real_x + 1); + auto i0_l = linear_interp(i00, i01, t_x); + auto i1_l = linear_interp(i10, i11, t_x); + auto res = linear_interp(i0_l, i1_l, t_y); + outputData + [n * output_strides.w + c * output_strides.z + + output_x * output_strides.x + output_y * output_strides.y] = + static_cast(res); + } + } +} + template kernel void upsample_bicubic2d( constant T* inputData [[buffer(0)]], @@ -284,6 +365,19 @@ kernel void upsample_bicubic2d_backward( constant bool& align_corners [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]) +#define INSTANTIATE_UPSAMPLE_BILINEAR(DTYPE) \ + template [[host_name("upsample_bilinear2d_" #DTYPE)]] kernel void \ + upsample_bilinear2d( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant ulong4 & input_strides [[buffer(2)]], \ + constant ulong4 & output_strides [[buffer(3)]], \ + constant long4 & input_sizes [[buffer(4)]], \ + constant long4 & output_sizes [[buffer(5)]], \ + constant float2 & scales [[buffer(6)]], \ + constant bool& align_corners [[buffer(7)]], \ + uint thread_index [[thread_position_in_grid]]) + #define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \ template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \ upsample_bicubic2d_backward( \ @@ -297,11 +391,15 @@ kernel void upsample_bicubic2d_backward( constant bool& align_corners [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]) +INSTANTIATE_UPSAMPLE_BILINEAR(uchar); INSTANTIATE_UPSAMPLE_BICUBIC(float); +INSTANTIATE_UPSAMPLE_BILINEAR(float); INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float); INSTANTIATE_UPSAMPLE_BICUBIC(half); +INSTANTIATE_UPSAMPLE_BILINEAR(half); INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(half); #if __METAL_VERSION__ >= 310 INSTANTIATE_UPSAMPLE_BICUBIC(bfloat); +INSTANTIATE_UPSAMPLE_BILINEAR(bfloat); INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(bfloat); #endif diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index 93abd301725..a3b099b6cc1 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -424,7 +424,7 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) std::optional scales_h, std::optional scales_w, const Tensor& output) { - mps::upsample_out_template(input, output_size, std::nullopt, scales_h, scales_w, output, align_corners, "bilinear"); + mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d"); } TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) diff --git a/test/test_mps.py b/test/test_mps.py index 2c7fec90cb8..2eb4eedf80b 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -601,9 +601,6 @@ def mps_ops_modifier(ops): # cpu not giving nan for x/0.0 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - - # inconsistency errors between cpu and mps, max seen atol is 2 - 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_BEFORE_13_3_XFAILLIST = { @@ -636,8 +633,6 @@ def mps_ops_modifier(ops): MACOS_AFTER_13_1_XFAILLIST = { # before macOS 13.2 it falls back to cpu and pass the forward pass 'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16], # Unsupported Border padding mode - # inconsistency errors between cpu and mps, max seen atol is 2 - 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_13_3_XFAILLIST = { @@ -12326,7 +12321,6 @@ class TestConsistency(TestCaseMPS): 'native_layer_norm', 'nn.functional.layer_norm', 'nn.functional.interpolate', - 'nn.functional.upsample_bilinear', 'nn.functional.upsample_nearest', 'norm', 'masked.normalize', 'arange', 'linspace', @@ -12414,10 +12408,14 @@ class TestConsistency(TestCaseMPS): mps_out = op(*mps_args, **mps_kwargs) atol, rtol = self._compute_tolerances(op, dtype) - if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8: - atol = 1.0 - rtol = 0.0 - + if (op.name == "nn.functional.interpolate" and dtype == torch.uint8 and + cpu_kwargs.get("mode") == "bilinear" and + cpu_kwargs.get("recompute_scale_factor") is True and + cpu_kwargs.get("scale_factor") == 1.7): + # For 1/3, 2/3 scale factors results will not match CPU ones + # As MPS compute scales in floats, but CPU always used doubles, which results + # in slight numerical differences + atol, rtol = 1, 0 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)