From e429a3b72e787ddcc26ee2ba177643c9177bab24 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 14 Nov 2024 10:33:57 -0800 Subject: [PATCH] Move complex from Half.h to complex.h (#140565) Executing on old TODO on the way to sharing Half.h with ExecuTorch. Differential Revision: [D65888037](https://our.internmc.facebook.com/intern/diff/D65888037/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140565 Approved by: https://github.com/ezyang, https://github.com/malfet ghstack dependencies: #140564 --- c10/util/Half.h | 51 --------------------------------- c10/util/complex.h | 50 ++++++++++++++++++++++++++++++++ torch/csrc/utils/byte_order.cpp | 1 + 3 files changed, 51 insertions(+), 51 deletions(-) diff --git a/c10/util/Half.h b/c10/util/Half.h index 9a7d159758b..c0085798a96 100644 --- a/c10/util/Half.h +++ b/c10/util/Half.h @@ -12,7 +12,6 @@ #include #include #include -#include #include #include @@ -385,56 +384,6 @@ struct alignas(2) Half { #endif }; -// TODO : move to complex.h -template <> -struct alignas(4) complex { - Half real_; - Half imag_; - - // Constructors - complex() = default; - // Half constructor is not constexpr so the following constructor can't - // be constexpr - C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) - : real_(real), imag_(imag) {} - C10_HOST_DEVICE inline complex(const c10::complex& value) - : real_(value.real()), imag_(value.imag()) {} - - // Conversion operator - inline C10_HOST_DEVICE operator c10::complex() const { - return {real_, imag_}; - } - - constexpr C10_HOST_DEVICE Half real() const { - return real_; - } - constexpr C10_HOST_DEVICE Half imag() const { - return imag_; - } - - C10_HOST_DEVICE complex& operator+=(const complex& other) { - real_ = static_cast(real_) + static_cast(other.real_); - imag_ = static_cast(imag_) + static_cast(other.imag_); - return *this; - } - - C10_HOST_DEVICE complex& operator-=(const complex& other) { - real_ = static_cast(real_) - static_cast(other.real_); - imag_ = static_cast(imag_) - static_cast(other.imag_); - return *this; - } - - C10_HOST_DEVICE complex& operator*=(const complex& other) { - auto a = static_cast(real_); - auto b = static_cast(imag_); - auto c = static_cast(other.real()); - auto d = static_cast(other.imag()); - real_ = a * c - b * d; - imag_ = a * d + b * c; - return *this; - } -}; - C10_API inline std::ostream& operator<<(std::ostream& out, const Half& value) { out << (float)value; return out; diff --git a/c10/util/complex.h b/c10/util/complex.h index c08e10aa0f2..b63710d9458 100644 --- a/c10/util/complex.h +++ b/c10/util/complex.h @@ -3,6 +3,7 @@ #include #include +#include #if defined(__CUDACC__) || defined(__HIPCC__) #include @@ -606,6 +607,55 @@ C10_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { #endif } +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) + : real_(real), imag_(imag) {} + C10_HOST_DEVICE inline complex(const c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline C10_HOST_DEVICE operator c10::complex() const { + return {real_, imag_}; + } + + constexpr C10_HOST_DEVICE Half real() const { + return real_; + } + constexpr C10_HOST_DEVICE Half imag() const { + return imag_; + } + + C10_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + } // namespace c10 C10_CLANG_DIAGNOSTIC_POP() diff --git a/torch/csrc/utils/byte_order.cpp b/torch/csrc/utils/byte_order.cpp index e7eaf2de0c6..5a5ec190991 100644 --- a/torch/csrc/utils/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -1,4 +1,5 @@ #include +#include #include #include