ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Find a file
Tianlei Wu ba22d7879a
[CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above (#22713)
### Description
Based on https://github.com/microsoft/onnxruntime/pull/9700, and extend
it to ArgMin as well.

This pull request introduces several enhancements and fixes related to
the `ArgMax` and `ArgMin` operators in the CUDA execution provider. The
changes ensure proper handling of these operators across different
versions and improve kernel registration and fallback mechanisms.

Key changes include:

#### Enhancements to `ArgMax` and `ArgMin` Operators:

* Added new kernel class registrations for `ArgMax` and `ArgMin` for
different data types and versions in
`onnxruntime/core/providers/cuda/cuda_execution_provider.cc`.
[[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R966-R972)
[[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1209-R1215)
[[3]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1657-R1659)
[[4]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285L1825-L1827)
[[5]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1933-R1939)
[[6]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2174-R2180)

* Introduced `ArgMaxOrArgMinNeedFallbackToCPU` function to handle
fallback to CPU when the `select_last_index` attribute is set to 1, as
CUDA does not support this attribute.
[[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2597-R2622)
[[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2672-R2674)

#### Macro and Kernel Registration Improvements:

* Replaced `REGISTER_KERNEL_UNTIL_VERSIONED_TYPED` with
`REGISTER_KERNEL_VERSIONED_RANGE_TYPED` and
`REGISTER_KERNEL_VERSIONED_SINCE_TYPED` macros for better version
handling.
[[1]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L19-R29)
[[2]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L40-R46)

* Updated kernel registration for `ArgMax` and `ArgMin` to use the new
macros, ensuring proper version handling and support for different data
types.

#### Safety Checks:

* Added safety checks in the `ArgMax` and `ArgMin` classes to ensure
`select_last_index` is not set to 1, as it is not supported on CUDA.
[[1]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL91-R99)
[[2]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL101-R117)

#### Testing Enhancements:

* Added new tests for `ArgMax` and `ArgMin` operators to verify behavior
when `select_last_index` is set to 0, ensuring compatibility with both
CPU and CUDA execution providers.
[[1]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3340-R3360)
[[2]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3679-R3699)

### Motivation and Context
Improve CUDA kernel coverage for stable diffusion model and hence
improve its performance on CUDA
2024-11-06 09:54:32 -08:00
.config Add an 1ES PT baseline file (#22587) 2024-10-25 09:18:30 -07:00
.devcontainer
.gdn
.github [CI] Set up proper permissions for linting workflow (#22696) 2024-11-01 18:14:52 -07:00
.pipelines [DML EP] Update DML to 1.15.4 (#22635) 2024-10-29 17:13:57 -07:00
.vscode Stop VSCode appending file associations to settings.json (#21944) 2024-08-31 19:04:12 -07:00
cgmanifests Remove nsync (#20413) 2024-10-21 15:32:14 -07:00
cmake Refactor the cmake code that is related to delay loading (#22646) 2024-11-04 16:30:50 -08:00
csharp [C# MauiModelTester] Fix icon name in Info.plist (#21666) 2024-11-05 16:55:38 -08:00
dockerfiles [ROCm] Python 3.10 in ROCm CI, and ROCm 6.2.3 in MigraphX CI (#22527) 2024-10-25 11:47:16 -07:00
docs [CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above (#22713) 2024-11-06 09:54:32 -08:00
include/onnxruntime/core [CoreML] ML Program more ops (2/N) (#22480) 2024-11-01 08:37:56 +08:00
java Build CUDA and DML together (#22602) 2024-10-31 15:51:13 -07:00
js [WebNN EP] Fix issues with MLTensor caching (#22701) 2024-11-06 09:17:11 -08:00
objectivec [CoreML ML Program] support acclerators selector (#22383) 2024-10-15 11:50:11 +08:00
onnxruntime [CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above (#22713) 2024-11-06 09:54:32 -08:00
orttraining enable serialize prepacked weights into data file (#22256) 2024-10-24 22:24:48 -07:00
rust Fix typos according to reviewdog report. (#21335) 2024-07-22 13:37:32 -07:00
samples
tools Enable CUDA Python Test (#22717) 2024-11-05 16:26:50 -08:00
winml Fix warnings (#21809) 2024-08-21 14:23:37 -07:00
.clang-format
.clang-tidy
.dockerignore
.gitattributes Fix typos according to reviewdog report. (#21335) 2024-07-22 13:37:32 -07:00
.gitignore
.gitmodules Revert "Upgrade emsdk from 3.1.59 to 3.1.62" (#21817) 2024-08-22 11:21:00 -07:00
.lintrunner.toml [js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728) 2024-08-14 16:51:22 -07:00
build.bat
build.sh
build_arm64x.bat
CITATION.cff Fix citation author name issue (#19597) 2024-02-22 17:03:56 -08:00
CODEOWNERS
CONTRIBUTING.md
lgtm.yml
LICENSE
NuGet.config Update C# test projects (#21631) 2024-09-05 08:21:23 +10:00
ort.wprp Fully dynamic ETW controlled logging for ORT and QNN logs (#20537) 2024-06-06 21:11:14 -07:00
ORT_icon_for_light_bg.png
packages.config [DML EP] Update DML to 1.15.4 (#22635) 2024-10-29 17:13:57 -07:00
pyproject.toml Ignore ruff rule N813 (#21477) 2024-07-24 17:48:22 -07:00
README.md Update README.md with release roadmap info (#22486) 2024-10-18 11:00:43 -07:00
requirements-dev.txt
requirements-doc.txt
requirements-lintrunner.txt Update lintrunner requirements (#22185) 2024-09-23 18:27:16 -07:00
requirements-training.txt
requirements.txt Add compatibility for NumPy 2.0 (#21085) 2024-06-27 13:50:53 -07:00
SECURITY.md
setup.py Update CMake to 3.31.0rc1 (#22433) 2024-10-16 11:50:13 -07:00
ThirdPartyNotices.txt Remove nsync (#20413) 2024-10-21 15:32:14 -07:00
VERSION_NUMBER bumps up version in main from 1.20 -> 1.21 (#22482) 2024-10-17 12:32:35 -07:00

ONNX Runtime is a cross-platform inference and training machine-learning accelerator.

ONNX Runtime inference can enable faster customer experiences and lower costs, supporting models from deep learning frameworks such as PyTorch and TensorFlow/Keras as well as classical machine learning libraries such as scikit-learn, LightGBM, XGBoost, etc. ONNX Runtime is compatible with different hardware, drivers, and operating systems, and provides optimal performance by leveraging hardware accelerators where applicable alongside graph optimizations and transforms. Learn more →

ONNX Runtime training can accelerate the model training time on multi-node NVIDIA GPUs for transformer models with a one-line addition for existing PyTorch training scripts. Learn more →

Get Started & Resources

Builtin Pipeline Status

System Inference Training
Windows Build Status
Build Status
Build Status
Linux Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Build Status
Mac Build Status
Android Build Status
iOS Build Status
Web Build Status
Other Build Status

This project is tested with BrowserStack.

Third-party Pipeline Status

System Inference Training
Linux Build Status

Releases

The current release and past releases can be found here: https://github.com/microsoft/onnxruntime/releases.

For details on the upcoming release, including release dates, announcements, features, and guidance on submitting feature requests, please visit the release roadmap: https://onnxruntime.ai/roadmap.

Data/Telemetry

Windows distributions of this project may collect usage data and send it to Microsoft to help improve our products and services. See the privacy statement for more details.

Contributions and Feedback

We welcome contributions! Please see the contribution guidelines.

For feature requests or bug reports, please file a GitHub Issue.

For general discussion or questions, please use GitHub Discussions.

Code of Conduct

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

License

This project is licensed under the MIT License.