diff --git a/aten/src/ATen/native/mps/kernels/UpSample.metal b/aten/src/ATen/native/mps/kernels/UpSample.metal index 9d36f06ac20..b809fcc061f 100644 --- a/aten/src/ATen/native/mps/kernels/UpSample.metal +++ b/aten/src/ATen/native/mps/kernels/UpSample.metal @@ -146,6 +146,87 @@ void upsample_increment_value_bounded( value); } +template +struct linear_return_type { + typedef float type; +}; +template <> +struct linear_return_type { + typedef uchar type; +}; +template +using linear_return_t = typename linear_return_type::type; + +template +inline linear_return_t linear_interp(T v0, T v1, float x) { + return x * v1 + (1 - x) * v0; +} + +// See Note [ Weights computation for uint8_t and multiplication trick ] +// Essentially fall back to fixed floating point arithmetic during uint8 +// interpolation, which is not necesserily more accurate (see example below), +// but matches closes to what CPU can deliver +// I.e. mid-point 152+249+172+35 is 152, but algorithm yields 153 as horizontal +// and vertical interpolation is done in separate steps and results are rounded +// to uint8 Also, as Metal is currently limited to 32-bit floats, results will +// never match those on CPU especially for 1/3, 2/3 scale +template <> +inline uchar linear_interp(uchar v0, uchar v1, float x) { + constexpr auto PRECISION_BITS = 15; + constexpr auto one = 1L << (PRECISION_BITS); + constexpr auto onehalf = 1L << (PRECISION_BITS - 1); + auto ix = static_cast(x * one + .5); + auto iomx = static_cast((1.0 - x) * one + .5); + return (onehalf + v0 * iomx + v1 * ix) >> PRECISION_BITS; +} + +template +kernel void upsample_bilinear2d( + constant T* inputData [[buffer(0)]], + device T* outputData [[buffer(1)]], + constant ulong4& input_strides [[buffer(2)]], + constant ulong4& output_strides [[buffer(3)]], + constant long4& input_sizes [[buffer(4)]], + constant long4& output_sizes [[buffer(5)]], + constant float2& scales [[buffer(6)]], + constant bool& align_corners [[buffer(7)]], + uint thread_index [[thread_position_in_grid]]) { + auto output_x = thread_index % output_sizes.x; + auto output_y = thread_index / output_sizes.x; + auto real_x = area_pixel_compute_source_index( + scales.x, output_x, align_corners, /*cubic=*/false); + auto t_x = fract(real_x); + + auto real_y = area_pixel_compute_source_index( + scales.y, output_y, align_corners, /*cubic=*/false); + auto t_y = fract(real_y); + for (int n = 0; n < output_sizes.w; n++) { + for (int c = 0; c < output_sizes.z; c++) { + auto i00 = upsample_get_value_bounded( + inputData, input_sizes.xy, input_strides, n, c, real_y, real_x); + auto i01 = upsample_get_value_bounded( + inputData, input_sizes.xy, input_strides, n, c, real_y, real_x + 1); + auto i10 = upsample_get_value_bounded( + inputData, input_sizes.xy, input_strides, n, c, real_y + 1, real_x); + auto i11 = upsample_get_value_bounded( + inputData, + input_sizes.xy, + input_strides, + n, + c, + real_y + 1, + real_x + 1); + auto i0_l = linear_interp(i00, i01, t_x); + auto i1_l = linear_interp(i10, i11, t_x); + auto res = linear_interp(i0_l, i1_l, t_y); + outputData + [n * output_strides.w + c * output_strides.z + + output_x * output_strides.x + output_y * output_strides.y] = + static_cast(res); + } + } +} + template kernel void upsample_bicubic2d( constant T* inputData [[buffer(0)]], @@ -284,6 +365,19 @@ kernel void upsample_bicubic2d_backward( constant bool& align_corners [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]) +#define INSTANTIATE_UPSAMPLE_BILINEAR(DTYPE) \ + template [[host_name("upsample_bilinear2d_" #DTYPE)]] kernel void \ + upsample_bilinear2d( \ + constant DTYPE * inputData [[buffer(0)]], \ + device DTYPE * outputData [[buffer(1)]], \ + constant ulong4 & input_strides [[buffer(2)]], \ + constant ulong4 & output_strides [[buffer(3)]], \ + constant long4 & input_sizes [[buffer(4)]], \ + constant long4 & output_sizes [[buffer(5)]], \ + constant float2 & scales [[buffer(6)]], \ + constant bool& align_corners [[buffer(7)]], \ + uint thread_index [[thread_position_in_grid]]) + #define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \ template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \ upsample_bicubic2d_backward( \ @@ -297,11 +391,15 @@ kernel void upsample_bicubic2d_backward( constant bool& align_corners [[buffer(7)]], \ uint thread_index [[thread_position_in_grid]]) +INSTANTIATE_UPSAMPLE_BILINEAR(uchar); INSTANTIATE_UPSAMPLE_BICUBIC(float); +INSTANTIATE_UPSAMPLE_BILINEAR(float); INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float); INSTANTIATE_UPSAMPLE_BICUBIC(half); +INSTANTIATE_UPSAMPLE_BILINEAR(half); INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(half); #if __METAL_VERSION__ >= 310 INSTANTIATE_UPSAMPLE_BICUBIC(bfloat); +INSTANTIATE_UPSAMPLE_BILINEAR(bfloat); INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(bfloat); #endif diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index 93abd301725..a3b099b6cc1 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -424,7 +424,7 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps) std::optional scales_h, std::optional scales_w, const Tensor& output) { - mps::upsample_out_template(input, output_size, std::nullopt, scales_h, scales_w, output, align_corners, "bilinear"); + mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d"); } TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps) diff --git a/test/test_mps.py b/test/test_mps.py index 2c7fec90cb8..2eb4eedf80b 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -601,9 +601,6 @@ def mps_ops_modifier(ops): # cpu not giving nan for x/0.0 'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8], - - # inconsistency errors between cpu and mps, max seen atol is 2 - 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_BEFORE_13_3_XFAILLIST = { @@ -636,8 +633,6 @@ def mps_ops_modifier(ops): MACOS_AFTER_13_1_XFAILLIST = { # before macOS 13.2 it falls back to cpu and pass the forward pass 'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16], # Unsupported Border padding mode - # inconsistency errors between cpu and mps, max seen atol is 2 - 'nn.functional.interpolatebilinear': [torch.uint8], } MACOS_13_3_XFAILLIST = { @@ -12326,7 +12321,6 @@ class TestConsistency(TestCaseMPS): 'native_layer_norm', 'nn.functional.layer_norm', 'nn.functional.interpolate', - 'nn.functional.upsample_bilinear', 'nn.functional.upsample_nearest', 'norm', 'masked.normalize', 'arange', 'linspace', @@ -12414,10 +12408,14 @@ class TestConsistency(TestCaseMPS): mps_out = op(*mps_args, **mps_kwargs) atol, rtol = self._compute_tolerances(op, dtype) - if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8: - atol = 1.0 - rtol = 0.0 - + if (op.name == "nn.functional.interpolate" and dtype == torch.uint8 and + cpu_kwargs.get("mode") == "bilinear" and + cpu_kwargs.get("recompute_scale_factor") is True and + cpu_kwargs.get("scale_factor") == 1.7): + # For 1/3, 2/3 scale factors results will not match CPU ones + # As MPS compute scales in floats, but CPU always used doubles, which results + # in slight numerical differences + atol, rtol = 1, 0 self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)