From 7725d0ba12d5f42ad93fc1b94edcf060e6eaa82c Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Thu, 6 Feb 2025 17:57:31 +0000 Subject: [PATCH] [METAL] inline bfloat min/max (#146588) After a recent commit 36c6e09528a7e071edecde083254da70cba26c95 , building from source with `python setup.py develop` leads to an error due to multiple symbols for min/max: ``` FAILED: caffe2/aten/src/ATen/kernels_bfloat.metallib /Users/Irakli_Salia/Desktop/pytorch/build/caffe2/aten/src/ATen/kernels_bfloat.metallib cd /Users/Irakli_Salia/Desktop/pytorch/build/caffe2/aten/src/ATen && xcrun metallib -o kernels_bfloat.metallib BinaryKernel_31.air Bucketization_31.air CrossKernel_31.air FusedOptimizerOps_31.air Gamma_31.air HistogramKernel_31.air Im2Col_31.air Indexing_31.air LinearAlgebra_31.air Quantized_31.air RMSNorm_31.air RenormKernel_31.air Repeat_31.air SpecialOps_31.air TriangularOps_31.air UnaryKernel_31.air UnfoldBackward_31.air UpSample_31.air LLVM ERROR: multiple symbols ('_ZN3c105metal3minIDF16bEEN5metal9enable_ifIXgssr5metalE19is_floating_point_vIT_EES4_E4typeES4_S4_')! ``` This PR fixes that. @malfet Pull Request resolved: https://github.com/pytorch/pytorch/pull/146588 Approved by: https://github.com/FFFrog, https://github.com/Skylion007, https://github.com/malfet --- c10/metal/utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/c10/metal/utils.h b/c10/metal/utils.h index 0429787951c..04a09fb77c4 100644 --- a/c10/metal/utils.h +++ b/c10/metal/utils.h @@ -108,13 +108,13 @@ template #if __METAL_VERSION__ >= 310 template <> -bfloat min(bfloat a, bfloat b) { +inline bfloat min(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b))); } template <> -bfloat max(bfloat a, bfloat b) { +inline bfloat max(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); }