onnxruntime/orttraining
Vincent Wang 9f68a27c7a
[ORTModule] Handle Cast on Constant Number on Triton Code-gen (#19321)
When using scaled_dot_product_attention on float16 type, the exported
graph has Sqrt(float16(constant)), which cannot be ConstantFold in ORT
because Sqrt CPU kernel doesn't support float16. This causes Triton
code-gen generates code like:

result = 128.0.to(tl.float32)

This code cannot be compiled because .to() cannot be applied to
constant.

This PR is to handle such case that constant number will not do the
Cast.
2024-01-30 17:04:01 +08:00
..
orttraining [ORTModule] Handle Cast on Constant Number on Triton Code-gen (#19321) 2024-01-30 17:04:01 +08:00
tools [ROCm] Update CI/Packaging pipeline to ROCm6.0 (#18985) 2024-01-03 17:25:15 +08:00