diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index f2a16fb29d..cf298aee9f 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -172,10 +172,8 @@ target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) - include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - endif() + include(cutlass) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index a52e941b23..df62199dc2 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -783,7 +783,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) - target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock) + target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md new file mode 100644 index 0000000000..7e8d30cd18 --- /dev/null +++ b/onnxruntime/core/mickey/README.md @@ -0,0 +1,6 @@ +# About Mickey + +Playful name for a template library of high performance cuda code that +are often shared by various AI operators. The intention is to make this +header files only, with no binary impact unless it is instantiated +where it is needed. diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/prepack_sm80.h new file mode 100644 index 0000000000..e291ab39e8 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/prepack_sm80.h @@ -0,0 +1,325 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * prepack_sm80.h + * + * Abstract: + * Prepack weights and quantization parameters (scales and offsets) for + * GEMM, where activations are fp16 or bf16, and weights are block-wise + * 4b quantized values, specifically for Ampere GPUs. + * + * Prepacking enables faster loading of weights and quantization parameters + * into tensor cores, and faster dequantization of weights. + * + * Only supports fp16 for now, bfloat16 support will be added later. + */ + +#pragma once + +#include "core/common/common.h" +#include "core/util/matrix_layout.h" + +namespace onnxruntime { +namespace cuda { + +/** + * @brief Blockwise quantization methods + * @tparam ElementT source data type, fp16 + * @tparam block_size number of elemenets quantized together + * @tparam qbits number of bits in each quantized element + * @tparam Columnwise true: elements in a block come from one single column + * false: elements in a block come from one single row + */ +template < + typename ElementT, + int block_size, + int qbits, + bool Columnwise, + bool ExtraBoundsCheck = false> +struct BlockwiseQuantization { + static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(sizeof(ElementT) == 2, "Only 16b floating point types are supported!"); + + using QuantBlocking = + std::conditional_t, + MatrixShape<1, block_size>>; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + // We pack 4 weights into one 16b element, so we can leverage cutlass tile iterators + // for async share memory loading, and minimizing bank conflict during matrix loading + using ElementWPack = ElementT; + using LayoutWPack = ColumnMajorLayout; // <- layout of packed weight, must be column major + + // Current Ampere kernel use 8b zero point, need to shrink it to 4b in the future + using ElementQOffset = uint8_t; + + // Layout of the quantization parameters (scales and zero points) + // Major on the dimension that has the most parameters per squarish weight block. + // E.g. for column-wise quantization, a [64, 64] block has [2, 64] parameters, + // where each row has more data, so we use row major layout so that warp threads + // can use less load instructions to load more parameters. + using LayoutQmeta = + typename std::conditional::type; + + /** + * @brief Get quantized weight tensor dimensions. + * Actual weight type is int4, we use ElementW = uint8 to avoid possible compilation + * troubles. Since the layout is column major, we are packing 2 weights in a column + * into one int8 + */ + static inline auto get_quant_weights_shape(int rows, int columns) { + return make_Position(rows / 2, columns); + } + + static inline auto get_quant_meta_shape(int rows, int columns) { + return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + } + + /** + * @brief Prepack weight matrix to facilitate matrix loading, depending on MMA + * instruction layout. + * + * The weight matrix is int4, yet we want to leverage existing fp16/bf16 + * tile loading and MMA layout code in CUTLASS. So we group 4 int4 into 2 + * bytes, pretending it's fp16. This grouping must be done in a way to be + * easily unpacked into tiles that match the MMA instruction layout. + * For MMA instruction <16, 8, 16>, each instruction processes 2 8x8 tiles, + * vertically stacked on the K dimension. And MmaTensorOpMultiplicandTileIterator + * loads a tile. + * + * So we stack 2x2 tiles on a 3rd dimeansion, and reshape them in a HWC fashion: + * T0, T2 + * T1, T3 + * ==> + * T0[0, 0], T1[0, 0], T2[0, 0], T3[0, 0] + * T0[1, 0], T1[1, 0], T2[1, 0], T3[1, 0] + * T0[2, 0], T1[2, 0], T2[2, 0], T3[2, 0] + * T0[3, 0], T1[3, 0], T2[3, 0], T3[3, 0] + * ... + * T0[0, 7], T1[0, 7], T2[0, 7], T3[0, 7] + * T0[1, 7], T1[1, 7], T2[1, 7], T3[1, 7] + * T0[2, 7], T1[2, 7], T2[2, 7], T3[2, 7] + * T0[3, 7], T1[3, 7], T2[3, 7], T3[3, 7] + * + * This pack a 8x16 int8 tile into a 16x8 int8 tile, i.e. a 8x8 16b tile + */ + static void prepack_weights( + int rows, + int columns, + const gsl::span& weights, // <- int4 weights, column major + const gsl::span& weights_prepacked // <- int4 prepacked weights tensor, same size buffer + ) { + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && + (rows % QuantBlocking::kRow) == 0 && + (columns % QuantBlocking::kColumn) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(weights.size() == size_t(rows * columns / 2), + "Weight tensor shape mismatch!"); + ORT_ENFORCE(weights_prepacked.size() == weights.size(), + "Prepacked Weight tensor buffer should be the same size!"); + + const MatrixRef + tensor_weight(weights, make_Position(rows / 2, columns)); + const MatrixRef + tensor_weight_prepacked(weights_prepacked, make_Position(rows, columns / 2)); + + // TODO(fuchen)!! parallized this. + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } + } + + /** + * @brief We rearrange the values of the quantization scale and offset tensors + * to facilitate faster loading to tensor core, only 16b gemm, and (1,n) + * block quantization. + */ + static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1; + + static void prepack_quant_scales( + size_t rows, + size_t columns, + const gsl::span& scales, // <- quant scales, column major layout + const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), + "Quantization scale tensor shape mismatch!"); + ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()), + "Prepacked quantization scale tensor buffer should be the same size!"); + + MatrixRef tensor_scale(scales, meta_shape); + MatrixRef tensor_scale_prepacked(scales_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row = 0; row < tensor_scale.shape()[0]; ++row) { + tensor_scale_prepacked.at(row, col) = tensor_scale.at(row, col); + } + } + } + } + + static void prepack_quant_offsets( + size_t rows, + size_t columns, + const gsl::span& offsets, // <- quant offsets, int4, column major layout + const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer + ) { + auto meta_shape = get_quant_meta_shape(rows, columns); + + ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, + "Does not support odd number of rows or columns!"); + ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()), + "Wrong buffer size for prepacked quantization offsets!"); + ORT_ENFORCE(offsets.size() == size_t(((meta_shape[0] + 1) / 2) * meta_shape[1]), + "Quantization offset tensor shape mismatch!"); + + MatrixRef + tensor_offset(offsets, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + MatrixRef tensor_offset_prepacked(offsets_prepacked, meta_shape); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row_blk = 0; row_blk < meta_shape[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + // Potential transpose if the prepacked layout is different from the original layout + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_offset_prepacked.at(row + 0, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_offset_prepacked.at(row + 1, col) = pair01 >> 4; + } + } + } + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index 48d975a7fd..b5784ecb56 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -779,6 +779,17 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + template void MlasBlockwiseQuantizedShape( @@ -790,6 +801,16 @@ MlasBlockwiseQuantizedShape( int& q_cols ); +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); void MLASCALL MlasBlockwiseQuantizedBufferSizes( diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h new file mode 100644 index 0000000000..a0405e3203 --- /dev/null +++ b/onnxruntime/core/util/matrix_layout.h @@ -0,0 +1,475 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * matrix_layout.h + * + * Abstract: + * Utils for simplifying positioning and striding in tensors. Inspired + * by CUTLASS, striving for 0 runtime cost while promote safety. + * + * Only supports 2D tensors (matrix) for now. + */ + +#pragma once + +#include +#include "core/common/gsl.h" + +// TODO!! Already have this in cuda, what about cpu code though? +#if defined(_MSC_VER) +#define ORT_FORCEINLINE __forceinline +#else +#define ORT_FORCEINLINE __attribute__((always_inline)) inline +#endif + +namespace onnxruntime { + +// +// Clang-format doesn't handle force inline decorator well, it insists on +// adding extra indentation to the next line, making it very confusing +// to read. So we turn it off for this file. +// clang-format off +// + +/** + * @brief A tuple of integers to represent tensor coordinates + */ +template < + int Rank_, ///< Logical rank of coordinate + typename Index_ = int, ///< Index type used for each dimension + typename LongIndex_ = int64_t ///< Long index type used for linear offsets + > +struct Position { + public: + /// Number of elements in Position + static int const kRank = Rank_; + + /// Index type used to store elements + using Index = Index_; + + /// Type used to represent linear offsets + using LongIndex = LongIndex_; + + private: + Index idx[kRank]; + + public: + ORT_FORCEINLINE explicit Position(Index value = Index(0)) { + for (int i = 0; i < kRank; ++i) { + idx[i] = value; + } + } + + /// Constructs from an array of integers + ORT_FORCEINLINE + Position(Index const (&_idx)[kRank]) { + for (int i = 0; i < kRank; ++i) { + idx[i] = _idx[i]; + } + } + + template + ORT_FORCEINLINE + Position(Position other) { + for (int i = 0; i < kRank; ++i) { + idx[i] = other[i]; + } + } + + ORT_FORCEINLINE + Position operator+(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position operator-(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position operator*(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position operator/(Position const& b) const { + Position c; + for (int i = 0; i < kRank; ++i) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + ORT_FORCEINLINE + Position& operator+=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE + Position& operator-=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] -= b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE + Position& operator*=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] *= b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE + Position& operator/=(Position const& b) { + for (int i = 0; i < kRank; ++i) { + idx[i] /= b.idx[i]; + } + return *this; + } + + ORT_FORCEINLINE Index& operator[](int dim) { return idx[dim]; } + + ORT_FORCEINLINE Index const& operator[](int dim) const { return idx[dim]; } + + ORT_FORCEINLINE bool operator==(Position const& b) const { + bool equal = true; + for (int i = 0; equal && i < kRank; ++i) { + equal = (idx[i] == b.idx[i]); + } + return equal; + } + + ORT_FORCEINLINE bool operator!=(Position const& b) const { return !(*this == b); } + + ORT_FORCEINLINE + Position& clamp(Position const& max, Position const& min = Position()) { + for (int i = 0; i < kRank; ++i) { + idx[i] = std::max(std::min(idx[i], max.idx[i]), min.idx[i]); + } + return *this; + } + + ORT_FORCEINLINE + Index sum() const { + Index sum_(idx[0]); + for (int i = 1; i < kRank; ++i) { + sum_ += idx[i]; + } + return sum_; + } + + ORT_FORCEINLINE + LongIndex product() const { + LongIndex product_(idx[0]); + for (int i = 1; i < kRank; ++i) { + product_ *= idx[i]; + } + return product_; + } +}; + +template +Position<2, T, L> make_Position(T _0, T _1) { + T values[2] = {_0, _1}; + return Position<2, T, L>(values); +} + +template +Position<3, T, L> make_Position(T _0, T _1, T _2) { + T values[3] = {_0, _1, _2}; + return Position<2, T, L>(values); +} + +/// Describes the size of a matrix tile +template < + int Row_, ///< rows of a matrix + int Column_ ///< columns of a matrix + > +struct MatrixShape { + static int const kRow = Row_; ///< rows of a matrix + static int const kColumn = Column_; ///< columns of a matrix + static int const kCount = Row_ * Column_; ///< total number of elements in a matrix + + ORT_FORCEINLINE static Position<2> toCoord() { + return make_Position(kRow, kColumn); + } +}; + +/** + * @brief Defines a mapping from logical coordinate to linear memory + * offsets in a row major layout matrix + */ +class RowMajorLayout { + public: + /// Index type used for coordinates + using Index = int; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using MatCoord = Position<2, Index, LongIndex>; + + private: + Index stride_; + + public: + ORT_FORCEINLINE + RowMajorLayout(Index ldm = 0) : stride_(ldm) {} + + ORT_FORCEINLINE static RowMajorLayout packed(MatCoord const& extent) { + return RowMajorLayout(extent[1]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ORT_FORCEINLINE + LongIndex operator()(MatCoord const& coord) const { + return LongIndex(coord[0]) * stride_ + coord[1]; + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + ORT_FORCEINLINE + MatCoord inverse(LongIndex offset) const { + return make_Position(Index(offset / stride_), Index(offset % stride_)); + } + + ORT_FORCEINLINE + Index stride() const { + return stride_; + } +}; + +class ColumnMajorLayout { + public: + /// Index type used for coordinates + using Index = int; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using MatCoord = Position<2, Index, LongIndex>; + + private: + Index stride_; + + public: + ORT_FORCEINLINE + ColumnMajorLayout(Index ldm = 0) : stride_(ldm) {} + + ORT_FORCEINLINE static ColumnMajorLayout packed(MatCoord const& extent) { + return ColumnMajorLayout(extent[0]); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ORT_FORCEINLINE + LongIndex operator()(MatCoord const& coord) const { + return LongIndex(coord[1]) * LongIndex(stride_) + coord[0]; + } + + /// Inverse of layout function, mapping linear offset to logical coordinate + ORT_FORCEINLINE + MatCoord inverse(LongIndex offset) const { + return make_Position(Index(offset % stride_), Index(offset / stride_)); + } + + ORT_FORCEINLINE + Index stride() const { + return stride_; + } +}; + +/** + * @brief A reference to a tensor, with a layout object to map logical + * coordinates to linear offsets. + */ +template < + /// Data type of element stored within tensor, must be numerical types + typename Element_, + /// Defines a mapping from logical coordinate to linear memory offsets + typename Layout_, + /// If true, extra bounds checking is performed on all accesses + bool ExtraBoundsCheck_ = false> +class MatrixRef { + public: + /// Data type of individual access + using Element = Element_; + + using Reference = Element&; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using MatCoord = typename Layout::MatCoord; + + /// MatrixRef to constant data + using ConstMatrixRef = MatrixRef< + typename std::remove_const::type const, + Layout, ExtraBoundsCheck_>; + + /// MatrixRef to non-constant data + using NonConstMatrixRef = MatrixRef< + typename std::remove_const::type, + Layout, ExtraBoundsCheck_>; + + static constexpr bool IsNonConstRef = std::is_same>::value; + + private: + /// Pointer to data + gsl::span data_; + + /// Shape of matrix + MatCoord shape_; + + /// Layout object maps logical coordinates to linear offsets + Layout layout_; + + public: + ORT_FORCEINLINE + MatrixRef() : data_() {} + + ORT_FORCEINLINE + MatrixRef( + gsl::span const& data, ///< pointer to start of tensor + MatCoord const& shape ///< shape of tensor + ) : data_(data), shape_(shape), layout_(Layout::packed(shape)) { + Expects(data_.size() >= size_t(shape_.product())); + } + + ORT_FORCEINLINE + MatrixRef( + Element* ptr, ///< pointer to start of tensor + LongIndex size, ///< size of tensor in elements + MatCoord const& shape ///< shape of tensor + ) : data_(ptr, size), shape_(shape), layout_(Layout::packed(shape)) { + Expects(data_.size() >= shape_.product()); + } + + /// Converting constructor from MatrixRef to non-constant data. + template + ORT_FORCEINLINE + MatrixRef( + NonConstMatrixRef const& ref, ///< MatrixRef to non-const data + /// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const + _Magic magic = (typename std::enable_if::type)0 + ) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {} + + ORT_FORCEINLINE + ConstMatrixRef const_ref() const { + return ConstMatrixRef(data_, shape_); + } + + ORT_FORCEINLINE + NonConstMatrixRef non_const_ref() { + return NonConstMatrixRef( + const_cast::type*>(data_.data()), + data_.size(), shape_); + } + + /// Returns true if the MatrixRef is non-null + ORT_FORCEINLINE + bool good() const { return !data_.empty(); } + + ORT_FORCEINLINE + gsl::span const& data() const { return data_; } + + ORT_FORCEINLINE + MatCoord const& shape() const { return shape_; } + + ORT_FORCEINLINE + Layout& layout() { return layout_; } + + ORT_FORCEINLINE + Layout layout() const { return layout_; } + + ORT_FORCEINLINE + Index stride() const { return layout_.stride(); } + + ORT_FORCEINLINE + Index& stride() { return layout_.stride(); } + + /// Computes the offset of an index from the origin of the tensor + ORT_FORCEINLINE + LongIndex offset(MatCoord const& coord) const { + if constexpr (ExtraBoundsCheck_) { + Expects(coord[0] >= 0 && coord[0] < shape_[0]); + Expects(coord[1] >= 0 && coord[1] < shape_[1]); + } + return layout_(coord); + } + + /// Returns a reference to the element at a given Coord + ORT_FORCEINLINE + Reference at(MatCoord const& coord) const { + return data_[offset(coord)]; + } + + ORT_FORCEINLINE + Reference at(int row, int col) const { + return data_[offset(make_Position(row, col))]; + } + + /// Returns a reference to the element at a given Coord + ORT_FORCEINLINE + Reference operator[](MatCoord const& coord) const { + return data_[offset(coord)]; + } +}; + +/// Constructs a MatrixRef, deducing types from arguments. +template < + typename Element, + typename Layout = RowMajorLayout, + bool ExtraBoundsCheck = false> +ORT_FORCEINLINE +MatrixRef +make_MatrixRef( + Element* ptr, + int64_t size, + typename Layout::MatCoord const& shape) { + return MatrixRef(ptr, size, shape); +} + +template < + typename Element, + typename Layout = RowMajorLayout, + bool ExtraBoundsCheck = false> +ORT_FORCEINLINE +MatrixRef +make_MatrixRef( + const gsl::span& span, + typename Layout::MatCoord const& shape) { + return MatrixRef(span, shape); +} + +// clang-format off + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc new file mode 100644 index 0000000000..aba2b0b2cb --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc @@ -0,0 +1,507 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/framework/float16.h" +#include "core/mickey/blk_q4/prepack_sm80.h" +#include "core/mlas/inc/mlas_q4.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +void prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns); + EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +void prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn)); + EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } + } else { + // In all other cases, we don't prepack scale or offset + FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking"; + } +} + +template +void prepack_quant_offsets_ref( + size_t rows, + size_t columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn)); + EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking"; + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); + tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); + tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); + tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); + } + } + } + } +} + +template +void testPrepack(int rows, int columns, bool has_offset = true) { + using ElementT = MLFloat16; + constexpr int block_size = 32; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + ColumnMajorQuantBlocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + std::vector q_weights(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + std::vector q_scales(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + std::vector q_zp(meta_shape.product()); + for (size_t i = 0; i < q_zp.size(); i++) { + q_zp[i] = dis(gen) % 16; + } + MatrixRef tensor_offset( + q_zp, meta_shape); + +#if 0 // debug + // Fill tensor_q_weight with the patterned data, easier to debug with print + int loop_val = 0; + int offset = 3; + for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { + for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); + auto val = (loop_val + offset) % 256; + tensor_q_weight.at(weight_cord) = ElementW(val); + loop_val++; + if (loop_val == 256) { + loop_val = 0; + offset += 11; + } + } + } + } + } + for (int col = 0; col < tensor_scale.extent().column(); ++col){ + int c = col * QuantBlocking::kColumn; + for (int row = 0; row < tensor_scale.extent().row(); ++row){ + int r = row * QuantBlocking::kRow; + auto weight_cord = cutlass::make_Coord(r/2, c); + int w = 0; + if (r % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + tensor_scale.at({row, col}) = w; + tensor_offset.at({row, col}) = ElementQOffset(w); + } + } + + int fill_val = -512; + int factor = 1; + for (int col = 0; col < tensor_scale.extent().column(); ++col){ + for (int row = 0; row < tensor_scale.extent().row(); ++row){ + tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); + fill_val++; + if (fill_val == 512) { + fill_val = -512; + factor += 1; + } + } + } + +#endif // debug + + std::vector dequants(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B for reference + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape( + block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); + // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes + EXPECT_EQ(q_rows, q_weight_shape[0]); + EXPECT_EQ(q_cols, q_weight_shape[1]); + + // + // Quantization tool outputs: + // + std::vector o_elements(q_rows * q_cols); + MatrixRef tensor_o_elements(o_elements, q_weight_shape); + + std::vector o_scales(meta_shape.product()); + MatrixRef tensor_o_scales(o_scales, meta_shape); + + std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); + MatrixRef tensor_o_zp( + o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); + + MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, + tensor_dequant.data().data(), block_size, + ColumnMajorQuantBlocking, rows, columns, columns, nullptr); + for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { + for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { + EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) + << "quantized value mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + if (has_offset) { + uint8_t pair01 = tensor_o_zp.at(row / 2, col); + EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) + << "quantized offset mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) + << "quantized offset mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) + << "quantized scale mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) + << "quantized scale mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } + + // + // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, + // quantization scale tensor_scale and quantization offset tensor_offset. The above + // testing just make sure our test setup is consistent with quantization tool output. + // + // Next we test the prepack code + // + + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(rows, columns / 2)); + prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_w(q_weight_shape.product()); + MatrixRef tensor_packed_w( + packed_w, make_Position(rows, columns / 2)); + Base::prepack_weights(rows, columns, o_elements, packed_w); + + for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) + << "prepacked weights mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) + : tensor_scale; + if (Base::ShouldRearrangeMeta) { + prepack_quant_scales_ref( + rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } + + std::vector packed_scales(meta_shape.product()); + MatrixRef tensor_packed_s( + packed_scales, meta_shape); + Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); + + for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) + << "prepacked scales mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + if (has_offset) { + std::vector packed_zp_ref(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) + : tensor_offset; + if (Base::ShouldRearrangeMeta) { + prepack_quant_offsets_ref( + rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } + + std::vector packed_zp(meta_shape.product()); + MatrixRef tensor_packed_zp( + packed_zp, meta_shape); + Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); + + for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) + << "prepacked offsets mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } +} + +// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 +TEST(BlkQ4_GEMM, PrepackSm80Test) { + testPrepack(32, 32); + testPrepack(32, 32, false); + testPrepack(32, 32); + testPrepack(32, 32, false); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128, false); + testPrepack(128, 32, false); + testPrepack(256, 256, false); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128, false); + testPrepack(128, 32, false); + testPrepack(256, 256, false); +} + +} // namespace test +} // namespace onnxruntime