migrate PyTorch to c10::bit_cast (#98418)

Use the standardized version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98418
Approved by: https://github.com/ezyang
This commit is contained in:
mikey dagitses 2023-04-06 05:54:20 -07:00 committed by PyTorch MergeBot
parent 213cec3c45
commit fe99d39fbd
11 changed files with 43 additions and 62 deletions

View file

@ -431,6 +431,7 @@ cu_library(
visibility = ["//visibility:public"],
deps = [
":aten_cuda_cpp",
"//c10/util:bit_cast",
"@cuda//:cublas",
"@cuda//:cufft",
"@cuda//:cusparse",

View file

@ -7,6 +7,7 @@
#include <ATen/NestedTensorImpl.h>
#include <ATen/TensorAccessor.h>
#include <c10/util/Logging.h>
#include <c10/util/bit_cast.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
@ -844,8 +845,8 @@ std::tuple<Tensor, Tensor, int64_t, int64_t, Tensor> _flash_attention_forward(
debug_attn_mask = return_debug_mask ? debug_attn_mask : at::empty({0}, query.options());
int64_t signed_philox_seed = sdp::bit_cast<int64_t>(philox_seed);
int64_t signed_philox_offset= sdp::bit_cast<int64_t>(philox_offset);
int64_t signed_philox_seed = c10::bit_cast<int64_t>(philox_seed);
int64_t signed_philox_offset= c10::bit_cast<int64_t>(philox_offset);
return std::make_tuple(output, logsumexp, signed_philox_seed, signed_philox_offset, debug_attn_mask);
#endif

View file

@ -4,6 +4,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAMathCompat.h>
#include <c10/util/bit_cast.h>
#include <c10/core/TensorImpl.h>
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
@ -106,8 +107,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// The kernel computes irregadless we will drop for this functions return
Tensor grad_softmax;
uint64_t unsigned_philox_seed = sdp::bit_cast<uint64_t>(philox_seed);
uint64_t unsigned_philox_offset = sdp::bit_cast<uint64_t>(philox_offset);
uint64_t unsigned_philox_seed = c10::bit_cast<uint64_t>(philox_seed);
uint64_t unsigned_philox_offset = c10::bit_cast<uint64_t>(philox_offset);
std::tie(dq, dk, dv, grad_softmax) = fmha::mha_bwd(
contiguous_grad_out,

View file

@ -21,14 +21,6 @@
namespace sdp {
template <typename To, typename From>
To bit_cast(From f) {
static_assert(sizeof(To) == sizeof(From));
To t;
std::memcpy(&t, &f, sizeof(f));
return t;
}
// This helper function creates a constexpr std::array
// From a compile time list of values
template <typename V, typename... T>

View file

@ -855,7 +855,7 @@ namespace {
// generate expected_val
for (int64_t i = 0; i < vec::size(); i++) {
bit_rep hex_mask = 0;
hex_mask=bit_cast<bit_rep>(mask[i]);
hex_mask=c10::bit_cast<bit_rep>(mask[i]);
expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
}
// test with blendv

View file

@ -1,6 +1,7 @@
#pragma once
#include <ATen/cpu/vec/vec.h>
#include <ATen/cpu/vec/functional.h>
#include <c10/util/bit_cast.h>
#include <c10/util/irange.h>
#include <gtest/gtest.h>
#include <chrono>
@ -266,25 +267,6 @@ std::ostream& operator<<(std::ostream& stream, const CheckWithinDomains<T>& dmn)
return stream;
}
template <class To, class From>
typename std::enable_if<
(sizeof(To) == sizeof(From)) && std::is_trivially_copyable<From>::value&&
std::is_trivial<To>::value,
// this implementation requires that To is trivially default constructible
To>::type
bit_cast(const From& src) noexcept {
To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}
template <class To, class T>
To bit_cast_ptr(T* p, size_t N = sizeof(To)) noexcept {
unsigned char p1[sizeof(To)] = {};
std::memcpy(p1, p, std::min(N, sizeof(To)));
return bit_cast<To>(p1);
}
template <typename T>
std::enable_if_t<std::is_floating_point<T>::value, bool> check_both_nan(T x,
T y) {
@ -888,8 +870,8 @@ public:
if (bitwise)
{
for (const auto i : c10::irange(sizeX)) {
BVT b_exp = bit_cast<BVT>(expArr[i]);
BVT b_act = bit_cast<BVT>(actArr[i]);
BVT b_exp = c10::bit_cast<BVT>(expArr[i]);
BVT b_act = c10::bit_cast<BVT>(actArr[i]);
EXPECT_EQ(b_exp, b_act) << getDetail(i / unitStorageCount);
if (::testing::Test::HasFailure())
return true;
@ -1121,7 +1103,7 @@ T func_cmp(Op call, T v0, T v1) {
using bit_rep = BitType<T>;
constexpr bit_rep mask = std::numeric_limits<bit_rep>::max();
bit_rep ret = call(v0, v1) ? mask : 0;
return bit_cast<T>(ret);
return c10::bit_cast<T>(ret);
}
struct PreventFma
@ -1300,8 +1282,8 @@ template<typename T>
std::enable_if_t<!is_complex<T>::value, T>
local_and(const T& val0, const T& val1) {
using bit_rep = BitType<T>;
bit_rep ret = bit_cast<bit_rep>(val0) & bit_cast<bit_rep>(val1);
return bit_cast<T> (ret);
bit_rep ret = c10::bit_cast<bit_rep>(val0) & c10::bit_cast<bit_rep>(val1);
return c10::bit_cast<T> (ret);
}
template <typename T>
@ -1313,17 +1295,17 @@ local_and(const Complex<T>& val0, const Complex<T>& val1)
T imag1 = val0.imag();
T real2 = val1.real();
T imag2 = val1.imag();
bit_rep real_ret = bit_cast<bit_rep>(real1) & bit_cast<bit_rep>(real2);
bit_rep imag_ret = bit_cast<bit_rep>(imag1) & bit_cast<bit_rep>(imag2);
return Complex<T>(bit_cast<T>(real_ret), bit_cast<T>(imag_ret));
bit_rep real_ret = c10::bit_cast<bit_rep>(real1) & c10::bit_cast<bit_rep>(real2);
bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) & c10::bit_cast<bit_rep>(imag2);
return Complex<T>(c10::bit_cast<T>(real_ret), c10::bit_cast<T>(imag_ret));
}
template<typename T>
std::enable_if_t<!is_complex<T>::value, T>
local_or(const T& val0, const T& val1) {
using bit_rep = BitType<T>;
bit_rep ret = bit_cast<bit_rep>(val0) | bit_cast<bit_rep>(val1);
return bit_cast<T> (ret);
bit_rep ret = c10::bit_cast<bit_rep>(val0) | c10::bit_cast<bit_rep>(val1);
return c10::bit_cast<T> (ret);
}
template<typename T>
@ -1334,17 +1316,17 @@ local_or(const Complex<T>& val0, const Complex<T>& val1) {
T imag1 = val0.imag();
T real2 = val1.real();
T imag2 = val1.imag();
bit_rep real_ret = bit_cast<bit_rep>(real1) | bit_cast<bit_rep>(real2);
bit_rep imag_ret = bit_cast<bit_rep>(imag1) | bit_cast<bit_rep>(imag2);
return Complex<T>(bit_cast<T> (real_ret), bit_cast<T>(imag_ret));
bit_rep real_ret = c10::bit_cast<bit_rep>(real1) | c10::bit_cast<bit_rep>(real2);
bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) | c10::bit_cast<bit_rep>(imag2);
return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret));
}
template<typename T>
std::enable_if_t<!is_complex<T>::value, T>
local_xor(const T& val0, const T& val1) {
using bit_rep = BitType<T>;
bit_rep ret = bit_cast<bit_rep>(val0) ^ bit_cast<bit_rep>(val1);
return bit_cast<T> (ret);
bit_rep ret = c10::bit_cast<bit_rep>(val0) ^ c10::bit_cast<bit_rep>(val1);
return c10::bit_cast<T> (ret);
}
template<typename T>
@ -1355,9 +1337,9 @@ local_xor(const Complex<T>& val0, const Complex<T>& val1) {
T imag1 = val0.imag();
T real2 = val1.real();
T imag2 = val1.imag();
bit_rep real_ret = bit_cast<bit_rep>(real1) ^ bit_cast<bit_rep>(real2);
bit_rep imag_ret = bit_cast<bit_rep>(imag1) ^ bit_cast<bit_rep>(imag2);
return Complex<T>(bit_cast<T> (real_ret), bit_cast<T>(imag_ret));
bit_rep real_ret = c10::bit_cast<bit_rep>(real1) ^ c10::bit_cast<bit_rep>(real2);
bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) ^ c10::bit_cast<bit_rep>(imag2);
return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret));
}
template <typename T>

View file

@ -1,6 +1,8 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <limits>
C10_CLANG_DIAGNOSTIC_PUSH()
@ -27,7 +29,7 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
x(__bfloat16_as_ushort(__float2bfloat16(value)))
#elif defined(__SYCL_DEVICE_ONLY__) && \
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
x(sycl::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
#else
// RNE by default
x(detail::round_to_nearest_even(value))

View file

@ -1,6 +1,8 @@
#pragma once
#include <c10/macros/Macros.h>
#include <c10/util/bit_cast.h>
#include <cstring>
#include <limits>
@ -32,7 +34,7 @@ inline C10_HOST_DEVICE Half::Half(float value)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
x(__half_as_short(__float2half(value)))
#elif defined(__SYCL_DEVICE_ONLY__)
x(sycl::bit_cast<uint16_t>(sycl::half(value)))
x(c10::bit_cast<uint16_t>(sycl::half(value)))
#else
x(detail::fp16_ieee_from_fp32_value(value))
#endif
@ -45,7 +47,7 @@ inline C10_HOST_DEVICE Half::operator float() const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
return __half2float(*reinterpret_cast<const __half*>(&x));
#elif defined(__SYCL_DEVICE_ONLY__)
return float(sycl::bit_cast<sycl::half>(x));
return float(c10::bit_cast<sycl::half>(x));
#else
return detail::fp16_ieee_to_fp32_value(x);
#endif
@ -102,7 +104,7 @@ inline C10_HOST_DEVICE Half operator-(const Half& a) {
defined(__HIP_DEVICE_COMPILE__)
return __hneg(a);
#elif defined(__SYCL_DEVICE_ONLY__)
return -sycl::bit_cast<sycl::half>(a);
return -c10::bit_cast<sycl::half>(a);
#else
return -static_cast<float>(a);
#endif

View file

@ -13,13 +13,13 @@ namespace c10 {
// information as well as the source of our implementations.
template <class To, class From>
std::enable_if_t<
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
sizeof(To) == sizeof(From) && std::is_trivially_copyable<From>::value &&
std::is_trivially_copyable<To>::value,
To>
// constexpr support needs compiler magic
bit_cast(const From& src) noexcept {
static_assert(
std::is_trivially_constructible_v<To>,
std::is_trivially_constructible<To>::value,
"This implementation additionally requires "
"destination type to be trivially constructible");

View file

@ -33,6 +33,7 @@ def define_targets(rules):
local_defines = ["C10_BUILD_MAIN_LIB"],
visibility = ["//visibility:public"],
deps = [
":bit_cast",
"//c10/macros",
"@fmt",
] + rules.select({

View file

@ -12,6 +12,8 @@
#pragma once
#include <c10/util/bit_cast.h>
#include <algorithm>
#include <cassert>
#include <climits>
@ -615,11 +617,8 @@ inline double BitsToDouble(uint64_t Bits) {
/// This function takes a 32-bit integer and returns the bit equivalent float.
inline float BitsToFloat(uint32_t Bits) {
// TODO: Use bit_cast once C++20 becomes available.
float F;
static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes");
memcpy(&F, &Bits, sizeof(Bits));
return F;
// TODO: Use std::bit_cast once C++20 becomes available.
return c10::bit_cast<float>(Bits);
}
/// This function takes a double and returns the bit equivalent 64-bit integer.