mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
When using scaled_dot_product_attention on float16 type, the exported graph has Sqrt(float16(constant)), which cannot be ConstantFold in ORT because Sqrt CPU kernel doesn't support float16. This causes Triton code-gen generates code like: result = 128.0.to(tl.float32) This code cannot be compiled because .to() cannot be applied to constant. This PR is to handle such case that constant number will not do the Cast. |
||
|---|---|---|
| .. | ||
| orttraining | ||
| tools | ||