From 9ef28f092f575ce15bfc40c7663e649e75809c53 Mon Sep 17 00:00:00 2001 From: KnightYao Date: Fri, 5 Jul 2024 23:11:59 +0800 Subject: [PATCH] [Fix Bug] Fp8*Fp8 Run Error (#20911) Fix fp8*fp8 when input A is e5m2, input B is e4m3 will run error ### Description ### Motivation and Context --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 28ab27ee33..07c5de2fe8 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm( #endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; #endif default: @@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm( compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; break; case CUDA_R_32F: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; default: ORT_THROW("Unable to determine computeType in operator GemmFloat8.");