mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
[ROCm] Enable BFloat16 for Gemm and MatMul Op (#10398)
* gemm-bf16 * gemm bf16 * gemm bf16 * matmul bf16 * minor style change Co-authored-by: Ethan Tao <ettao@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: root <root@GCRAMDRR1-MI100-087.redmond.corp.microsoft.com>
This commit is contained in:
parent
5f49f40fa5
commit
4d305282da
7 changed files with 176 additions and 7 deletions
|
|
@ -44,6 +44,13 @@ __global__ void CopyVectorHalf(const half* x, int incx, half* y, int incy, int n
|
|||
y[id * incy] = x[id * incx];
|
||||
}
|
||||
|
||||
__global__ void CopyVectorBFloat16(const onnxruntime::BFloat16* x, int incx, onnxruntime::BFloat16* y, int incy,
|
||||
int n) {
|
||||
int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (id >= n) return;
|
||||
y[id * incy] = x[id * incx];
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocblas_operation , rocblas_operation , int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int) {
|
||||
|
|
@ -64,3 +71,11 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons
|
|||
hipLaunchKernelGGL(CopyVectorHalf, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n);
|
||||
return rocblas_status_success;
|
||||
}
|
||||
|
||||
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx,
|
||||
onnxruntime::BFloat16* y, int incy) {
|
||||
dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1);
|
||||
dim3 dimBlock(COPY_BLOCK_DIM, 1, 1);
|
||||
hipLaunchKernelGGL(CopyVectorBFloat16, dim3(dimGrid), dim3(dimBlock), 0, stream, x, incx, y, incy, n);
|
||||
return rocblas_status_success;
|
||||
}
|
||||
|
|
@ -53,6 +53,7 @@ namespace rocm {
|
|||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(double)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_KERNEL_TYPED(BFloat16)
|
||||
|
||||
template <typename T>
|
||||
Status Gemm<T>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ namespace rocm {
|
|||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(double)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_KERNEL_TYPED(BFloat16)
|
||||
|
||||
// StridedBatchedGemm can be used for the following GEMM computation
|
||||
// C[pnm] = A[pnk]*B[km] or C[pnm] = A[pnk]*B[pkm]
|
||||
|
|
|
|||
|
|
@ -1126,11 +1126,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, D
|
|||
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul);
|
||||
// class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm);
|
||||
// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum);
|
||||
|
||||
// OpSet 14
|
||||
|
|
@ -1970,11 +1970,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Div)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Cast)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Softmax)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, MatMul)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Relu)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Sigmoid)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Tanh)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Gemm)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, ReduceSum)>,
|
||||
|
||||
// OpSet 14
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
|
||||
#include "core/providers/rocm/rocm_common.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
// Generalize library calls to be use in template functions
|
||||
|
||||
// gemm
|
||||
|
|
@ -66,6 +68,33 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
|
|||
rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmHelper(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
int m, int n, int k,
|
||||
const BFloat16* alpha,
|
||||
const BFloat16* A, int lda,
|
||||
const BFloat16* B, int ldb,
|
||||
const BFloat16* beta,
|
||||
BFloat16* C, int ldc) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
||||
// accumulating in FP32
|
||||
return rocblas_gemm_ex(handle,
|
||||
transa,
|
||||
transb,
|
||||
m, n, k,
|
||||
&h_a,
|
||||
A, rocblas_datatype_bf16_r, lda,
|
||||
B, rocblas_datatype_bf16_r, ldb,
|
||||
&h_b,
|
||||
C, rocblas_datatype_bf16_r, ldc,
|
||||
C, rocblas_datatype_bf16_r, ldc,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
|
||||
// batched gemm
|
||||
inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
|
|
@ -130,6 +159,35 @@ inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
|
|||
rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmBatchedHelper(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
int m, int n, int k,
|
||||
const BFloat16* alpha,
|
||||
const BFloat16* Aarray[], int lda,
|
||||
const BFloat16* Barray[], int ldb,
|
||||
const BFloat16* beta,
|
||||
BFloat16* Carray[], int ldc,
|
||||
int batch_count) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
|
||||
// accumulating in FP32
|
||||
return rocblas_gemm_batched_ex(handle,
|
||||
transa,
|
||||
transb,
|
||||
m, n, k,
|
||||
&h_a,
|
||||
(const void**)Aarray, rocblas_datatype_bf16_r, lda,
|
||||
(const void**)Barray, rocblas_datatype_bf16_r, ldb,
|
||||
&h_b,
|
||||
(void**)Carray, rocblas_datatype_bf16_r, ldc,
|
||||
(void**)Carray, rocblas_datatype_bf16_r, ldc,
|
||||
batch_count,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
|
||||
// strided batched gemm
|
||||
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
|
|
@ -205,6 +263,37 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
|||
rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
|
||||
inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle,
|
||||
rocblas_operation transa,
|
||||
rocblas_operation transb,
|
||||
int m, int n, int k,
|
||||
const BFloat16* alpha,
|
||||
const BFloat16* A, int lda,
|
||||
long long int strideA,
|
||||
const BFloat16* B, int ldb,
|
||||
long long int strideB,
|
||||
const BFloat16* beta,
|
||||
BFloat16* C, int ldc,
|
||||
long long int strideC,
|
||||
int batch_count) {
|
||||
float h_a = alpha->ToFloat();
|
||||
float h_b = beta->ToFloat();
|
||||
// accumulating in FP32
|
||||
return rocblas_gemm_strided_batched_ex(handle,
|
||||
transa,
|
||||
transb,
|
||||
m, n, k,
|
||||
&h_a,
|
||||
A, rocblas_datatype_bf16_r, lda, strideA,
|
||||
B, rocblas_datatype_bf16_r, ldb, strideB,
|
||||
&h_b,
|
||||
C, rocblas_datatype_bf16_r, ldc, strideC,
|
||||
C, rocblas_datatype_bf16_r, ldc, strideC,
|
||||
batch_count,
|
||||
rocblas_datatype_f32_r,
|
||||
rocblas_gemm_algo_standard, 0, 0);
|
||||
}
|
||||
|
||||
// transpose using geam
|
||||
inline rocblas_status rocblasTransposeHelper(hipStream_t /*stream*/, rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) {
|
||||
return rocblas_sgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
|
||||
|
|
@ -222,3 +311,4 @@ inline rocblas_status rocblasCopyHelper(hipStream_t /*stream*/, rocblas_handle h
|
|||
return rocblas_dcopy(handle, n, x, incx, y, incy);
|
||||
}
|
||||
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const half* x, int incx, half* y, int incy);
|
||||
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy);
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ TEST(GemmOpTest, GemmNoTrans_double) {
|
|||
TestGemmNoTrans<double>();
|
||||
}
|
||||
|
||||
// Only CUDA kernel has float 16 support
|
||||
// Only CUDA and ROCM kernel has float 16 support
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(GemmOpTest, GemmNoTrans_f16) {
|
||||
#ifdef USE_CUDA
|
||||
|
|
@ -86,13 +86,15 @@ TEST(GemmOpTest, GemmNoTrans_f16) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(GemmOpTest, GemmNoTrans_bfloat16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
if (!HasCudaEnvironment(min_cuda_architecture)) {
|
||||
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
OpTester test("Gemm", 14);
|
||||
test.AddAttribute("transA", (int64_t)0);
|
||||
test.AddAttribute("transB", (int64_t)0);
|
||||
|
|
@ -103,7 +105,11 @@ TEST(GemmOpTest, GemmNoTrans_bfloat16) {
|
|||
test.AddInput<BFloat16>("C", {2, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
|
||||
test.AddOutput<BFloat16>("Y", {2, 3}, MakeBFloat16({11.0f, 11.0f, 11.0f, -9.0f, -9.0f, -9.0f}));
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "default_providers.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -162,6 +163,61 @@ TEST(MathOpTest, MatMulUint64Type) {
|
|||
RunMatMulTest<uint64_t>(9);
|
||||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(MathOpTest, MatMul_Float16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
if (!HasCudaEnvironment(min_cuda_architecture)) {
|
||||
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
OpTester test("MatMul", 14);
|
||||
|
||||
std::vector<float> A{1.0f, 2.0f, 3.0f, 4.0f,
|
||||
-1.0f, -2.0f, -3.0f, -4.0f};
|
||||
std::vector<float> B(12, 1.0f);
|
||||
std::vector<float> Y{10.0f, 10.0f, 10.0f,
|
||||
-10.0f, -10.0f, -10.0f};
|
||||
|
||||
std::vector<MLFloat16> f_A(8);
|
||||
std::vector<MLFloat16> f_B(12);
|
||||
std::vector<MLFloat16> f_Y(6);
|
||||
ConvertFloatToMLFloat16(A.data(), f_A.data(), 8);
|
||||
ConvertFloatToMLFloat16(B.data(), f_B.data(), 12);
|
||||
ConvertFloatToMLFloat16(Y.data(), f_Y.data(), 6);
|
||||
|
||||
test.AddInput<MLFloat16>("A", {2, 4}, f_A);
|
||||
test.AddInput<MLFloat16>("B", {4, 3}, f_B);
|
||||
test.AddOutput<MLFloat16>("Y", {2, 3}, f_Y);
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: fp16 is not supported
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(MathOpTest, MatMul_BFloat16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
if (!HasCudaEnvironment(min_cuda_architecture)) {
|
||||
LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16";
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
OpTester test("MatMul", 14);
|
||||
|
||||
test.AddInput<BFloat16>("A", {2, 4}, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, -1.0f, -2.0f, -3.0f, -4.0f}));
|
||||
test.AddInput<BFloat16>("B", {4, 3}, MakeBFloat16({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}));
|
||||
test.AddOutput<BFloat16>("Y", {2, 3}, MakeBFloat16({10.0f, 10.0f, 10.0f, -10.0f, -10.0f, -10.0f}));
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef ENABLE_TRAINING // Prepacking is enabled only on non-training builds
|
||||
TEST(MathOpTest, MatMulSharedPrepackedWeights) {
|
||||
OpTester test("MatMul");
|
||||
|
|
|
|||
Loading…
Reference in a new issue