mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
migrate PyTorch to c10::bit_cast (#98418)
Use the standardized version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/98418 Approved by: https://github.com/ezyang
This commit is contained in:
parent
213cec3c45
commit
fe99d39fbd
11 changed files with 43 additions and 62 deletions
|
|
@ -431,6 +431,7 @@ cu_library(
|
|||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":aten_cuda_cpp",
|
||||
"//c10/util:bit_cast",
|
||||
"@cuda//:cublas",
|
||||
"@cuda//:cufft",
|
||||
"@cuda//:cusparse",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/TensorAccessor.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
|
|
@ -844,8 +845,8 @@ std::tuple<Tensor, Tensor, int64_t, int64_t, Tensor> _flash_attention_forward(
|
|||
|
||||
debug_attn_mask = return_debug_mask ? debug_attn_mask : at::empty({0}, query.options());
|
||||
|
||||
int64_t signed_philox_seed = sdp::bit_cast<int64_t>(philox_seed);
|
||||
int64_t signed_philox_offset= sdp::bit_cast<int64_t>(philox_offset);
|
||||
int64_t signed_philox_seed = c10::bit_cast<int64_t>(philox_seed);
|
||||
int64_t signed_philox_offset= c10::bit_cast<int64_t>(philox_offset);
|
||||
|
||||
return std::make_tuple(output, logsumexp, signed_philox_seed, signed_philox_offset, debug_attn_mask);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAMathCompat.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
|
|
@ -106,8 +107,8 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
|||
// The kernel computes irregadless we will drop for this functions return
|
||||
Tensor grad_softmax;
|
||||
|
||||
uint64_t unsigned_philox_seed = sdp::bit_cast<uint64_t>(philox_seed);
|
||||
uint64_t unsigned_philox_offset = sdp::bit_cast<uint64_t>(philox_offset);
|
||||
uint64_t unsigned_philox_seed = c10::bit_cast<uint64_t>(philox_seed);
|
||||
uint64_t unsigned_philox_offset = c10::bit_cast<uint64_t>(philox_offset);
|
||||
|
||||
std::tie(dq, dk, dv, grad_softmax) = fmha::mha_bwd(
|
||||
contiguous_grad_out,
|
||||
|
|
|
|||
|
|
@ -21,14 +21,6 @@
|
|||
|
||||
namespace sdp {
|
||||
|
||||
template <typename To, typename From>
|
||||
To bit_cast(From f) {
|
||||
static_assert(sizeof(To) == sizeof(From));
|
||||
To t;
|
||||
std::memcpy(&t, &f, sizeof(f));
|
||||
return t;
|
||||
}
|
||||
|
||||
// This helper function creates a constexpr std::array
|
||||
// From a compile time list of values
|
||||
template <typename V, typename... T>
|
||||
|
|
|
|||
|
|
@ -855,7 +855,7 @@ namespace {
|
|||
// generate expected_val
|
||||
for (int64_t i = 0; i < vec::size(); i++) {
|
||||
bit_rep hex_mask = 0;
|
||||
hex_mask=bit_cast<bit_rep>(mask[i]);
|
||||
hex_mask=c10::bit_cast<bit_rep>(mask[i]);
|
||||
expected_val[i] = (hex_mask & 0x01) ? b[i] : a[i];
|
||||
}
|
||||
// test with blendv
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <chrono>
|
||||
|
|
@ -266,25 +267,6 @@ std::ostream& operator<<(std::ostream& stream, const CheckWithinDomains<T>& dmn)
|
|||
return stream;
|
||||
}
|
||||
|
||||
template <class To, class From>
|
||||
typename std::enable_if<
|
||||
(sizeof(To) == sizeof(From)) && std::is_trivially_copyable<From>::value&&
|
||||
std::is_trivial<To>::value,
|
||||
// this implementation requires that To is trivially default constructible
|
||||
To>::type
|
||||
bit_cast(const From& src) noexcept {
|
||||
To dst;
|
||||
std::memcpy(&dst, &src, sizeof(To));
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <class To, class T>
|
||||
To bit_cast_ptr(T* p, size_t N = sizeof(To)) noexcept {
|
||||
unsigned char p1[sizeof(To)] = {};
|
||||
std::memcpy(p1, p, std::min(N, sizeof(To)));
|
||||
return bit_cast<To>(p1);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::enable_if_t<std::is_floating_point<T>::value, bool> check_both_nan(T x,
|
||||
T y) {
|
||||
|
|
@ -888,8 +870,8 @@ public:
|
|||
if (bitwise)
|
||||
{
|
||||
for (const auto i : c10::irange(sizeX)) {
|
||||
BVT b_exp = bit_cast<BVT>(expArr[i]);
|
||||
BVT b_act = bit_cast<BVT>(actArr[i]);
|
||||
BVT b_exp = c10::bit_cast<BVT>(expArr[i]);
|
||||
BVT b_act = c10::bit_cast<BVT>(actArr[i]);
|
||||
EXPECT_EQ(b_exp, b_act) << getDetail(i / unitStorageCount);
|
||||
if (::testing::Test::HasFailure())
|
||||
return true;
|
||||
|
|
@ -1121,7 +1103,7 @@ T func_cmp(Op call, T v0, T v1) {
|
|||
using bit_rep = BitType<T>;
|
||||
constexpr bit_rep mask = std::numeric_limits<bit_rep>::max();
|
||||
bit_rep ret = call(v0, v1) ? mask : 0;
|
||||
return bit_cast<T>(ret);
|
||||
return c10::bit_cast<T>(ret);
|
||||
}
|
||||
|
||||
struct PreventFma
|
||||
|
|
@ -1300,8 +1282,8 @@ template<typename T>
|
|||
std::enable_if_t<!is_complex<T>::value, T>
|
||||
local_and(const T& val0, const T& val1) {
|
||||
using bit_rep = BitType<T>;
|
||||
bit_rep ret = bit_cast<bit_rep>(val0) & bit_cast<bit_rep>(val1);
|
||||
return bit_cast<T> (ret);
|
||||
bit_rep ret = c10::bit_cast<bit_rep>(val0) & c10::bit_cast<bit_rep>(val1);
|
||||
return c10::bit_cast<T> (ret);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -1313,17 +1295,17 @@ local_and(const Complex<T>& val0, const Complex<T>& val1)
|
|||
T imag1 = val0.imag();
|
||||
T real2 = val1.real();
|
||||
T imag2 = val1.imag();
|
||||
bit_rep real_ret = bit_cast<bit_rep>(real1) & bit_cast<bit_rep>(real2);
|
||||
bit_rep imag_ret = bit_cast<bit_rep>(imag1) & bit_cast<bit_rep>(imag2);
|
||||
return Complex<T>(bit_cast<T>(real_ret), bit_cast<T>(imag_ret));
|
||||
bit_rep real_ret = c10::bit_cast<bit_rep>(real1) & c10::bit_cast<bit_rep>(real2);
|
||||
bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) & c10::bit_cast<bit_rep>(imag2);
|
||||
return Complex<T>(c10::bit_cast<T>(real_ret), c10::bit_cast<T>(imag_ret));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::enable_if_t<!is_complex<T>::value, T>
|
||||
local_or(const T& val0, const T& val1) {
|
||||
using bit_rep = BitType<T>;
|
||||
bit_rep ret = bit_cast<bit_rep>(val0) | bit_cast<bit_rep>(val1);
|
||||
return bit_cast<T> (ret);
|
||||
bit_rep ret = c10::bit_cast<bit_rep>(val0) | c10::bit_cast<bit_rep>(val1);
|
||||
return c10::bit_cast<T> (ret);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -1334,17 +1316,17 @@ local_or(const Complex<T>& val0, const Complex<T>& val1) {
|
|||
T imag1 = val0.imag();
|
||||
T real2 = val1.real();
|
||||
T imag2 = val1.imag();
|
||||
bit_rep real_ret = bit_cast<bit_rep>(real1) | bit_cast<bit_rep>(real2);
|
||||
bit_rep imag_ret = bit_cast<bit_rep>(imag1) | bit_cast<bit_rep>(imag2);
|
||||
return Complex<T>(bit_cast<T> (real_ret), bit_cast<T>(imag_ret));
|
||||
bit_rep real_ret = c10::bit_cast<bit_rep>(real1) | c10::bit_cast<bit_rep>(real2);
|
||||
bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) | c10::bit_cast<bit_rep>(imag2);
|
||||
return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::enable_if_t<!is_complex<T>::value, T>
|
||||
local_xor(const T& val0, const T& val1) {
|
||||
using bit_rep = BitType<T>;
|
||||
bit_rep ret = bit_cast<bit_rep>(val0) ^ bit_cast<bit_rep>(val1);
|
||||
return bit_cast<T> (ret);
|
||||
bit_rep ret = c10::bit_cast<bit_rep>(val0) ^ c10::bit_cast<bit_rep>(val1);
|
||||
return c10::bit_cast<T> (ret);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -1355,9 +1337,9 @@ local_xor(const Complex<T>& val0, const Complex<T>& val1) {
|
|||
T imag1 = val0.imag();
|
||||
T real2 = val1.real();
|
||||
T imag2 = val1.imag();
|
||||
bit_rep real_ret = bit_cast<bit_rep>(real1) ^ bit_cast<bit_rep>(real2);
|
||||
bit_rep imag_ret = bit_cast<bit_rep>(imag1) ^ bit_cast<bit_rep>(imag2);
|
||||
return Complex<T>(bit_cast<T> (real_ret), bit_cast<T>(imag_ret));
|
||||
bit_rep real_ret = c10::bit_cast<bit_rep>(real1) ^ c10::bit_cast<bit_rep>(real2);
|
||||
bit_rep imag_ret = c10::bit_cast<bit_rep>(imag1) ^ c10::bit_cast<bit_rep>(imag2);
|
||||
return Complex<T>(c10::bit_cast<T> (real_ret), c10::bit_cast<T>(imag_ret));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <limits>
|
||||
|
||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
||||
|
|
@ -27,7 +29,7 @@ inline C10_HOST_DEVICE BFloat16::BFloat16(float value)
|
|||
x(__bfloat16_as_ushort(__float2bfloat16(value)))
|
||||
#elif defined(__SYCL_DEVICE_ONLY__) && \
|
||||
defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS)
|
||||
x(sycl::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
|
||||
x(c10::bit_cast<uint16_t>(sycl::ext::oneapi::bfloat16(value)))
|
||||
#else
|
||||
// RNE by default
|
||||
x(detail::round_to_nearest_even(value))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
|
||||
|
|
@ -32,7 +34,7 @@ inline C10_HOST_DEVICE Half::Half(float value)
|
|||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
x(__half_as_short(__float2half(value)))
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
x(sycl::bit_cast<uint16_t>(sycl::half(value)))
|
||||
x(c10::bit_cast<uint16_t>(sycl::half(value)))
|
||||
#else
|
||||
x(detail::fp16_ieee_from_fp32_value(value))
|
||||
#endif
|
||||
|
|
@ -45,7 +47,7 @@ inline C10_HOST_DEVICE Half::operator float() const {
|
|||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
||||
return __half2float(*reinterpret_cast<const __half*>(&x));
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
return float(sycl::bit_cast<sycl::half>(x));
|
||||
return float(c10::bit_cast<sycl::half>(x));
|
||||
#else
|
||||
return detail::fp16_ieee_to_fp32_value(x);
|
||||
#endif
|
||||
|
|
@ -102,7 +104,7 @@ inline C10_HOST_DEVICE Half operator-(const Half& a) {
|
|||
defined(__HIP_DEVICE_COMPILE__)
|
||||
return __hneg(a);
|
||||
#elif defined(__SYCL_DEVICE_ONLY__)
|
||||
return -sycl::bit_cast<sycl::half>(a);
|
||||
return -c10::bit_cast<sycl::half>(a);
|
||||
#else
|
||||
return -static_cast<float>(a);
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ namespace c10 {
|
|||
// information as well as the source of our implementations.
|
||||
template <class To, class From>
|
||||
std::enable_if_t<
|
||||
sizeof(To) == sizeof(From) && std::is_trivially_copyable_v<From> &&
|
||||
std::is_trivially_copyable_v<To>,
|
||||
sizeof(To) == sizeof(From) && std::is_trivially_copyable<From>::value &&
|
||||
std::is_trivially_copyable<To>::value,
|
||||
To>
|
||||
// constexpr support needs compiler magic
|
||||
bit_cast(const From& src) noexcept {
|
||||
static_assert(
|
||||
std::is_trivially_constructible_v<To>,
|
||||
std::is_trivially_constructible<To>::value,
|
||||
"This implementation additionally requires "
|
||||
"destination type to be trivially constructible");
|
||||
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ def define_targets(rules):
|
|||
local_defines = ["C10_BUILD_MAIN_LIB"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bit_cast",
|
||||
"//c10/macros",
|
||||
"@fmt",
|
||||
] + rules.select({
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <c10/util/bit_cast.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
|
|
@ -615,11 +617,8 @@ inline double BitsToDouble(uint64_t Bits) {
|
|||
|
||||
/// This function takes a 32-bit integer and returns the bit equivalent float.
|
||||
inline float BitsToFloat(uint32_t Bits) {
|
||||
// TODO: Use bit_cast once C++20 becomes available.
|
||||
float F;
|
||||
static_assert(sizeof(uint32_t) == sizeof(float), "Unexpected type sizes");
|
||||
memcpy(&F, &Bits, sizeof(Bits));
|
||||
return F;
|
||||
// TODO: Use std::bit_cast once C++20 becomes available.
|
||||
return c10::bit_cast<float>(Bits);
|
||||
}
|
||||
|
||||
/// This function takes a double and returns the bit equivalent 64-bit integer.
|
||||
|
|
|
|||
Loading…
Reference in a new issue