[ROCm] Enable BFloat16 for Gemm and MatMul Op (#10398)

* gemm-bf16

* gemm bf16

* gemm bf16

* matmul bf16

* minor style change

Co-authored-by: Ethan Tao <ettao@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: root <root@GCRAMDRR1-MI100-087.redmond.corp.microsoft.com>
This commit is contained in:
ytaous 2022-01-27 00:09:16 -08:00 committed by GitHub
parent 5f49f40fa5
commit 4d305282da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 176 additions and 7 deletions

View file

@ -44,6 +44,13 @@ __global__ void CopyVectorHalf(const half* x, int incx, half* y, int incy, int n
y[id * incy] = x[id * incx];
}
__global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onnxruntime::BFloat16* y, int incy,
int n) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id >= n) return;
y[id * incy] = x[id * incx];
}
} // namespace
rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {
@ -64,3 +71,11 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons
hipLaunchKernelGGL(CopyVectorHalf, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n);
return rocblas_status_success;
}
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx,
onnxruntime::BFloat16* y, int incy) {
dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1);
dim3 dimBlock(COPY_BLOCK_DIM, 1, 1);
hipLaunchKernelGGL(CopyVectorBFloat16, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n);
return rocblas_status_success;
}

View file

@ -53,6 +53,7 @@ namespace rocm {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
template <typename T>
Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {

View file

@ -41,6 +41,7 @@ namespace rocm {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
REGISTER_KERNEL_TYPED(MLFloat16)
REGISTER_KERNEL_TYPED(BFloat16)
// StridedBatchedGemm can be used for the following GEMM computation
// C[pnm] = A[pnk]*B[km] or C[pnm] = A[pnk]*B[pkm]

View file

@ -1126,11 +1126,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, D
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum);
// OpSet 14
@ -1970,11 +1970,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum)>,
// OpSet 14

View file

@ -5,6 +5,8 @@
#include "core/providers/rocm/rocm_common.h"
using namespace onnxruntime;
// Generalize library calls to be use in template functions
// gemm
@ -66,6 +68,33 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
rocblas_gemm_algo_standard, 0, 0);
}
inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m, int n, int k,
const BFloat16* alpha,
const BFloat16* A, int lda,
const BFloat16* B, int ldb,
const BFloat16* beta,
BFloat16* C, int ldc) {
float h_a = alpha->ToFloat();
float h_b = beta->ToFloat();
// accumulating in FP32
return rocblas_gemm_ex(handle,
transa,
transb,
m, n, k,
&h_a,
A, rocblas_datatype_bf16_r, lda,
B, rocblas_datatype_bf16_r, ldb,
&h_b,
C, rocblas_datatype_bf16_r, ldc,
C, rocblas_datatype_bf16_r, ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
}
// batched gemm
inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
rocblas_operation transa,
@ -130,6 +159,35 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
rocblas_gemm_algo_standard, 0, 0);
}
inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m, int n, int k,
const BFloat16* alpha,
const BFloat16* Aarray[], int lda,
const BFloat16* Barray[], int ldb,
const BFloat16* beta,
BFloat16* Carray[], int ldc,
int batch_count) {
float h_a = alpha->ToFloat();
float h_b = beta->ToFloat();
// accumulating in FP32
return rocblas_gemm_batched_ex(handle,
transa,
transb,
m, n, k,
&h_a,
(const void**)Aarray, rocblas_datatype_bf16_r, lda,
(const void**)Barray, rocblas_datatype_bf16_r, ldb,
&h_b,
(void**)Carray, rocblas_datatype_bf16_r, ldc,
(void**)Carray, rocblas_datatype_bf16_r, ldc,
batch_count,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
}
// strided batched gemm
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
rocblas_operation transa,
@ -205,6 +263,37 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
rocblas_gemm_algo_standard, 0, 0);
}
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m, int n, int k,
const BFloat16* alpha,
const BFloat16* A, int lda,
long long int strideA,
const BFloat16* B, int ldb,
long long int strideB,
const BFloat16* beta,
BFloat16* C, int ldc,
long long int strideC,
int batch_count) {
float h_a = alpha->ToFloat();
float h_b = beta->ToFloat();
// accumulating in FP32
return rocblas_gemm_strided_batched_ex(handle,
transa,
transb,
m, n, k,
&h_a,
A, rocblas_datatype_bf16_r, lda, strideA,
B, rocblas_datatype_bf16_r, ldb, strideB,
&h_b,
C, rocblas_datatype_bf16_r, ldc, strideC,
C, rocblas_datatype_bf16_r, ldc, strideC,
batch_count,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, 0, 0);
}
// transpose using geam
inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) {
return rocblas_sgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
@ -222,3 +311,4 @@ inline rocblas_status rocblasCopyHelper(hipStream_t /*stream*/, rocblas_handle h
return rocblas_dcopy(handle, n, x, incx, y, incy);
}
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const half* x, int incx, half* y, int incy);
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy);

View file

@ -45,7 +45,7 @@ TEST(GemmOpTest, GemmNoTrans_double) {
TestGemmNoTrans<double>();
}
// Only CUDA kernel has float 16 support
// Only CUDA and ROCM kernel has float 16 support
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(GemmOpTest, GemmNoTrans_f16) {
#ifdef USE_CUDA
@ -86,13 +86,15 @@ TEST(GemmOpTest, GemmNoTrans_f16) {
}
#endif
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(GemmOpTest, GemmNoTrans_bfloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
OpTester test("Gemm", 14);
test.AddAttribute("transA", (int64_t)0);
test.AddAttribute("transB", (int64_t)0);
@ -103,7 +105,11 @@ TEST(GemmOpTest, GemmNoTrans_bfloat16) {
test.AddInput<BFloat16>("C", {2, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
test.AddOutput<BFloat16>("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif

View file

@ -3,6 +3,7 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "default_providers.h"
namespace onnxruntime {
@ -162,6 +163,61 @@ TEST(MathOpTest, MatMulUint64Type) {
RunMatMulTest<uint64_t>(9);
}
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(MathOpTest, MatMul_Float16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
OpTester test("MatMul", 14);
std::vector<float> A{1.0f, 2.0f, 3.0f, 4.0f,
-1.0f, -2.0f, -3.0f, -4.0f};
std::vector<float> B(12, 1.0f);
std::vector<float> Y{10.0f, 10.0f, 10.0f,
-10.0f, -10.0f, -10.0f};
std::vector<MLFloat16> f_A(8);
std::vector<MLFloat16> f_B(12);
std::vector<MLFloat16> f_Y(6);
ConvertFloatToMLFloat16(A.data(), f_A.data(), 8);
ConvertFloatToMLFloat16(B.data(), f_B.data(), 12);
ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6);
test.AddInput<MLFloat16>("A", {2, 4}, f_A);
test.AddInput<MLFloat16>("B", {4, 3}, f_B);
test.AddOutput<MLFloat16>("Y", {2, 3}, f_Y);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: fp16 is not supported
}
#endif
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(MathOpTest, MatMul_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
return;
}
#endif
OpTester test("MatMul", 14);
test.AddInput<BFloat16>("A", {2, 4}, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}));
test.AddInput<BFloat16>("B", {4, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
test.AddOutput<BFloat16>("Y", {2, 3}, MakeBFloat16({10.0f, 10.0f, 10.0f, -10.0f, -10.0f, -10.0f}));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
#endif
#ifndef ENABLE_TRAINING // Prepacking is enabled only on non-training builds
TEST(MathOpTest, MatMulSharedPrepackedWeights) {
OpTester test("MatMul");