[Fix Bug] Fp8*Fp8 Run Error (#20911)

Fix fp8*fp8 when input A is e5m2, input B is e4m3 will run error

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
KnightYao 2024-07-05 23:11:59 +08:00 committed by GitHub
parent 3f6b7430d6
commit 9ef28f092f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm(
#endif
case CUDA_R_8F_E4M3:
case CUDA_R_8F_E5M2:
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
compute_type = CUBLAS_COMPUTE_32F;
break;
#endif
default:
@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm(
compute_type = CUBLAS_COMPUTE_32F_FAST_16BF;
break;
case CUDA_R_32F:
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
compute_type = CUBLAS_COMPUTE_32F;
break;
default:
ORT_THROW("Unable to determine computeType in operator GemmFloat8.");