mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[ROCm] Enable int8 for MatMulInteger Op (#11776)
This commit is contained in:
parent
3d2bcb3386
commit
03dfcb0e87
3 changed files with 69 additions and 6 deletions
68
onnxruntime/core/providers/rocm/integer_gemm.cc
Normal file
68
onnxruntime/core/providers/rocm/integer_gemm.cc
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <rocblas.h>
|
||||
#include "core/providers/rocm/shared_inc/integer_gemm.h"
|
||||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
#include "core/providers/rocm/shared_inc/rocm_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace rocm {
|
||||
|
||||
inline int roundoff(int v, int d) {
|
||||
return (v + d - 1) / d * d;
|
||||
}
|
||||
|
||||
Status GemmInt8(int m, int n, int k,
|
||||
int32_t alpha, int32_t beta,
|
||||
const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc,
|
||||
const RocmKernel* rocm_kernel) {
|
||||
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
|
||||
ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null");
|
||||
|
||||
hipStream_t stream = rocm_kernel->Stream();
|
||||
|
||||
// pad A and B to make their leading dimension be multiples of 32
|
||||
// because cublasGemmEx requires:
|
||||
// 1. leading dimension is multiples of 4
|
||||
// 2. A, B is 32-bit aligned
|
||||
|
||||
const int mask = 0x1F;
|
||||
int lda_aligned = lda;
|
||||
IAllocatorUniquePtr<int8_t> a_padded;
|
||||
if ((mask & lda_aligned) != 0) {
|
||||
lda_aligned = roundoff(lda, 32);
|
||||
a_padded = rocm_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
|
||||
HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
int ldb_aligned = ldb;
|
||||
IAllocatorUniquePtr<int8_t> b_padded;
|
||||
if ((mask & ldb_aligned) != 0) {
|
||||
ldb_aligned = roundoff(ldb, 32);
|
||||
b_padded = rocm_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
|
||||
HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
auto handle = rocm_kernel->RocblasHandle();
|
||||
rocblas_set_stream(handle, stream);
|
||||
ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex(
|
||||
handle,
|
||||
rocblas_operation_none, rocblas_operation_none,
|
||||
n, m, k,
|
||||
&alpha,
|
||||
ldb_aligned == ldb ? b : b_padded.get(), rocblas_datatype_i8_r, ldb_aligned,
|
||||
lda_aligned == lda ? a : a_padded.get(), rocblas_datatype_i8_r, lda_aligned,
|
||||
&beta,
|
||||
c, rocblas_datatype_i32_r, ldc,
|
||||
c, rocblas_datatype_i32_r, ldc, // C == D
|
||||
rocblas_datatype_i32_r,
|
||||
rocblas_gemm_algo_standard,
|
||||
0, 0));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace rocm
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1262,7 +1262,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, MatMul)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, float, Elu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, double, Elu)>,
|
||||
|
|
|
|||
|
|
@ -104,10 +104,6 @@ provider_excluded_files = [
|
|||
"math/einsum.h",
|
||||
"math/gemm.cc",
|
||||
"math/matmul.cc",
|
||||
"math/matmul_integer.cc",
|
||||
"math/matmul_integer.cu",
|
||||
"math/matmul_integer.cuh",
|
||||
"math/matmul_integer.h",
|
||||
"math/softmax_impl.cu",
|
||||
"math/softmax_warpwise_impl.cuh",
|
||||
"math/softmax.cc",
|
||||
|
|
@ -140,7 +136,6 @@ provider_excluded_files = [
|
|||
"rnn/rnn_impl.h",
|
||||
"shared_inc/cuda_call.h",
|
||||
"shared_inc/fpgeneric.h",
|
||||
"shared_inc/integer_gemm.h",
|
||||
"cuda_allocator.cc",
|
||||
"cuda_allocator.h",
|
||||
"cuda_call.cc",
|
||||
|
|
|
|||
Loading…
Reference in a new issue