diff --git a/BUILD.bazel b/BUILD.bazel index c71781606b7..1018f7907ad 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -332,6 +332,7 @@ intern_build_aten_ops( "@fbgemm", "@mkl", "@sleef", + "@mkl_dnn//:mkl-dnn", ], ) diff --git a/aten/src/ATen/native/CPUBlas.cpp b/aten/src/ATen/native/CPUBlas.cpp index 2e0dadfbabc..689d3bacd5a 100644 --- a/aten/src/ATen/native/CPUBlas.cpp +++ b/aten/src/ATen/native/CPUBlas.cpp @@ -41,6 +41,17 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int * #include #endif // USE_FBGEMM +#if AT_MKLDNN_ENABLED() +#include +#endif // oneDNN + +#define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5) + +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) +#include +#include +#endif // oneDNN BRGEMM + namespace at::native::cpublas { namespace internal { @@ -822,4 +833,366 @@ void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex +struct UnsafeUkernelKeyHasher { + std::size_t operator()(const key_t& key) const; +}; + +template<> +std::size_t UnsafeUkernelKeyHasher::operator()(const BrgemmKey& key) const { + // Use beta, M, N, and K to compute hash to reduce the overhead as + // batch size, alpha, and data types are unlikely to change within the same kernel and + // leading dimensions are likely to be related to M, K, N or use fixed values. + std::size_t h = std::hash()(key.beta + 1); + h = std::hash()(key.M) ^ (h << 1); + h = std::hash()(key.N) ^ (h << 1); + h = std::hash()(key.K) ^ (h << 1); + h = std::hash()(key.ldc) ^ (h << 1); + return h; +} + +template<> +std::size_t UnsafeUkernelKeyHasher::operator()(const PackKey& key) const { + // Use K and N to compute hash to reduce the overhead as + // data types are unlikely to change and + // ld_in/ld_out is likely to be related to K, N or use fixed values + std::size_t h = std::hash()(key.K); + h = std::hash()(key.N) ^ (h << 1); + return h; +} + +template +struct KernelCache { + using kstore_t = std::unordered_map, UnsafeUkernelKeyHasher>; + static inline std::shared_ptr&& fetch_or_create( + const key_t& key, + const std::function()>& callback) { + auto&& search = get_store().find(key); + if (search != get_store().end()) { + return std::move(search->second); + } else { + get_store().insert({key, callback()}); + return std::move(get_store()[key]); + } + } + + static inline kstore_t& get_store() { + static thread_local kstore_t cache_kernels; + return cache_kernels; + } +}; + +// Helper struct for convenient brgemm configuration +struct GemmHelper { + GemmHelper( + int64_t M, + int64_t N, + int64_t K, + int64_t bs, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + ScalarType dt_a, + ScalarType dt_b, + ScalarType dt_c, + const float alpha, + const float beta) { + // Create brgemm + brg = dnnl::ukernel::brgemm( + M, + N, + K, + bs, + ld_a, + ld_b, + ld_c, + get_dnnl_dtype(dt_a), + get_dnnl_dtype(dt_b), + get_dnnl_dtype(dt_c), + alpha, + beta); + // Create a scratchpad buffer for the brgemm execution + scratchpad = std::vector(brg.get_scratchpad_size()); + // Prepare default vector of pairs of tensors A and B offsets for each batch. + A_B_offsets.reserve(1); + A_B_offsets[0] = std::make_pair(0, 0); + } + dnnl::ukernel::brgemm brg; + std::vector scratchpad; + std::vector> A_B_offsets; +}; + +struct Brgemm : public KernelCache { + // Fetch/create GemmHelper object and execute brgemm with batch size = 1 + template + static inline void call( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const scalar_t_a* A, + const scalar_t_b* B, + scalar_t_c* C) { + auto&& key = BrgemmKey( + M, + N, + K, + int64_t(1), + ld_a, + ld_b, + ld_c, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + alpha, + beta); + // Fetch/create GemmHelper object + auto&& value = fetch_or_create(key, [&]() { + auto&& v = std::make_shared( + M, + N, + K, + 1, + ld_a, + ld_b, + ld_c, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + c10::CppTypeToScalarType::value, + alpha, + beta); + (*v).brg.generate(); + return std::move(v); + }); + if (get_current() != value) { + dnnl::ukernel::brgemm::release_hw_context(); + ((*value).brg).set_hw_context(); + get_current() = value; + } + ((*value).brg) + .execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data()); + } + + static inline std::shared_ptr& get_current() { + static thread_local std::shared_ptr current; + return current; + } + + static inline bool device_check(ScalarType dtype) { + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (dtype == ScalarType::Half) { + static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16; + return fp16_support; + } + return false; + } +}; + +using pack_t = dnnl::ukernel::brgemm_pack_B; +struct Pack : public KernelCache { + static inline void call( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out) { + auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out); + auto&& pack = fetch_or_create(key, [&]() { + auto&& p = std::make_shared( + K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out)); + if (need_pack(dt_in)) { + (*p).generate(); + } + return std::move(p); + }); + if (need_pack(dt_in)) { + (*pack).execute(in, out); + } else { + TORCH_CHECK(false, "No need to pack"); + } + } + + static inline bool need_pack(ScalarType dtype) { + if (!at::globalContext().userEnabledMkldnn()) { + return false; + } + if (dtype == ScalarType::Half) { + static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16; + return fp16_pack; + } + return false; + } +}; +#endif + +void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::Half* A, + const at::Half* B, + float* C) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + if (Brgemm::device_check(ScalarType::Half)) { + Brgemm::call( + M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C); + return; + } +#endif + TORCH_CHECK(false, + "Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported"); +} + +void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::BFloat16* A, + const at::BFloat16* B, + float* C) { + TORCH_CHECK(false, + "BFloat16 Brgemm is currently not supported"); +} + +void brgemm_release() { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + dnnl::ukernel::brgemm::release_hw_context(); +#endif +} + +void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out); +#else + TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled"); +#endif +} + +bool need_pack(ScalarType dt_in) { +#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC))) + return Pack::need_pack(dt_in); +#else + return false; +#endif +} + +} // namespace at::native::cpublas diff --git a/aten/src/ATen/native/CPUBlas.h b/aten/src/ATen/native/CPUBlas.h index 3b30df1c21f..01ec49e4b54 100644 --- a/aten/src/ATen/native/CPUBlas.h +++ b/aten/src/ATen/native/CPUBlas.h @@ -7,6 +7,7 @@ #include #include + namespace at::native::cpublas { namespace internal { @@ -186,4 +187,58 @@ void copy(int64_t n, const float *x, int64_t incx, float *y, int64_t incy); void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); void copy(int64_t n, const c10::complex *x, int64_t incx, c10::complex *y, int64_t incy); -} // namespace at::native::cpublas +// Batch-reduce GEMM +// Operates by the following formula: +// C = alpha * SUM(A[i] x B[i]) + beta * C, i = 0 to batch size +// A Base pointer to a tensor A. +// B Base pointer to a tensor B. +// Byte offsets vector of pairs of tensors A and B offsets for +// each batch. The number of batches must coincide with the +// `batch_size` value passed at object construction stage. +// C Pointer to a tensor C (accumulation buffer). +// scratchpad Pointer to a scratchpad buffer. +// Currently, only brgemm with batch size = 1 will be used +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::Half* A, + const at::Half* B, + float* C); + +TORCH_API void brgemm( + int64_t M, + int64_t N, + int64_t K, + int64_t ld_a, + int64_t ld_b, + int64_t ld_c, + const float alpha, + const float beta, + const at::BFloat16* A, + const at::BFloat16* B, + float* C); + +// Release brgemm hardware context +void brgemm_release(); + +// Pack B matrix to get better performance if needed +void pack( + int64_t K, + int64_t N, + int64_t ld_in, + int64_t ld_out, + ScalarType dt_in, + ScalarType dt_out, + const void* in, + void* out); + +// Whether pack is needed in the platform. +bool need_pack(ScalarType dt_in); + +} // namespace at::native::cpublas diff --git a/cmake/Modules/FindMKLDNN.cmake b/cmake/Modules/FindMKLDNN.cmake index a8d42b6ee8d..234d361d7f5 100644 --- a/cmake/Modules/FindMKLDNN.cmake +++ b/cmake/Modules/FindMKLDNN.cmake @@ -87,13 +87,18 @@ IF(NOT MKLDNN_FOUND) SET(ONEDNN_BUILD_GRAPH ON CACHE BOOL "" FORCE) ENDIF(NOT APPLE AND NOT WIN32 AND NOT BUILD_LITE_INTERPRETER) + IF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") + MESSAGE("-- Will build oneDNN UKERNEL") + SET(DNNL_EXPERIMENTAL_UKERNEL ON CACHE BOOL "" FORCE) + ENDIF(EXISTS "${MKLDNN_ROOT}/include/oneapi/dnnl/dnnl_ukernel.hpp") + FIND_PACKAGE(BLAS) FIND_PATH(IDEEP_INCLUDE_DIR ideep.hpp PATHS ${IDEEP_ROOT} PATH_SUFFIXES include) - FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include/oneapi/dnnl) IF(NOT MKLDNN_INCLUDE_DIR) MESSAGE("MKLDNN_INCLUDE_DIR not found") EXECUTE_PROCESS(COMMAND git${CMAKE_EXECUTABLE_SUFFIX} submodule update --init mkl-dnn WORKING_DIRECTORY ${IDEEP_ROOT}) - FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) + FIND_PATH(MKLDNN_INCLUDE_DIR dnnl.hpp dnnl.h dnnl_ukernel.hpp dnnl_ukernel.h PATHS ${MKLDNN_ROOT} PATH_SUFFIXES include) ENDIF(NOT MKLDNN_INCLUDE_DIR) IF(BUILD_ONEDNN_GRAPH) FIND_PATH(LLGA_INCLUDE_DIR dnnl_graph.hpp PATHS ${LLGA_ROOT} PATH_SUFFIXES include/oneapi/dnnl) diff --git a/third_party/mkl-dnn.BUILD b/third_party/mkl-dnn.BUILD index 64154894d65..aa55943bc91 100644 --- a/third_party/mkl-dnn.BUILD +++ b/third_party/mkl-dnn.BUILD @@ -11,9 +11,9 @@ _DNNL_RUNTIME_OMP = { "#cmakedefine DNNL_SYCL_CUDA": "/* #undef DNNL_SYCL_CUDA */", "#cmakedefine DNNL_SYCL_HIP": "/* #undef DNNL_SYCL_HIP */", "#cmakedefine DNNL_ENABLE_STACK_CHECKER": "#undef DNNL_ENABLE_STACK_CHECKER", + "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "/* undef DNNL_EXPERIMENTAL_UKERNEL */", "#cmakedefine DNNL_EXPERIMENTAL": "#undef DNNL_EXPERIMENTAL", "#cmakedefine DNNL_EXPERIMENTAL_SPARSE": "#undef DNNL_EXPERIMENTAL_SPARSE", - "#cmakedefine DNNL_EXPERIMENTAL_UKERNEL": "#undef DNNL_EXPERIMENTAL_UKERNEL", "#cmakedefine ONEDNN_BUILD_GRAPH": "#undef ONEDNN_BUILD_GRAPH", "#cmakedefine DNNL_EXPERIMENTAL_PROFILING": "#undef DNNL_EXPERIMENTAL_PROFILING", "#cmakedefine01 BUILD_TRAINING": "#define BUILD_TRAINING 1", @@ -138,6 +138,7 @@ cc_library( "DNNL_ENABLE_CONCURRENT_EXEC", "DNNL_ENABLE_PRIMITIVE_CACHE", "DNNL_ENABLE_CPU_ISA_HINTS", + "DNNL_EXPERIMENTAL_UKERNEL", "ONEDNN_BUILD_GRAPH", ], )