From 03dfcb0e87b6c28a75f074e3e6105c714c40819f Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Thu, 21 Jul 2022 13:20:48 -0500 Subject: [PATCH] [ROCm] Enable int8 for MatMulInteger Op (#11776) --- .../core/providers/rocm/integer_gemm.cc | 68 +++++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 2 +- tools/ci_build/amd_hipify.py | 5 -- 3 files changed, 69 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/rocm/integer_gemm.cc diff --git a/onnxruntime/core/providers/rocm/integer_gemm.cc b/onnxruntime/core/providers/rocm/integer_gemm.cc new file mode 100644 index 0000000000..86e457d743 --- /dev/null +++ b/onnxruntime/core/providers/rocm/integer_gemm.cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + + +#include +#include +#include "core/providers/rocm/shared_inc/integer_gemm.h" + +#include "core/providers/rocm/rocm_common.h" +#include "core/providers/rocm/shared_inc/rocm_call.h" + +namespace onnxruntime { +namespace rocm { + +inline int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +Status GemmInt8(int m, int n, int k, + int32_t alpha, int32_t beta, + const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc, + const RocmKernel* rocm_kernel) { + ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); + ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null"); + + hipStream_t stream = rocm_kernel->Stream(); + + // pad A and B to make their leading dimension be multiples of 32 + // because cublasGemmEx requires: + // 1. leading dimension is multiples of 4 + // 2. A, B is 32-bit aligned + + const int mask = 0x1F; + int lda_aligned = lda; + IAllocatorUniquePtr a_padded; + if ((mask & lda_aligned) != 0) { + lda_aligned = roundoff(lda, 32); + a_padded = rocm_kernel->GetScratchBuffer(m * lda_aligned); + HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream)); + } + + int ldb_aligned = ldb; + IAllocatorUniquePtr b_padded; + if ((mask & ldb_aligned) != 0) { + ldb_aligned = roundoff(ldb, 32); + b_padded = rocm_kernel->GetScratchBuffer(k * ldb_aligned); + HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream)); + } + + auto handle = rocm_kernel->RocblasHandle(); + rocblas_set_stream(handle, stream); + ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex( + handle, + rocblas_operation_none, rocblas_operation_none, + n, m, k, + &alpha, + ldb_aligned == ldb ? b : b_padded.get(), rocblas_datatype_i8_r, ldb_aligned, + lda_aligned == lda ? a : a_padded.get(), rocblas_datatype_i8_r, lda_aligned, + &beta, + c, rocblas_datatype_i32_r, ldc, + c, rocblas_datatype_i32_r, ldc, // C == D + rocblas_datatype_i32_r, + rocblas_gemm_algo_standard, + 0, 0)); + return Status::OK(); +} +} // namespace rocm +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 6fc65bc80d..a3cda47a98 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1262,7 +1262,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 4bf2cdbc1f..c9d6b75130 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -104,10 +104,6 @@ provider_excluded_files = [ "math/einsum.h", "math/gemm.cc", "math/matmul.cc", - "math/matmul_integer.cc", - "math/matmul_integer.cu", - "math/matmul_integer.cuh", - "math/matmul_integer.h", "math/softmax_impl.cu", "math/softmax_warpwise_impl.cuh", "math/softmax.cc", @@ -140,7 +136,6 @@ provider_excluded_files = [ "rnn/rnn_impl.h", "shared_inc/cuda_call.h", "shared_inc/fpgeneric.h", - "shared_inc/integer_gemm.h", "cuda_allocator.cc", "cuda_allocator.h", "cuda_call.cc",