From 32af2ba68f32edab31c7eb2a5ccb4666c0a1cb09 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 10:37:50 -0700 Subject: [PATCH] enhance string util functions (#21893) ### Description - make `MakeString` force inline - refactor ORT_FORCEINLINE macro - move to one place to avoid macro redefinition error - ~~add a `StringJoin` utility~~ ### Motivation and Context --- include/onnxruntime/core/common/make_string.h | 22 ++++++++++--------- .../quantization/blockwise_quant_block_bnb4.h | 20 +++++++---------- onnxruntime/core/framework/murmurhash3.cc | 14 +++++------- .../providers/cpu/tensor/gather_elements.cc | 10 +++------ onnxruntime/core/util/force_inline.h | 10 +++++++++ onnxruntime/core/util/matrix_layout.h | 6 +---- 6 files changed, 40 insertions(+), 42 deletions(-) create mode 100644 onnxruntime/core/util/force_inline.h diff --git a/include/onnxruntime/core/common/make_string.h b/include/onnxruntime/core/common/make_string.h index 826898de85..c47e6cc35e 100644 --- a/include/onnxruntime/core/common/make_string.h +++ b/include/onnxruntime/core/common/make_string.h @@ -21,27 +21,29 @@ #include #include +#include "core/util/force_inline.h" + namespace onnxruntime { namespace detail { -inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { +ORT_FORCEINLINE void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { } template -inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept { +ORT_FORCEINLINE void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept { ss << t; } template -inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept { +ORT_FORCEINLINE void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept { MakeStringImpl(ss, t); MakeStringImpl(ss, args...); } // see MakeString comments for explanation of why this is necessary template -inline std::string MakeStringImpl(const Args&... args) noexcept { +ORT_FORCEINLINE std::string MakeStringImpl(const Args&... args) noexcept { std::ostringstream ss; MakeStringImpl(ss, args...); return ss.str(); @@ -78,7 +80,7 @@ using if_char_array_make_ptr_t = typename if_char_array_make_ptr::type; * This version uses the current locale. */ template -std::string MakeString(const Args&... args) { +ORT_FORCEINLINE std::string MakeString(const Args&... args) { // We need to update the types from the MakeString template instantiation to decay any char[n] to char*. // e.g. MakeString("in", "out") goes from MakeString to MakeStringImpl // so that MakeString("out", "in") will also match MakeStringImpl instead of requiring @@ -98,7 +100,7 @@ std::string MakeString(const Args&... args) { * This version uses std::locale::classic(). */ template -std::string MakeStringWithClassicLocale(const Args&... args) { +ORT_FORCEINLINE std::string MakeStringWithClassicLocale(const Args&... args) { std::ostringstream ss; ss.imbue(std::locale::classic()); detail::MakeStringImpl(ss, args...); @@ -107,19 +109,19 @@ std::string MakeStringWithClassicLocale(const Args&... args) { // MakeString versions for already-a-string types. -inline std::string MakeString(const std::string& str) { +ORT_FORCEINLINE std::string MakeString(const std::string& str) { return str; } -inline std::string MakeString(const char* cstr) { +ORT_FORCEINLINE std::string MakeString(const char* cstr) { return cstr; } -inline std::string MakeStringWithClassicLocale(const std::string& str) { +ORT_FORCEINLINE std::string MakeStringWithClassicLocale(const std::string& str) { return str; } -inline std::string MakeStringWithClassicLocale(const char* cstr) { +ORT_FORCEINLINE std::string MakeStringWithClassicLocale(const char* cstr) { return cstr; } diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h index cb8e97a592..f1692a1ab7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -7,21 +7,17 @@ #include #include +#include "core/util/force_inline.h" + namespace onnxruntime { namespace contrib { -#if defined(_MSC_VER) -#define FORCEINLINE __forceinline -#else -#define FORCEINLINE __attribute__((always_inline)) inline -#endif - typedef enum Bnb_DataType_t { FP4 = 0, NF4 = 1, } Bnb_DataType_t; -FORCEINLINE uint8_t QuantizeOneFP4(float x) { +ORT_FORCEINLINE uint8_t QuantizeOneFP4(float x) { // FP4 with bias of 3 // first bit is a sign // subnormals @@ -69,7 +65,7 @@ FORCEINLINE uint8_t QuantizeOneFP4(float x) { } } -FORCEINLINE uint8_t QuantizeOneNF4(float x) { +ORT_FORCEINLINE uint8_t QuantizeOneNF4(float x) { if (x > 0.03979014977812767f) { if (x > 0.3893125355243683f) { // 1 if (x > 0.6427869200706482f) { // 11 @@ -120,7 +116,7 @@ FORCEINLINE uint8_t QuantizeOneNF4(float x) { } template -FORCEINLINE uint8_t QuantizeOneBnb4(float x) { +ORT_FORCEINLINE uint8_t QuantizeOneBnb4(float x) { if constexpr (DATA_TYPE == FP4) return QuantizeOneFP4(x); else @@ -128,7 +124,7 @@ FORCEINLINE uint8_t QuantizeOneBnb4(float x) { } template -FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { +ORT_FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { float local_absmax = 0.0f; int32_t block_len = std::min(block_size, numel - block_idx * block_size); @@ -177,7 +173,7 @@ static float nf4_qaunt_map[16] = {-1.0f, 1.0f}; template -FORCEINLINE T DequantizeOneBnb4(uint8_t x) { +ORT_FORCEINLINE T DequantizeOneBnb4(uint8_t x) { if constexpr (DATA_TYPE == FP4) return static_cast(fp4_qaunt_map[x]); else @@ -185,7 +181,7 @@ FORCEINLINE T DequantizeOneBnb4(uint8_t x) { } template -FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { +ORT_FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { int32_t block_len = std::min(block_size, numel - block_idx * block_size); int32_t src_offset = block_idx * block_size / 2; int32_t dst_offset = block_idx * block_size; diff --git a/onnxruntime/core/framework/murmurhash3.cc b/onnxruntime/core/framework/murmurhash3.cc index e2dbba9b07..802f0a4c58 100644 --- a/onnxruntime/core/framework/murmurhash3.cc +++ b/onnxruntime/core/framework/murmurhash3.cc @@ -17,6 +17,8 @@ #include "core/framework/endian.h" +#include "core/util/force_inline.h" + //----------------------------------------------------------------------------- // Platform-specific functions and macros @@ -24,8 +26,6 @@ #if defined(_MSC_VER) -#define FORCE_INLINE __forceinline - #include #define ROTL32(x, y) _rotl(x, y) @@ -37,8 +37,6 @@ #else // defined(_MSC_VER) -#define FORCE_INLINE inline __attribute__((always_inline)) - inline uint32_t rotl32(uint32_t x, int8_t r) { return (x << r) | (x >> (32 - r)); } @@ -61,7 +59,7 @@ inline uint64_t rotl64(uint64_t x, int8_t r) { // // Changes to support big-endian from https://github.com/explosion/murmurhash/pull/27/ // were manually applied to original murmurhash3 source code. -FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { +ORT_FORCEINLINE uint32_t getblock32(const uint32_t* p, int i) { if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { return p[i]; } else { @@ -73,7 +71,7 @@ FORCE_INLINE uint32_t getblock32(const uint32_t* p, int i) { } } -FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { +ORT_FORCEINLINE uint64_t getblock64(const uint64_t* p, int i) { if constexpr (onnxruntime::endian::native == onnxruntime::endian::little) { return p[i]; } else { @@ -92,7 +90,7 @@ FORCE_INLINE uint64_t getblock64(const uint64_t* p, int i) { //----------------------------------------------------------------------------- // Finalization mix - force all bits of a hash block to avalanche -FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) { +ORT_FORCEINLINE constexpr uint32_t fmix32(uint32_t h) { h ^= h >> 16; h *= 0x85ebca6b; h ^= h >> 13; @@ -104,7 +102,7 @@ FORCE_INLINE constexpr uint32_t fmix32(uint32_t h) { //---------- -FORCE_INLINE constexpr uint64_t fmix64(uint64_t k) { +ORT_FORCEINLINE constexpr uint64_t fmix64(uint64_t k) { k ^= k >> 33; k *= BIG_CONSTANT(0xff51afd7ed558ccd); k ^= k >> 33; diff --git a/onnxruntime/core/providers/cpu/tensor/gather_elements.cc b/onnxruntime/core/providers/cpu/tensor/gather_elements.cc index a2239844eb..495ff19d86 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_elements.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_elements.cc @@ -5,6 +5,8 @@ #include "gather_elements.h" #include "onnxruntime_config.h" +#include "core/util/force_inline.h" + namespace onnxruntime { ONNX_CPU_OPERATOR_VERSIONED_KERNEL( @@ -66,14 +68,8 @@ static inline size_t CalculateOffset(size_t inner_dim, const TensorPitches& inpu return base_offset; } -#if defined(_MSC_VER) -#define FORCEINLINE __forceinline -#else -#define FORCEINLINE __attribute__((always_inline)) inline -#endif - template -FORCEINLINE int64_t GetIndex(size_t i, const T* indices, int64_t axis_size) { +ORT_FORCEINLINE int64_t GetIndex(size_t i, const T* indices, int64_t axis_size) { int64_t index = indices[i]; if (index < 0) // Handle negative indices index += axis_size; diff --git a/onnxruntime/core/util/force_inline.h b/onnxruntime/core/util/force_inline.h new file mode 100644 index 0000000000..cd15107004 --- /dev/null +++ b/onnxruntime/core/util/force_inline.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(_MSC_VER) +#define ORT_FORCEINLINE __forceinline +#else +#define ORT_FORCEINLINE __attribute__((always_inline)) inline +#endif diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index dbf961ab5b..105da1f5e2 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -17,11 +17,7 @@ #include #include -#if defined(_MSC_VER) -#define ORT_FORCEINLINE __forceinline -#else -#define ORT_FORCEINLINE __attribute__((always_inline)) inline -#endif +#include "core/util/force_inline.h" namespace onnxruntime {