[ROCm] Enable int8 for MatMulInteger Op (#11776)

This commit is contained in:
Xinya Zhang 2022-07-21 13:20:48 -05:00 committed by GitHub
parent 3d2bcb3386
commit 03dfcb0e87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 6 deletions

View file

@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <hip/hip_runtime.h>
#include <rocblas.h>
#include "core/providers/rocm/shared_inc/integer_gemm.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
namespace onnxruntime {
namespace rocm {
inline int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}
Status GemmInt8(int m, int n, int k,
int32_t alpha, int32_t beta,
const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc,
const RocmKernel* rocm_kernel) {
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null");
hipStream_t stream = rocm_kernel->Stream();
// pad A and B to make their leading dimension be multiples of 32
// because cublasGemmEx requires:
// 1. leading dimension is multiples of 4
// 2. A, B is 32-bit aligned
const int mask = 0x1F;
int lda_aligned = lda;
IAllocatorUniquePtr<int8_t> a_padded;
if ((mask & lda_aligned) != 0) {
lda_aligned = roundoff(lda, 32);
a_padded = rocm_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream));
}
int ldb_aligned = ldb;
IAllocatorUniquePtr<int8_t> b_padded;
if ((mask & ldb_aligned) != 0) {
ldb_aligned = roundoff(ldb, 32);
b_padded = rocm_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream));
}
auto handle = rocm_kernel->RocblasHandle();
rocblas_set_stream(handle, stream);
ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex(
handle,
rocblas_operation_none, rocblas_operation_none,
n, m, k,
&alpha,
ldb_aligned == ldb ? b : b_padded.get(), rocblas_datatype_i8_r, ldb_aligned,
lda_aligned == lda ? a : a_padded.get(), rocblas_datatype_i8_r, lda_aligned,
&beta,
c, rocblas_datatype_i32_r, ldc,
c, rocblas_datatype_i32_r, ldc, // C == D
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0, 0));
return Status::OK();
}
} // namespace rocm
} // namespace onnxruntime

View file

@ -1262,7 +1262,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, MatMul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, float, Elu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, double, Elu)>,

View file

@ -104,10 +104,6 @@ provider_excluded_files = [
"math/einsum.h",
"math/gemm.cc",
"math/matmul.cc",
"math/matmul_integer.cc",
"math/matmul_integer.cu",
"math/matmul_integer.cuh",
"math/matmul_integer.h",
"math/softmax_impl.cu",
"math/softmax_warpwise_impl.cuh",
"math/softmax.cc",
@ -140,7 +136,6 @@ provider_excluded_files = [
"rnn/rnn_impl.h",
"shared_inc/cuda_call.h",
"shared_inc/fpgeneric.h",
"shared_inc/integer_gemm.h",
"cuda_allocator.cc",
"cuda_allocator.h",
"cuda_call.cc",