mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[MPS][BE] Implement bilineard2d as shader (#145581)
That significantly improves performance and addresses correctness problem(to an extend permitted by reducing precision of scale factor computation to float32). uint8 scaling algorithm mimics CPU/Pillow implementation
569b785371/src/libImaging/Resample.c (L306-L309)
I.e. using fixed precision integral arithmetic and rounding results of horizontal interpolation back to integers before performing vertical one, which results in technically less accurate results.
But even with those changes, `atol`, `rtol` must be tweaked to `1, 0` when scale factor is `1/3` or `2/3` because of the difference of representation of those values as floats and doubles.
Changes in the performance could be measured using the following script
```python
import torch
import time
import subprocess
def benchmark(device, dtype):
# Create example inputs
x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype)
sf = .5
# Check output
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
z = torch.nn.functional.interpolate(x.cpu(), scale_factor=sf, mode="bilinear")
outputs_match = torch.allclose(y.cpu(), z)
if not outputs_match:
atol = (y.cpu() - z).abs().max()
rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max()
print(f"atol={atol} rtol={rtol}")
# Measure time manually
start_time = time.time() * 1000
for _ in range(1000):
y = torch.nn.functional.interpolate(x, scale_factor=sf, mode="bilinear")
torch.mps.synchronize
end_time = time.time() * 1000
manual_delta = (end_time - start_time)
average_time = f"{manual_delta:6.1f}"
return "True " if outputs_match else "False", average_time
outputs_match_list = []
average_time_list = []
for device in ["mps", "cpu"]:
for dtype in [torch.float32, torch.float16, torch.bfloat16, torch.uint8]:
outputs_match, average_time = benchmark(device, dtype)
outputs_match_list.append(str(outputs_match))
average_time_list.append(average_time)
brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip()
print(f"\nBenchmarking Results (collected on {brand_string}):")
print("-"*40)
print("Device : MPS | CPU")
print("Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8")
print(f"Outputs Match : ", " | ".join(outputs_match_list))
print(f"Average Time (us) :", " |".join(average_time_list))
```
Benchmark results before
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8
Outputs Match : True | True | True | False | True | True | True | True
Average Time (us) : 277.3 | 197.2 | 188.0 | 163.5 | 302.8 | 248.1 | 308.7 | 650.9
```
After(almost **100x** perf gain):
```
Benchmarking Results (collected on Apple M4 Pro):
----------------------------------------
Device : MPS | CPU
Dtype : FP32 | FP16 | BF16 | U8 | FP32 | FP16 | BF16 | U8
Outputs Match : True | True | True | True | True | True | True | True
Average Time (us) : 1.7 | 1.5 | 1.7 | 1.5 | 296.5 | 236.0 | 310.8 | 642.6
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145581
Approved by: https://github.com/Skylion007
ghstack dependencies: #145578
This commit is contained in:
parent
0afdee4c39
commit
3cf7874ebe
3 changed files with 107 additions and 11 deletions
|
|
@ -146,6 +146,87 @@ void upsample_increment_value_bounded(
|
|||
value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct linear_return_type {
|
||||
typedef float type;
|
||||
};
|
||||
template <>
|
||||
struct linear_return_type<uchar> {
|
||||
typedef uchar type;
|
||||
};
|
||||
template <typename T>
|
||||
using linear_return_t = typename linear_return_type<T>::type;
|
||||
|
||||
template <typename T>
|
||||
inline linear_return_t<T> linear_interp(T v0, T v1, float x) {
|
||||
return x * v1 + (1 - x) * v0;
|
||||
}
|
||||
|
||||
// See Note [ Weights computation for uint8_t and multiplication trick ]
|
||||
// Essentially fall back to fixed floating point arithmetic during uint8
|
||||
// interpolation, which is not necesserily more accurate (see example below),
|
||||
// but matches closes to what CPU can deliver
|
||||
// I.e. mid-point 152+249+172+35 is 152, but algorithm yields 153 as horizontal
|
||||
// and vertical interpolation is done in separate steps and results are rounded
|
||||
// to uint8 Also, as Metal is currently limited to 32-bit floats, results will
|
||||
// never match those on CPU especially for 1/3, 2/3 scale
|
||||
template <>
|
||||
inline uchar linear_interp(uchar v0, uchar v1, float x) {
|
||||
constexpr auto PRECISION_BITS = 15;
|
||||
constexpr auto one = 1L << (PRECISION_BITS);
|
||||
constexpr auto onehalf = 1L << (PRECISION_BITS - 1);
|
||||
auto ix = static_cast<long>(x * one + .5);
|
||||
auto iomx = static_cast<long>((1.0 - x) * one + .5);
|
||||
return (onehalf + v0 * iomx + v1 * ix) >> PRECISION_BITS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void upsample_bilinear2d(
|
||||
constant T* inputData [[buffer(0)]],
|
||||
device T* outputData [[buffer(1)]],
|
||||
constant ulong4& input_strides [[buffer(2)]],
|
||||
constant ulong4& output_strides [[buffer(3)]],
|
||||
constant long4& input_sizes [[buffer(4)]],
|
||||
constant long4& output_sizes [[buffer(5)]],
|
||||
constant float2& scales [[buffer(6)]],
|
||||
constant bool& align_corners [[buffer(7)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
auto output_x = thread_index % output_sizes.x;
|
||||
auto output_y = thread_index / output_sizes.x;
|
||||
auto real_x = area_pixel_compute_source_index(
|
||||
scales.x, output_x, align_corners, /*cubic=*/false);
|
||||
auto t_x = fract(real_x);
|
||||
|
||||
auto real_y = area_pixel_compute_source_index(
|
||||
scales.y, output_y, align_corners, /*cubic=*/false);
|
||||
auto t_y = fract(real_y);
|
||||
for (int n = 0; n < output_sizes.w; n++) {
|
||||
for (int c = 0; c < output_sizes.z; c++) {
|
||||
auto i00 = upsample_get_value_bounded<T>(
|
||||
inputData, input_sizes.xy, input_strides, n, c, real_y, real_x);
|
||||
auto i01 = upsample_get_value_bounded<T>(
|
||||
inputData, input_sizes.xy, input_strides, n, c, real_y, real_x + 1);
|
||||
auto i10 = upsample_get_value_bounded<T>(
|
||||
inputData, input_sizes.xy, input_strides, n, c, real_y + 1, real_x);
|
||||
auto i11 = upsample_get_value_bounded<T>(
|
||||
inputData,
|
||||
input_sizes.xy,
|
||||
input_strides,
|
||||
n,
|
||||
c,
|
||||
real_y + 1,
|
||||
real_x + 1);
|
||||
auto i0_l = linear_interp(i00, i01, t_x);
|
||||
auto i1_l = linear_interp(i10, i11, t_x);
|
||||
auto res = linear_interp(i0_l, i1_l, t_y);
|
||||
outputData
|
||||
[n * output_strides.w + c * output_strides.z +
|
||||
output_x * output_strides.x + output_y * output_strides.y] =
|
||||
static_cast<T>(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
kernel void upsample_bicubic2d(
|
||||
constant T* inputData [[buffer(0)]],
|
||||
|
|
@ -284,6 +365,19 @@ kernel void upsample_bicubic2d_backward(
|
|||
constant bool& align_corners [[buffer(7)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
#define INSTANTIATE_UPSAMPLE_BILINEAR(DTYPE) \
|
||||
template [[host_name("upsample_bilinear2d_" #DTYPE)]] kernel void \
|
||||
upsample_bilinear2d<DTYPE>( \
|
||||
constant DTYPE * inputData [[buffer(0)]], \
|
||||
device DTYPE * outputData [[buffer(1)]], \
|
||||
constant ulong4 & input_strides [[buffer(2)]], \
|
||||
constant ulong4 & output_strides [[buffer(3)]], \
|
||||
constant long4 & input_sizes [[buffer(4)]], \
|
||||
constant long4 & output_sizes [[buffer(5)]], \
|
||||
constant float2 & scales [[buffer(6)]], \
|
||||
constant bool& align_corners [[buffer(7)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
#define INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(DTYPE) \
|
||||
template [[host_name("upsample_bicubic2d_backward_" #DTYPE)]] kernel void \
|
||||
upsample_bicubic2d_backward<DTYPE>( \
|
||||
|
|
@ -297,11 +391,15 @@ kernel void upsample_bicubic2d_backward(
|
|||
constant bool& align_corners [[buffer(7)]], \
|
||||
uint thread_index [[thread_position_in_grid]])
|
||||
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(uchar);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(float);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(float);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(float);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(half);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(half);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC(bfloat);
|
||||
INSTANTIATE_UPSAMPLE_BILINEAR(bfloat);
|
||||
INSTANTIATE_UPSAMPLE_BICUBIC_BACKWARD(bfloat);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -424,7 +424,7 @@ TORCH_IMPL_FUNC(upsample_bilinear2d_out_mps)
|
|||
std::optional<double> scales_h,
|
||||
std::optional<double> scales_w,
|
||||
const Tensor& output) {
|
||||
mps::upsample_out_template(input, output_size, std::nullopt, scales_h, scales_w, output, align_corners, "bilinear");
|
||||
mps::upsample_kernel_out_template(input, output_size, align_corners, scales_h, scales_w, output, "bilinear2d");
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(upsample_bilinear2d_backward_out_mps)
|
||||
|
|
|
|||
|
|
@ -601,9 +601,6 @@ def mps_ops_modifier(ops):
|
|||
|
||||
# cpu not giving nan for x/0.0
|
||||
'atan2': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8, torch.int8],
|
||||
|
||||
# inconsistency errors between cpu and mps, max seen atol is 2
|
||||
'nn.functional.interpolatebilinear': [torch.uint8],
|
||||
}
|
||||
|
||||
MACOS_BEFORE_13_3_XFAILLIST = {
|
||||
|
|
@ -636,8 +633,6 @@ def mps_ops_modifier(ops):
|
|||
MACOS_AFTER_13_1_XFAILLIST = {
|
||||
# before macOS 13.2 it falls back to cpu and pass the forward pass
|
||||
'grid_sampler_2d': [torch.float32, torch.float16, torch.bfloat16], # Unsupported Border padding mode
|
||||
# inconsistency errors between cpu and mps, max seen atol is 2
|
||||
'nn.functional.interpolatebilinear': [torch.uint8],
|
||||
}
|
||||
|
||||
MACOS_13_3_XFAILLIST = {
|
||||
|
|
@ -12326,7 +12321,6 @@ class TestConsistency(TestCaseMPS):
|
|||
'native_layer_norm',
|
||||
'nn.functional.layer_norm',
|
||||
'nn.functional.interpolate',
|
||||
'nn.functional.upsample_bilinear',
|
||||
'nn.functional.upsample_nearest',
|
||||
'norm', 'masked.normalize',
|
||||
'arange', 'linspace',
|
||||
|
|
@ -12414,10 +12408,14 @@ class TestConsistency(TestCaseMPS):
|
|||
mps_out = op(*mps_args, **mps_kwargs)
|
||||
|
||||
atol, rtol = self._compute_tolerances(op, dtype)
|
||||
if op.name == "nn.functional.upsample_bilinear" and dtype == torch.uint8:
|
||||
atol = 1.0
|
||||
rtol = 0.0
|
||||
|
||||
if (op.name == "nn.functional.interpolate" and dtype == torch.uint8 and
|
||||
cpu_kwargs.get("mode") == "bilinear" and
|
||||
cpu_kwargs.get("recompute_scale_factor") is True and
|
||||
cpu_kwargs.get("scale_factor") == 1.7):
|
||||
# For 1/3, 2/3 scale factors results will not match CPU ones
|
||||
# As MPS compute scales in floats, but CPU always used doubles, which results
|
||||
# in slight numerical differences
|
||||
atol, rtol = 1, 0
|
||||
self.assertEqual(cpu_out, mps_out, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue