diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 28ab27ee33..07c5de2fe8 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm( #endif case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; #endif default: @@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm( compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; break; case CUDA_R_32F: - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; break; default: ORT_THROW("Unable to determine computeType in operator GemmFloat8.");