mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
[Fix Bug] Fp8*Fp8 Run Error (#20911)
Fix fp8*fp8 when input A is e5m2, input B is e4m3 will run error ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
3f6b7430d6
commit
9ef28f092f
1 changed files with 2 additions and 2 deletions
|
|
@ -207,7 +207,7 @@ Status GemmFloat8::ComputeGemm(
|
|||
#endif
|
||||
case CUDA_R_8F_E4M3:
|
||||
case CUDA_R_8F_E5M2:
|
||||
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
compute_type = CUBLAS_COMPUTE_32F;
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
|
|
@ -219,7 +219,7 @@ Status GemmFloat8::ComputeGemm(
|
|||
compute_type = CUBLAS_COMPUTE_32F_FAST_16BF;
|
||||
break;
|
||||
case CUDA_R_32F:
|
||||
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
compute_type = CUBLAS_COMPUTE_32F;
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unable to determine computeType in operator GemmFloat8.");
|
||||
|
|
|
|||
Loading…
Reference in a new issue