From 4d305282da4a4c0bd20447ff633db852d65f3d23 Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Thu, 27 Jan 2022 00:09:16 -0800 Subject: [PATCH] [ROCm] Enable BFloat16 for Gemm and MatMul Op (#10398) * gemm-bf16 * gemm bf16 * gemm bf16 * matmul bf16 * minor style change Co-authored-by: Ethan Tao Co-authored-by: root --- onnxruntime/core/providers/rocm/fpgeneric.cu | 15 ++++ onnxruntime/core/providers/rocm/math/gemm.cc | 1 + .../core/providers/rocm/math/matmul.cc | 1 + .../providers/rocm/rocm_execution_provider.cc | 8 +- .../providers/rocm/shared_inc/fpgeneric.h | 90 +++++++++++++++++++ .../test/providers/cpu/math/gemm_test.cc | 12 ++- .../test/providers/cpu/math/matmul_test.cc | 56 ++++++++++++ 7 files changed, 176 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index c53934c688..0ef3e0af93 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -44,6 +44,13 @@ __global__ void CopyVectorHalf(const half* x, int incx, half* y, int incy, int n y[id * incy] = x[id * incx]; } +__global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onnxruntime::BFloat16* y, int incy, + int n) { + int id = blockIdx.x * blockDim.x + threadIdx.x; + if (id >= n) return; + y[id * incy] = x[id * incx]; +} + } // namespace rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) { @@ -64,3 +71,11 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons hipLaunchKernelGGL(CopyVectorHalf, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n); return rocblas_status_success; } + +rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx, + onnxruntime::BFloat16* y, int incy) { + dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); + dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); + hipLaunchKernelGGL(CopyVectorBFloat16, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n); + return rocblas_status_success; +} \ No newline at end of file diff --git a/onnxruntime/core/providers/rocm/math/gemm.cc b/onnxruntime/core/providers/rocm/math/gemm.cc index 4a6091420d..8e32c9d130 100644 --- a/onnxruntime/core/providers/rocm/math/gemm.cc +++ b/onnxruntime/core/providers/rocm/math/gemm.cc @@ -53,6 +53,7 @@ namespace rocm { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template Status Gemm::ComputeInternal(OpKernelContext* ctx) const { diff --git a/onnxruntime/core/providers/rocm/math/matmul.cc b/onnxruntime/core/providers/rocm/math/matmul.cc index 4e6e8a125d..198ff367b8 100644 --- a/onnxruntime/core/providers/rocm/math/matmul.cc +++ b/onnxruntime/core/providers/rocm/math/matmul.cc @@ -41,6 +41,7 @@ namespace rocm { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) // StridedBatchedGemm can be used for the following GEMM computation // C[pnm] = A[pnk]*B[km] or C[pnm] = A[pnk]*B[pkm] diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 6aead1c46b..8d10c6f05c 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1126,11 +1126,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, D // class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul); // class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum); // OpSet 14 @@ -1970,11 +1970,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // OpSet 14 diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 3fb52c2421..cf6971a697 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -5,6 +5,8 @@ #include "core/providers/rocm/rocm_common.h" +using namespace onnxruntime; + // Generalize library calls to be use in template functions // gemm @@ -66,6 +68,33 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } +inline rocblas_status rocblasGemmHelper(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* A, int lda, + const BFloat16* B, int ldb, + const BFloat16* beta, + BFloat16* C, int ldc) { + float h_a = alpha->ToFloat(); + float h_b = beta->ToFloat(); + + // accumulating in FP32 + return rocblas_gemm_ex(handle, + transa, + transb, + m, n, k, + &h_a, + A, rocblas_datatype_bf16_r, lda, + B, rocblas_datatype_bf16_r, ldb, + &h_b, + C, rocblas_datatype_bf16_r, ldc, + C, rocblas_datatype_bf16_r, ldc, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, 0); +} + // batched gemm inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, rocblas_operation transa, @@ -130,6 +159,35 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } +inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* Aarray[], int lda, + const BFloat16* Barray[], int ldb, + const BFloat16* beta, + BFloat16* Carray[], int ldc, + int batch_count) { + float h_a = alpha->ToFloat(); + float h_b = beta->ToFloat(); + + // accumulating in FP32 + return rocblas_gemm_batched_ex(handle, + transa, + transb, + m, n, k, + &h_a, + (const void**)Aarray, rocblas_datatype_bf16_r, lda, + (const void**)Barray, rocblas_datatype_bf16_r, ldb, + &h_b, + (void**)Carray, rocblas_datatype_bf16_r, ldc, + (void**)Carray, rocblas_datatype_bf16_r, ldc, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, 0); +} + // strided batched gemm inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_operation transa, @@ -205,6 +263,37 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } +inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, + rocblas_operation transa, + rocblas_operation transb, + int m, int n, int k, + const BFloat16* alpha, + const BFloat16* A, int lda, + long long int strideA, + const BFloat16* B, int ldb, + long long int strideB, + const BFloat16* beta, + BFloat16* C, int ldc, + long long int strideC, + int batch_count) { + float h_a = alpha->ToFloat(); + float h_b = beta->ToFloat(); + // accumulating in FP32 + return rocblas_gemm_strided_batched_ex(handle, + transa, + transb, + m, n, k, + &h_a, + A, rocblas_datatype_bf16_r, lda, strideA, + B, rocblas_datatype_bf16_r, ldb, strideB, + &h_b, + C, rocblas_datatype_bf16_r, ldc, strideC, + C, rocblas_datatype_bf16_r, ldc, strideC, + batch_count, + rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, 0, 0); +} + // transpose using geam inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { return rocblas_sgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); @@ -222,3 +311,4 @@ inline rocblas_status rocblasCopyHelper(hipStream_t /*stream*/, rocblas_handle h return rocblas_dcopy(handle, n, x, incx, y, incy); } rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const half* x, int incx, half* y, int incy); +rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 5bdd5c3dea..2411f2a14b 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -45,7 +45,7 @@ TEST(GemmOpTest, GemmNoTrans_double) { TestGemmNoTrans(); } -// Only CUDA kernel has float 16 support +// Only CUDA and ROCM kernel has float 16 support #if defined(USE_CUDA) || defined(USE_ROCM) TEST(GemmOpTest, GemmNoTrans_f16) { #ifdef USE_CUDA @@ -86,13 +86,15 @@ TEST(GemmOpTest, GemmNoTrans_f16) { } #endif -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(GemmOpTest, GemmNoTrans_bfloat16) { +#ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; return; } +#endif OpTester test("Gemm", 14); test.AddAttribute("transA", (int64_t)0); test.AddAttribute("transB", (int64_t)0); @@ -103,7 +105,11 @@ TEST(GemmOpTest, GemmNoTrans_bfloat16) { test.AddInput("C", {2, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); test.AddOutput("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f})); std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index c286d52cfc..862ef375b0 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "default_providers.h" namespace onnxruntime { @@ -162,6 +163,61 @@ TEST(MathOpTest, MatMulUint64Type) { RunMatMulTest(9); } +#if defined(USE_CUDA) || defined(USE_ROCM) +TEST(MathOpTest, MatMul_Float16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif + OpTester test("MatMul", 14); + + std::vector A{1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}; + std::vector B(12, 1.0f); + std::vector Y{10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f}; + + std::vector f_A(8); + std::vector f_B(12); + std::vector f_Y(6); + ConvertFloatToMLFloat16(A.data(), f_A.data(), 8); + ConvertFloatToMLFloat16(B.data(), f_B.data(), 12); + ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6); + + test.AddInput("A", {2, 4}, f_A); + test.AddInput("B", {4, 3}, f_B); + test.AddOutput("Y", {2, 3}, f_Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: fp16 is not supported +} +#endif + +#if defined(USE_CUDA) || defined(USE_ROCM) +TEST(MathOpTest, MatMul_BFloat16) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif + OpTester test("MatMul", 14); + + test.AddInput("A", {2, 4}, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f})); + test.AddInput("B", {4, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})); + test.AddOutput("Y", {2, 3}, MakeBFloat16({10.0f, 10.0f, 10.0f, -10.0f, -10.0f, -10.0f})); + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + #ifndef ENABLE_TRAINING // Prepacking is enabled only on non-training builds TEST(MathOpTest, MatMulSharedPrepackedWeights) { OpTester test("MatMul");