mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
### Description Based on https://github.com/microsoft/onnxruntime/pull/9700, and extend it to ArgMin as well. This pull request introduces several enhancements and fixes related to the `ArgMax` and `ArgMin` operators in the CUDA execution provider. The changes ensure proper handling of these operators across different versions and improve kernel registration and fallback mechanisms. Key changes include: #### Enhancements to `ArgMax` and `ArgMin` Operators: * Added new kernel class registrations for `ArgMax` and `ArgMin` for different data types and versions in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R966-R972) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1209-R1215) [[3]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1657-R1659) [[4]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285L1825-L1827) [[5]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1933-R1939) [[6]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2174-R2180) * Introduced `ArgMaxOrArgMinNeedFallbackToCPU` function to handle fallback to CPU when the `select_last_index` attribute is set to 1, as CUDA does not support this attribute. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2597-R2622) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2672-R2674) #### Macro and Kernel Registration Improvements: * Replaced `REGISTER_KERNEL_UNTIL_VERSIONED_TYPED` with `REGISTER_KERNEL_VERSIONED_RANGE_TYPED` and `REGISTER_KERNEL_VERSIONED_SINCE_TYPED` macros for better version handling. [[1]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L19-R29) [[2]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L40-R46) * Updated kernel registration for `ArgMax` and `ArgMin` to use the new macros, ensuring proper version handling and support for different data types. #### Safety Checks: * Added safety checks in the `ArgMax` and `ArgMin` classes to ensure `select_last_index` is not set to 1, as it is not supported on CUDA. [[1]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL91-R99) [[2]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL101-R117) #### Testing Enhancements: * Added new tests for `ArgMax` and `ArgMin` operators to verify behavior when `select_last_index` is set to 0, ensuring compatibility with both CPU and CUDA execution providers. [[1]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3340-R3360) [[2]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3679-R3699) ### Motivation and Context Improve CUDA kernel coverage for stable diffusion model and hence improve its performance on CUDA |
||
|---|---|---|
| .. | ||
| c_cxx | ||
| execution_providers/images | ||
| images | ||
| python | ||
| ABI_Dev_Notes.md | ||
| Android_testing.md | ||
| C_API_Guidelines.md | ||
| cmake_guideline.md | ||
| Coding_Conventions_and_Standards.md | ||
| ContribOperators.md | ||
| FAQ.md | ||
| How_To_Update_ONNX_Dev_Notes.md | ||
| Memory_Optimizer.md | ||
| Model_Test.md | ||
| NotesOnThreading.md | ||
| ONNX_Runtime_Server_Usage.md | ||
| onnxruntime_dependencies.dot | ||
| onnxruntime_dependencies.png | ||
| onnxruntime_extensions.md | ||
| OperatorKernels.md | ||
| ORT_Format_Update_in_1.13.md | ||
| ORT_Use_Triton_Kernel.md | ||
| ORTModule_Convergence_Notes.md | ||
| ORTModule_ModuleWithLoss_Wrapper.md | ||
| ORTModule_PythonOp_Notes.md | ||
| ORTModule_Training_Guidelines.md | ||
| PR_Guidelines.md | ||
| Privacy.md | ||
| Reduced_Operator_Kernel_build.md | ||
| ReleaseManagement.md | ||
| Roadmap.md | ||
| Server.md | ||
| TVM_EP.md | ||
| Versioning.md | ||
| WinML_principles.md | ||