From ceb1e2b1a6eb4a14d37b202c8f25e666f5c95ac7 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 16 Feb 2022 11:11:08 +0800 Subject: [PATCH] [ROCm] Bugfix of BFloat16-float conversion and Add FastGelu Kernel for AMD (#10557) * bf16 bugfix on amd * enable fastgelu ut on amd --- include/onnxruntime/core/framework/float16.h | 22 +++++ .../contrib_ops/rocm/bert/fast_gelu_impl.cu | 10 +++ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 4 +- .../test/contrib_ops/fastgelu_op_test.cc | 87 +++++++++++-------- 4 files changed, 86 insertions(+), 37 deletions(-) diff --git a/include/onnxruntime/core/framework/float16.h b/include/onnxruntime/core/framework/float16.h index 8de851d9ae..4b256b3137 100644 --- a/include/onnxruntime/core/framework/float16.h +++ b/include/onnxruntime/core/framework/float16.h @@ -50,6 +50,20 @@ struct BFloat16 { inline ORT_HOST_DEVICE BFloat16(float v) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 val = __bfloat16_as_ushort(__float2bfloat16(v)); +#elif defined(USE_ROCM) + // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment. + if (v != v) { // isnan + val = UINT16_C(0x7FC0); + } else { + union { + uint32_t U32; + float F32; + }; + + F32 = v; + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + val = static_cast((U32 + rounding_bias) >> 16); + } #else ORT_IF_CONSTEXPR(endian::native == endian::little) { std::memcpy(&val, reinterpret_cast(&v) + sizeof(uint16_t), sizeof(uint16_t)); @@ -63,6 +77,14 @@ struct BFloat16 { inline ORT_HOST_DEVICE float ToFloat() const { #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 return __bfloat162float(*reinterpret_cast(&val)); +#elif defined(USE_ROCM) + // We should be using memcpy in order to respect the strict aliasing rule but it fails in the HIP environment. + float result = 0; + uint32_t tmp = val; + tmp <<= 16; + float* tempRes = reinterpret_cast(&tmp); + result = *tempRes; + return result; #else float result; char* const first = reinterpret_cast(&result); diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu index edab650cbe..f2c2abf950 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu @@ -113,6 +113,16 @@ bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int i return HIP_CALL(hipPeekAtLastError()); } +template <> +bool LaunchFastGeluKernel(const hipDeviceProp_t& prop, hipStream_t stream, int input_length, int bias_length, + const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + hipLaunchKernelGGL(HIP_KERNEL_NAME(FastGeluKernel), dim3(gridSize), dim3(blockSize), 0, stream, + A, B, C, input_length, bias_length, input, bias, output); + return HIP_CALL(hipPeekAtLastError()); +} + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 99c4755c36..ac2435baa0 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -80,7 +80,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); -// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization); @@ -166,7 +166,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 89ae377263..302226e107 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -40,41 +40,51 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -static void RunFastGeluTest( - const std::vector& input_data, - const std::vector& bias_data, - const std::vector& output_data, - const std::vector& input_dims, - const std::vector& bias_dims, - const std::vector& output_dims, - bool has_bias = true, - bool use_float16 = false) { +#if defined(USE_CUDA) || defined(USE_ROCM) +static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, + const std::vector& output_data, const std::vector& input_dims, + const std::vector& bias_dims, const std::vector& output_dims, + bool has_bias = true, bool use_float16 = false) { +#ifdef USE_CUDA // Test CUDA operator. int min_cuda_architecture = use_float16 ? 530 : 0; - if (HasCudaEnvironment(min_cuda_architecture)) { - OpTester tester("FastGelu", 1, onnxruntime::kMSDomain); + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; + } +#endif + OpTester tester("FastGelu", 1, onnxruntime::kMSDomain); - if (use_float16) { - tester.AddInput("X", input_dims, ToFloat16(input_data)); - if (has_bias) { - tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); - } - tester.AddOutput("Y", output_dims, ToFloat16(output_data)); - } else { - tester.AddInput("X", input_dims, input_data); - if (has_bias) { - tester.AddInput("bias", bias_dims, bias_data); - } - tester.AddOutput("Y", output_dims, output_data); + if (use_float16) { + tester.AddInput("X", input_dims, ToFloat16(input_data)); + if (has_bias) { + tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); } - - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + tester.AddOutput("Y", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("X", input_dims, input_data); + if (has_bias) { + tester.AddInput("bias", bias_dims, bias_data); + } + tester.AddOutput("Y", output_dims, output_data); } + std::vector> execution_providers; +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} +#endif + +static void RunFastGeluCpuTest(const std::vector& input_data, const std::vector& bias_data, + const std::vector& output_data, const std::vector& input_dims, + const std::vector& bias_dims, const std::vector& output_dims, + bool has_bias = true) { // Test CPU operator: only float32 is implemented for FastGelu CPU. - if (nullptr != DefaultCpuExecutionProvider().get() && !use_float16) { + if (nullptr != DefaultCpuExecutionProvider().get()) { OpTester tester("FastGelu", 1, onnxruntime::kMSDomain); tester.AddInput("X", input_dims, input_data); @@ -107,7 +117,10 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; - RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); +#if defined(USE_CUDA) || defined(USE_ROCM) + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); +#endif + RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); } TEST(FastGeluTest, FastGeluWithNullInput) { @@ -152,6 +165,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } +// CUDA and ROCm only for Float16 and BFloat16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBiasFloat16) { int batch_size = 1; int sequence_length = 2; @@ -172,7 +187,7 @@ TEST(FastGeluTest, FastGeluWithBiasFloat16) { std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; - RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true); + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, true, true); } TEST(FastGeluTest, FastGeluWithoutBiasFloat16) { @@ -194,17 +209,17 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16) { std::vector bias_dims = {}; std::vector output_dims = input_dims; - RunFastGeluTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); + RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } -// CUDA only, ROCM has not been supported yet -#ifdef USE_CUDA TEST(FastGeluTest, FastGeluWithBias_BFloat16) { +#ifdef USE_CUDA int min_cuda_architecture = 530; if (!HasCudaEnvironment(min_cuda_architecture)) { LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; return; } +#endif OpTester tester("FastGelu", 1, onnxruntime::kMSDomain); int batch_size = 1; @@ -235,12 +250,14 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) { tester.AddOutput("Y", output_dims, f_Y); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } #endif - - } // namespace test } // namespace onnxruntime