From 07317316ccbbbacc4e8aa6fd50ba8563a482e226 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 12 Oct 2023 19:13:53 -0700 Subject: [PATCH] CUDA EP vs ROCM EP hipify audit (#17776) Migrate most CUDA EP improvements and changes to ROCM EP. The process involves using hipify against all CUDA EP files (i.e. do not exclude any files from onnxruntime_rocm_hipify.cmake) then vimdiff compare them against the ROCM EP files that are under source control and pull in most changes. These changes include functional as well as formatting and makes comparing CUDA EP and ROCM EP easier, though it makes the PR diff somewhat less obvious due to formatting changes. - hipify audit of onnxruntime/core/providers/rocm, enable ops - Loop - Scan - hipify audit of onnxruntime/contrib_ops/rocm - fix contrib ops search implementation - enable more contrib ops - Affine - ComplexMul - ConvTransposeWithDynamicPads - Crop - DynamicSlice - FFT [Rfft, Irfft] - GreedySearch - ImageScaler - ParametricSoftplus - ScaledTanh - ThresholdRelu --------- Co-authored-by: cloudhan --- cmake/onnxruntime_providers_migraphx.cmake | 4 +- cmake/onnxruntime_providers_rocm.cmake | 5 +- cmake/onnxruntime_rocm_hipify.cmake | 27 - .../cuda/transformers/greedy_search.cc | 2 + .../contrib_ops/rocm/rocm_contrib_kernels.cc | 146 +- .../providers/cuda/cuda_provider_factory.cc | 2 +- .../core/providers/cuda/nn/conv_transpose.cc | 3 +- .../core/providers/rocm/cu_inc/common.cuh | 20 +- onnxruntime/core/providers/rocm/fpgeneric.cu | 4 +- .../core/providers/rocm/gpu_data_transfer.cc | 36 +- .../core/providers/rocm/gpu_data_transfer.h | 4 +- .../core/providers/rocm/integer_gemm.cc | 21 +- onnxruntime/core/providers/rocm/math/einsum.h | 5 +- .../math/einsum_utils/einsum_auxiliary_ops.h | 13 +- .../core/providers/rocm/math/softmax.cc | 23 +- onnxruntime/core/providers/rocm/nn/conv.cc | 44 +- onnxruntime/core/providers/rocm/nn/conv.h | 11 +- .../providers/rocm/reduction/reduction_ops.cc | 65 +- .../core/providers/rocm/rocm_allocator.cc | 5 +- .../core/providers/rocm/rocm_allocator.h | 3 +- onnxruntime/core/providers/rocm/rocm_call.cc | 2 +- .../providers/rocm/rocm_execution_provider.cc | 2039 ++++++++--------- .../providers/rocm/rocm_execution_provider.h | 25 +- .../rocm/rocm_execution_provider_info.cc | 6 +- onnxruntime/core/providers/rocm/rocm_fwd.h | 13 - onnxruntime/core/providers/rocm/rocm_kernel.h | 58 +- .../providers/rocm/rocm_provider_factory.cc | 67 +- .../providers/rocm/rocm_provider_factory.h | 12 +- .../core/providers/rocm/rocm_stream_handle.cc | 36 +- .../core/providers/rocm/rocm_stream_handle.h | 9 +- onnxruntime/core/providers/rocm/rocm_utils.cu | 9 +- .../providers/rocm/shared_inc/fast_divmod.h | 90 - .../providers/rocm/shared_inc/rocm_call.h | 4 + .../test/contrib_ops/element_wise_ops_test.cc | 34 +- onnxruntime/test/contrib_ops/fft_op_test.cc | 24 +- .../test/contrib_ops/greedy_search_test.cc | 44 +- .../providers/cpu/controlflow/loop_test.cc | 8 +- .../providers/cpu/controlflow/scan_test.cc | 10 +- tools/ci_build/amd_hipify.py | 20 + 39 files changed, 1493 insertions(+), 1460 deletions(-) delete mode 100644 onnxruntime/core/providers/rocm/rocm_fwd.h delete mode 100644 onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 632600288b..91ac66a407 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -42,7 +42,7 @@ onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) @@ -72,4 +72,4 @@ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) \ No newline at end of file + ) diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 2bb4c7d600..b662682915 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -10,6 +10,7 @@ find_package(hiprand REQUIRED) find_package(rocblas REQUIRED) find_package(MIOpen REQUIRED) + find_package(hipfft REQUIRED) # MIOpen version if(NOT DEFINED ENV{MIOPEN_PATH}) @@ -48,7 +49,7 @@ find_library(RCCL_LIB rccl REQUIRED) find_library(ROCTRACER_LIB roctracer64 REQUIRED) - set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB}) + set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB}) file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h" @@ -219,4 +220,4 @@ install(TARGETS onnxruntime_providers_rocm ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) \ No newline at end of file + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 55d03c1427..6bab3babab 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,15 +48,6 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" - "math/complex_mul.cc" - "math/complex_mul.h" - "math/complex_mul_impl.cu" - "math/complex_mul_impl.h" - "math/cufft_plan_cache.h" - "math/fft_ops.cc" - "math/fft_ops.h" - "math/fft_ops_impl.cu" - "math/fft_ops_impl.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" @@ -86,19 +77,6 @@ set(contrib_ops_excluded_files "quantization/qordered_ops/qordered_unary_ops.cc" "quantization/qordered_ops/qordered_unary_ops_impl.h" "quantization/qordered_ops/qordered_unary_ops_impl.cu" - "tensor/crop.cc" - "tensor/crop.h" - "tensor/crop_impl.cu" - "tensor/crop_impl.h" - "tensor/dynamicslice.cc" - "tensor/image_scaler.cc" - "tensor/image_scaler.h" - "tensor/image_scaler_impl.cu" - "tensor/image_scaler_impl.h" - "transformers/greedy_search.cc" - "transformers/greedy_search.h" - "conv_transpose_with_dynamic_pads.cc" - "conv_transpose_with_dynamic_pads.h" "cuda_contrib_kernels.cc" "cuda_contrib_kernels.h" "inverse.cc" @@ -119,10 +97,6 @@ endif() set(provider_excluded_files "atomic/common.cuh" - "controlflow/loop.cc" - "controlflow/loop.h" - "controlflow/scan.cc" - "controlflow/scan.h" "cu_inc/common.cuh" "math/einsum_utils/einsum_auxiliary_ops.cc" "math/einsum_utils/einsum_auxiliary_ops.h" @@ -170,7 +144,6 @@ set(provider_excluded_files "cuda_memory_check.h" "cuda_fence.cc" "cuda_fence.h" - "cuda_fwd.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" diff --git a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc index d9014ca8f5..812ab0b1bc 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc @@ -48,10 +48,12 @@ GreedySearch::GreedySearch(const OpKernelInfo& info) SetConsoleDumper(&g_cuda_dumper_greedysearch); +#ifndef USE_ROCM cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + static_cast(cuda_device_prop_)->minor * 10; +#endif } Status GreedySearch::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7bc0f99081..0f8fe68de7 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -29,6 +29,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GatedRelativePositionBias); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RemovePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, RestorePadding); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, RestorePadding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Rfft); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Rfft); @@ -52,6 +60,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Affine); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Attention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Attention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, PackedMultiHeadAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, PackedMultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BeamSearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ConvTransposeWithDynamicPads); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, Crop); @@ -61,12 +73,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int64_t, DynamicSlice); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, EmbedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GreedySearch); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, GroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, NhwcConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, NhwcConv); @@ -113,6 +124,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLayerNormalization); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedGelu); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QuantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, DequantizeWithOrder); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedAttention); +// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedLongformerAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedSelfAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GemmFastGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GemmFastGelu); @@ -139,6 +161,7 @@ KernelCreateInfo BuildKernelCreateInfo() { return info; } +// clang-format off Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing @@ -162,70 +185,73 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to maintain backward compatibility - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -238,7 +264,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -249,16 +274,25 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo - + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // TransposedMatMul is still here for backward compatibility BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -278,6 +312,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + }; for (auto& function_table_entry : function_table) { @@ -289,6 +324,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { return Status::OK(); } +// clang-format on } // namespace rocm } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 5a11f2529f..afdccb0a15 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -243,7 +243,7 @@ struct CUDA_Provider : Provider { cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy; cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream; - // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well. + // The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set by C API UpdateCUDAProviderOptionsWithValue() as well. // We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options if (options.find("has_user_compute_stream") != options.end()) { cuda_options.user_compute_stream = internal_options.user_compute_stream; diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 2b3326c528..04f6bc46dc 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -102,8 +102,9 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } s_.y_dims = gsl::make_span(y_dims); - if (w_dims_changed) + if (w_dims_changed) { ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } // Special case when there is a dim value of 0 in the shape. // Return only after we have cached the following for subsequent runs : diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 429ceb1f7c..5f966ac746 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -2,8 +2,6 @@ // Licensed under the MIT License. #pragma once -#include -#include #include #include #include @@ -294,6 +292,14 @@ __device__ __inline__ T _Gelu(T a) { return a * _Normcdf(a); } +template <> +__device__ __inline__ half _Gelu(half a) { + const half kHalf = half(0.5); + const half kOne = half(1.0); + const half kAlpha = half(M_SQRT1_2); + return a * kHalf * (kOne + _Erf(kAlpha * a)); +} + template __device__ __inline__ T _Mod(T a, T b) { T r = a % b; @@ -348,21 +354,19 @@ struct GridDim { }; }; -// aligned vector generates vectorized load/store +// aligned vector generates vectorized load/store on ROCM template struct alignas(sizeof(T) * vec_size) aligned_vector { T val[vec_size]; }; -#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ +#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \ HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; \ - if (id >= N) \ + if (id >= N) \ return; // HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels. -// TODO ROCM added support recently, should verify. -#define HIP_KERNEL_ASSERT(...) -// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) +#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__) // WARP related definitions and functions constexpr int GPU_WARP_SIZE = warpSize; diff --git a/onnxruntime/core/providers/rocm/fpgeneric.cu b/onnxruntime/core/providers/rocm/fpgeneric.cu index 4df7e0b5a5..d130758bec 100644 --- a/onnxruntime/core/providers/rocm/fpgeneric.cu +++ b/onnxruntime/core/providers/rocm/fpgeneric.cu @@ -68,7 +68,7 @@ rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocbla rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const half* x, int incx, half* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorHalf<<>>(x, incx, y, incy, n); + CopyVectorHalf<<>>(x, incx, y, incy, n); return rocblas_status_success; } @@ -76,6 +76,6 @@ rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, cons onnxruntime::BFloat16* y, int incy) { dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1); dim3 dimBlock(COPY_BLOCK_DIM, 1, 1); - CopyVectorBFloat16<<>>(x, incx, y, incy, n); + CopyVectorBFloat16<<>>(x, incx, y, incy, n); return rocblas_status_success; } diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index fd45ad675a..635a25480b 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -2,14 +2,15 @@ // Licensed under the MIT License. #include "core/providers/shared_library/provider_api.h" -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/gpu_data_transfer.h" -// use default stream for copy for now, to avoid racing in BFC arena as in issue #4829 -// note this may cause some models to run slower if there are ops running on CPU -// so we leave it as optional, in case user need the previous behavior -// a full fix to BFC arena is being looked at, and once it's in, we can revert this change +#include "core/providers/rocm/gpu_data_transfer.h" +#include "core/providers/rocm/rocm_common.h" + namespace onnxruntime { +GPUDataTransfer::GPUDataTransfer() {} + +GPUDataTransfer::~GPUDataTransfer() {} + bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; @@ -34,12 +35,12 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const } else { // copy from other CPU memory to GPU, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } } else if (src_device.Type() == OrtDevice::GPU) { // copying from GPU to CPU memory, this is blocking HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking + HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); } else { // copying between cpu memory memcpy(dst_data, src_data, bytes); @@ -57,34 +58,29 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + if (src_device.Type() == OrtDevice::CPU) { // copy from pinned memory to GPU, this is non-blocking HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - // Copy only if the two addresses are different. if (dst_data != src_data) { HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); } - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { - if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + if (dst_device.Type() == OrtDevice::CPU) { // copying from GPU to pinned memory, this is non-blocking HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } else { - // copying from GPU to CPU memory, this is blocking - HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); } } else { - // copying between cpu memory + if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) { + // sync the stream first to make sure the data arrived + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream.GetHandle()))); + } memcpy(dst_data, src_data, bytes); } return Status::OK(); } + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.h b/onnxruntime/core/providers/rocm/gpu_data_transfer.h index 3d35ed52ff..3d297bdce4 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.h +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.h @@ -10,8 +10,8 @@ namespace onnxruntime { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer() = default; - ~GPUDataTransfer() = default; + GPUDataTransfer(); + ~GPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; diff --git a/onnxruntime/core/providers/rocm/integer_gemm.cc b/onnxruntime/core/providers/rocm/integer_gemm.cc index 3c82a436d7..9771f42fd3 100644 --- a/onnxruntime/core/providers/rocm/integer_gemm.cc +++ b/onnxruntime/core/providers/rocm/integer_gemm.cc @@ -5,13 +5,14 @@ #include #include "core/providers/rocm/shared_inc/integer_gemm.h" +#include "core/common/safeint.h" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/rocm_call.h" namespace onnxruntime { namespace rocm { -inline int roundoff(int v, int d) { +constexpr int roundoff(int v, int d) { return (v + d - 1) / d * d; } @@ -21,20 +22,21 @@ Status GemmInt8(int m, int n, int k, const RocmKernel* rocm_kernel, onnxruntime::Stream* ort_stream) { ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null"); + ORT_ENFORCE(ort_stream != nullptr, "Rocm kernel must have the stream instance"); - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + hipStream_t stream = static_cast(ort_stream->GetHandle()); // pad A and B to make their leading dimension be multiples of 32 - // because cublasGemmEx requires: + // because rocblas_gemm_ex requires: // 1. leading dimension is multiples of 4 // 2. A, B is 32-bit aligned - const int mask = 0x1F; + constexpr int mask = 0x1F; int lda_aligned = lda; IAllocatorUniquePtr a_padded; if ((mask & lda_aligned) != 0) { lda_aligned = roundoff(lda, 32); - a_padded = rocm_kernel->GetScratchBuffer(m * lda_aligned, ort_stream); + a_padded = rocm_kernel->GetScratchBuffer(SafeInt(m) * lda_aligned, ort_stream); HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream)); } @@ -42,14 +44,15 @@ Status GemmInt8(int m, int n, int k, IAllocatorUniquePtr b_padded; if ((mask & ldb_aligned) != 0) { ldb_aligned = roundoff(ldb, 32); - b_padded = rocm_kernel->GetScratchBuffer(k * ldb_aligned, ort_stream); + b_padded = rocm_kernel->GetScratchBuffer(SafeInt(k) * ldb_aligned, ort_stream); HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream)); } - RocmStream* ort_rocm_stream = static_cast(ort_stream); - auto handle = ort_rocm_stream->rocblas_handle_; + auto* ort_rocm_stream = dynamic_cast(ort_stream); + auto rocblas = ort_rocm_stream->rocblas_handle_; + ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex( - handle, + rocblas, rocblas_operation_none, rocblas_operation_none, n, m, k, &alpha, diff --git a/onnxruntime/core/providers/rocm/math/einsum.h b/onnxruntime/core/providers/rocm/math/einsum.h index a4adc3da98..6be412348e 100644 --- a/onnxruntime/core/providers/rocm/math/einsum.h +++ b/onnxruntime/core/providers/rocm/math/einsum.h @@ -17,8 +17,7 @@ class Einsum final : public onnxruntime::Einsum { Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) { // We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method // TODO: Clean up the ROCMExecutionProvider interface to avoid this - rocm_ep_ = const_cast( - static_cast(info.GetExecutionProvider())); + rocm_ep_ = static_cast(info.GetExecutionProvider()); } Status Compute(OpKernelContext* context) const override; @@ -32,7 +31,7 @@ class Einsum final : public onnxruntime::Einsum { using onnxruntime::Einsum::equation_; // We need to access to the ROCM EP instance to get the rocblas/miopen handles - ROCMExecutionProvider* rocm_ep_; + const ROCMExecutionProvider* rocm_ep_; }; } // namespace rocm diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h index 623bb1d590..e1fc3f40ee 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops.h @@ -21,19 +21,18 @@ namespace EinsumOp { // Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow struct EinsumRocmAssets { explicit EinsumRocmAssets(rocblas_handle rocblas_handle, - ROCMExecutionProvider* rocm_ep, - Stream* ort_stream, - AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), - rocm_ep_(rocm_ep), - ort_stream_(ort_stream), - gpu_allocator_(gpu_allocator) {} + const ROCMExecutionProvider* rocm_ep, + Stream* ort_stream, AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle), + rocm_ep_(rocm_ep), + ort_stream_(ort_stream), + gpu_allocator_(gpu_allocator) {} hipStream_t GetRocmStream() { return ort_stream_ ? static_cast(ort_stream_->GetHandle()) : nullptr; } rocblas_handle rocblas_handle_; - ROCMExecutionProvider* rocm_ep_; + const ROCMExecutionProvider* rocm_ep_; Stream* ort_stream_; AllocatorPtr gpu_allocator_; }; diff --git a/onnxruntime/core/providers/rocm/math/softmax.cc b/onnxruntime/core/providers/rocm/math/softmax.cc index 5a07737d92..8d922d0bb4 100644 --- a/onnxruntime/core/providers/rocm/math/softmax.cc +++ b/onnxruntime/core/providers/rocm/math/softmax.cc @@ -29,20 +29,23 @@ Status SoftMaxComputeHelper( auto X_data = reinterpret_cast(X); if (D <= 1024 && D * sizeof(T) <= 4096) { - return dispatch_warpwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), - gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); + return dispatch_warpwise_softmax_forward< + HipT_IN, HipT_OUT, AccumulationType_t, IsLogSoftmax>( + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); } + return dispatch_blockwise_softmax_forward, IsLogSoftmax>( - stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), - gsl::narrow_cast(D), gsl::narrow_cast(N), tuning_ctx); + stream, Y_data, X_data, gsl::narrow_cast(D), gsl::narrow_cast(D), gsl::narrow_cast(D), + gsl::narrow_cast(N), tuning_ctx); } -#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, const TensorShape& shape, TOut* Y, \ - int64_t axis, RocmTuningContext* tuning_ctx); \ - template Status SoftMaxComputeHelper(Stream * stream, const T* input, const TensorShape& shape, TOut* Y, \ - int64_t axis, RocmTuningContext* tuning_ctx); +#define SPECIALIZED_SOFTMAX_HELPER_IMPL(T, TOut) \ + template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + const TensorShape& shape, TOut* Y, int64_t axis, \ + RocmTuningContext* tuning_ctx); \ + template Status SoftMaxComputeHelper(Stream * stream, const T* input, \ + const TensorShape& shape, TOut* Y, int64_t axis, \ + RocmTuningContext* tuning_ctx); SPECIALIZED_SOFTMAX_HELPER_IMPL(MLFloat16, float) SPECIALIZED_SOFTMAX_HELPER_IMPL(float, float) diff --git a/onnxruntime/core/providers/rocm/nn/conv.cc b/onnxruntime/core/providers/rocm/nn/conv.cc index 6846813c7c..6214ec7bc0 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.cc +++ b/onnxruntime/core/providers/rocm/nn/conv.cc @@ -44,14 +44,13 @@ const miopenConvFwdAlgorithm_t Conv::kAllAlgos[] = { miopenConvolutionFwdAlgoWinograd, miopenConvolutionFwdAlgoImplicitGEMM}; -miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, - miopenConvFwdAlgorithm_t algo, size_t* sz) { +miopenStatus_t GetWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, miopenConvFwdAlgorithm_t algo, size_t* sz) { return miopenConvolutionForwardGetWorkSpaceSize(handle, s.w_desc, s.x_tensor, s.conv_desc, s.y_tensor, sz); } size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& s, const miopenConvFwdAlgorithm_t* algo, int n_algo) { - // TODO: get maximum available size from memory arean + // TODO: get maximum available size from memory arena size_t free, total; HIP_CALL_THROW(hipMemGetInfo(&free, &total)); // Assuming 10% of fragmentation @@ -68,8 +67,7 @@ size_t GetMaxWorkspaceSize(miopenHandle_t handle, const MiopenConvState& input_dims, + const void* input_data, gsl::span input_dims, void* output_data, const gsl::span& output_dims, const gsl::span& starts, @@ -103,8 +101,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); } // set B @@ -140,7 +137,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const size_t kernel_rank = kernel_shape.size(); - ConvAttributes::ConvPadVector pads(conv_attrs_.pads); + ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { pads.resize(kernel_rank * 2, 0); } @@ -174,7 +171,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) TensorShapeVector slice_axes; slice_axes.reserve(kernel_rank); - const size_t spatial_dim_start = channels_last ? 1 : 2; + constexpr size_t spatial_dim_start = channels_last ? 1 : 2; const size_t spatial_dim_end = spatial_dim_start + kernel_rank; TensorShape spatial_shape = X->Shape().Slice(spatial_dim_start, spatial_dim_end); @@ -183,7 +180,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, post_slicing_required, slice_starts, slice_ends, slice_axes, channels_last)); - if (channels_last) { y_dims.push_back(M); y_dims_with_adjusted_pads.push_back(M); @@ -198,9 +194,6 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) s_.slice_axes = slice_axes; s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } if (post_slicing_required) { // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. s_.memory_for_miopen_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size, context->GetComputeStream()); @@ -225,18 +218,23 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) } if (w_dims_changed) { - if (channels_last) { + if (!channels_last) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); + } else { ORT_RETURN_IF_ERROR(s_.w_desc.Set(MiopenTensor::GetDataType(), miopenTensorNHWC, w_dims[0], w_dims[3], w_dims[1], w_dims[2])); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, MiopenTensor::GetDataType())); } } + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + if (channels_last) { ORT_RETURN_IF_ERROR(s_.x_tensor.Set(MiopenTensor::GetDataType(), miopenTensorNHWC, @@ -357,7 +355,7 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions // This may have lead to extra results that are unnecessary and hence we slice that off here if (s_.post_slicing_required) { - ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, s_.y_dims_with_adjusted_pads, + ORT_RETURN_IF_ERROR(SliceOutUnwantedOutputSection(Stream(context), s_.y_data, gsl::make_span(s_.y_dims_with_adjusted_pads), s_.Y->MutableDataRaw(), s_.y_dims.GetDims(), s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size)); } @@ -384,18 +382,18 @@ MiopenConvolutionDescriptor::~MiopenConvolutionDescriptor() { Status MiopenConvolutionDescriptor::Set( size_t rank, - gsl::span pads, - gsl::span strides, - gsl::span dilations, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, int groups, miopenConvolutionMode_t mode, miopenDataType_t data_type) { if (!desc_) MIOPEN_RETURN_IF_ERROR(miopenCreateConvolutionDescriptor(&desc_)); - InlinedVector pad_dims(rank); - InlinedVector stride_dims(rank); - InlinedVector dilation_dims(rank); + InlinedVector pad_dims(rank); + InlinedVector stride_dims(rank); + InlinedVector dilation_dims(rank); for (size_t i = 0; i < rank; i++) { pad_dims[i] = gsl::narrow_cast(pads[i]); stride_dims[i] = gsl::narrow_cast(strides[i]); diff --git a/onnxruntime/core/providers/rocm/nn/conv.h b/onnxruntime/core/providers/rocm/nn/conv.h index f4f2331e91..bc9846203e 100644 --- a/onnxruntime/core/providers/rocm/nn/conv.h +++ b/onnxruntime/core/providers/rocm/nn/conv.h @@ -10,6 +10,9 @@ #include namespace onnxruntime { + +using ConvPadVector = ConvAttributes::ConvPadVector; + namespace rocm { class MiopenConvolutionDescriptor final { @@ -18,9 +21,9 @@ class MiopenConvolutionDescriptor final { ~MiopenConvolutionDescriptor(); Status Set(size_t rank, - gsl::span pads, - gsl::span strides, - gsl::span dilations, + const gsl::span& pads, + const gsl::span& strides, + const gsl::span& dilations, int groups, miopenConvolutionMode_t mode, miopenDataType_t data_type); @@ -198,7 +201,7 @@ class Conv : public RocmKernel { Status SliceOutUnwantedOutputSection(hipStream_t stream, const void* input_data, - const gsl::span& input_dims, + gsl::span input_dims, void* output_data, const gsl::span& output_dims, const gsl::span& starts, diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index 4f726017d8..820745b22f 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -8,6 +8,9 @@ #include "core/providers/rocm/math/binary_elementwise_ops_impl.h" #include "core/providers/rocm/math/binary_elementwise_ops.h" #include "core/providers/rocm/math/unary_elementwise_ops_impl.h" +#ifdef ENABLE_TRAINING +#include "contrib_ops/cpu/aten_ops/aten_op.h" +#endif using namespace onnxruntime::common; namespace onnxruntime { @@ -100,8 +103,8 @@ namespace rocm { (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -// ROCM ArgMax/ArgMin doesn't have OpSet12 implementation (with select_last_index attr), keep it in OpSet11 for now. -#define REGISTER_KERNEL_TYPED_11(name, T) \ +// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet +#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ @@ -110,10 +113,10 @@ namespace rocm { kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 11, \ + 11, 11, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ @@ -166,7 +169,6 @@ Status ReduceKernel::ReduceKernelShared( const auto rank = input_shape.NumDimensions(); auto hip_stream = stream ? static_cast(stream->GetHandle()) : nullptr; - // Block of fast matrix reduction. if (fast_reduction_) { int m{}, n{}; @@ -210,10 +212,8 @@ Status ReduceKernel::ReduceKernelShared( ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, MiopenTensor::GetDataType(), ReduceTensorIndices)); else ORT_RETURN_IF_ERROR(reduce_desc.Set(miopen_reduce_op, miopen_type_X, ReduceTensorIndices)); - const auto one = ReduceConsts::One; const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; MiopenTensor output_tensor; ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); @@ -444,17 +444,18 @@ template Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, gsl::span axes, - bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, Stream* ort_stream, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, const TensorShape* input_shape_override) { typedef typename ToHipType::MappedType HipT; const TensorShape& input_shape = input_shape_override ? *input_shape_override : input.Shape(); + hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; int64_t input_count = prepare_reduce_metadata.input_count; int64_t output_count = prepare_reduce_metadata.output_count; auto& output_dims = prepare_reduce_metadata.output_dims; auto& input_dims_miopen = prepare_reduce_metadata.input_dims_miopen; auto& output_dims_miopen = prepare_reduce_metadata.output_dims_miopen; - hipStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // special case when there is a dim value of 0 in the shape. if (input_count == 0) { assert(output.Shape().Size() == 0); @@ -540,7 +541,6 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, const auto one = ReduceConsts::One; const auto zero = ReduceConsts::Zero; - MiopenTensor input_tensor; MiopenTensor output_tensor; ORT_RETURN_IF_ERROR(input_tensor.Set(input_dims_miopen, miopen_type_X)); @@ -588,11 +588,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, MIOPEN_RETURN_IF_ERROR(miopenGetReductionIndicesSize(RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, input_tensor, output_tensor, &indices_bytes_max)); auto indices_rocm_max = indices_bytes == 0 ? nullptr : IAllocator::MakeUniquePtr(gpu_allocator, indices_bytes, false, ort_stream, WaitRocmNotificationOnDevice); + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_max_desc, indices_rocm_max.get(), indices_bytes_max, workspace_rocm.get(), workspace_bytes, &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } // Exp(X-ReduceMax) @@ -652,11 +653,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, if (input_count == output_count) { HIP_RETURN_IF_ERROR(hipMemcpyAsync(reinterpret_cast(output.MutableData()), input_data, input_count * sizeof(T), hipMemcpyDeviceToDevice, stream)); } else { + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, &one, input_tensor, input_data, - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } } else { // miopenReduceTensor for ReduceSum has issue if input and output has same size, we just need to copy the data for this case @@ -675,11 +677,12 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, Impl_Cast(stream, temp_output.get(), reinterpret_cast(output.MutableData()), output_count); } else { + auto* p_output = reinterpret_cast(output.template MutableData()); MIOPEN_RETURN_IF_ERROR(miopenReduceTensor( RocmKernel::GetMiopenHandle(rocm_stream), reduce_desc, indices_rocm.get(), indices_bytes, workspace_rocm.get(), workspace_bytes, &one, input_tensor, reinterpret_cast(input.Data()), - &zero, output_tensor, reinterpret_cast(output.MutableData()))); + &zero, output_tensor, p_output)); } } } @@ -743,18 +746,29 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR // empty axes and no-op if (axes.empty() && noop_with_empty_axes_) { auto* Y = ctx->Output(0, X->Shape()); - HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), hipMemcpyDeviceToDevice, Stream(ctx))); + HIP_RETURN_IF_ERROR(hipMemcpyAsync(Y->MutableData(), X->Data(), X->SizeInBytes(), + hipMemcpyDeviceToDevice, Stream(ctx))); return Status::OK(); } +#ifdef ENABLE_TRAINING + // Use ATen for ReduceSum if possible. + const TensorShape& input_shape = X->Shape(); + if (contrib::IsATenOperatorExecutorInitialized() && miopen_reduce_op == MIOPEN_REDUCE_TENSOR_ADD && !calculate_log_ && + !calculate_sqt_ && !log_sum_exp_ && input_shape.Size() > 0) { + if (axes.empty()) { + axes.resize(input_shape.NumDimensions()); + std::iota(axes.begin(), axes.end(), 0); + } + ORT_RETURN_IF_ERROR(contrib::ExecuteReduceSumATen(ctx, axes, keepdims_)); + return Status::OK(); + } +#endif + PrepareReduceMetadata prepare_reduce_metadata; - ORT_RETURN_IF_ERROR(PrepareForReduce(X, - keepdims_, - axes, - prepare_reduce_metadata)); + ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), *X, prepare_reduce_metadata, *Y, miopen_reduce_op, axes, calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction, ctx->GetComputeStream()); } @@ -837,7 +851,6 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenR MIOPEN_RETURN_IF_ERROR(miopenReduceTensor(GetMiopenHandle(ctx), reduce_desc, indices_rocm.get(), indices_bytes, \ workspace_rocm.get(), workspace_bytes, &one, input_tensor, temp_X.get(), \ &zero, output_tensor, temp_Y.get())); \ - \ Impl_Cast(Stream(ctx), temp_Y.get(), reinterpret_cast(Y->MutableData()), output_count); \ \ return Status::OK(); \ @@ -909,13 +922,13 @@ template std::unique_ptr ReduceCompute #include "core/common/inlined_containers.h" #include "core/framework/allocator.h" #include "core/platform/ort_mutex.h" @@ -56,7 +55,7 @@ class ROCMPinnedAllocator : public IAllocator { ROCMPinnedAllocator(const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0), + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device always with id 0*/), 0, OrtMemTypeCPUOutput)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 730f55608c..484e59f4de 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -39,11 +39,11 @@ const char* RocmErrString(rocblas_status e) { CASE_ENUM_TO_STR(rocblas_status_invalid_handle); CASE_ENUM_TO_STR(rocblas_status_not_implemented); CASE_ENUM_TO_STR(rocblas_status_invalid_pointer); + CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); CASE_ENUM_TO_STR(rocblas_status_invalid_size); CASE_ENUM_TO_STR(rocblas_status_memory_error); CASE_ENUM_TO_STR(rocblas_status_internal_error); CASE_ENUM_TO_STR(rocblas_status_perf_degraded); - CASE_ENUM_TO_STR(rocblas_status_size_query_mismatch); CASE_ENUM_TO_STR(rocblas_status_size_increased); CASE_ENUM_TO_STR(rocblas_status_size_unchanged); CASE_ENUM_TO_STR(rocblas_status_invalid_value); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index c9975d0bc7..d7c5098d9d 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/inlined_containers.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/rocm/rocm_execution_provider.h" @@ -9,7 +10,6 @@ #include "core/providers/rocm/rocm_fwd.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_profiler.h" -#include "core/providers/rocm/rocm_stream_handle.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/rocm/rocm_contrib_kernels.h" @@ -23,6 +23,8 @@ #include "core/providers/rocm/triton_kernel.h" #endif +#include "core/providers/rocm/rocm_stream_handle.h" + using namespace onnxruntime::common; namespace onnxruntime { @@ -38,42 +40,64 @@ class Memcpy final : public OpKernel { ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); Tensor* Y = ctx->Output(0, X->Shape()); ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor."); - const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); - return gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream()); - } else if (X_type->IsSparseTensorType()) { - const auto* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); - SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); - return X->Copy(Info().GetDataTransferManager(), *Y); - } else if (X_type->IsTensorSequenceType()) { - const TensorSeq* X = ctx->Input(0); - ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); - TensorSeq* Y = ctx->Output(0); - ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); - auto X_dtype = X->DataType(); - Y->SetType(X_dtype); - AllocatorPtr alloc; - auto status = ctx->GetTempSpaceAllocator(&alloc); - if (!status.IsOK()) { - return Status(common::ONNXRUNTIME, common::FAIL, - "Memcpy rocm: unable to get an allocator."); - } - auto X_size = X->Size(); - Y->Reserve(X_size); - for (size_t i = 0; i < X_size; ++i) { - const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); - const IDataTransfer* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - Status retval = gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream()); - if (!retval.IsOK()) { - return retval; - } - Y->Add(std::move(*target_tensor)); - } + // do we support async copy? + // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, + // so we don't need the check here. + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); + } else { + if (X_type->IsSparseTensorType()) { + // TODO: support aysnc copy for sparse tensor + // sync the stream first, since it is a sync memory copy + HIP_CALL_THROW(hipStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle()))); + const auto* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr."); + SparseTensor* Y = ctx->OutputSparse(0, X->DenseShape()); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output sparse tensor."); + return X->Copy(Info().GetDataTransferManager(), *Y); + } else if (X_type->IsTensorSequenceType()) { + const TensorSeq* X = ctx->Input(0); + ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor sequence is nullptr."); + TensorSeq* Y = ctx->Output(0); + ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor sequence."); + auto X_dtype = X->DataType(); + Y->SetType(X_dtype); + AllocatorPtr alloc; + + // If we are copying contents to ROCM, the allocator to use + // to allocate the buffers of the new tensors in the sequence + // can be temp space allocator associated with the ROCM EP + if (Node().OpType() == "MemcpyFromHost") { + auto status = ctx->GetTempSpaceAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy rocm: unable to get an allocator."); + } + } else { + // If we are copying contents to CPU (op type is "MemcpyToHost"), + // the allocator to use to allocate the buffers of the new tensors + // in the sequence will be the allocator from the CPU EP + auto status = ctx->GetTempSpaceCPUAllocator(&alloc); + if (!status.IsOK()) { + return Status(common::ONNXRUNTIME, common::FAIL, + "Memcpy rocm: unable to get the CPU allocator."); + } + } + auto X_size = X->Size(); + Y->Reserve(X_size); + for (size_t i = 0; i < X_size; ++i) { + const Tensor& source_tensor = X->Get(i); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, + target_tensor->Location().device); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); + Y->Add(std::move(*target_tensor)); + } + return Status::OK(); + } + return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } - return Status(common::ONNXRUNTIME, common::FAIL, "Memcpy: Unsupported input type."); } }; @@ -100,18 +124,23 @@ ONNX_OPERATOR_KERNEL_EX( } // namespace rocm -AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) { +AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, + size_t gpu_mem_limit, + ArenaExtendStrategy arena_extend_strategy, + ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, + const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( [external_allocator_info](OrtDevice::DeviceId id) { - return std::make_unique(id, HIP, external_allocator_info.alloc, external_allocator_info.free, external_allocator_info.empty_cache); + return std::make_unique(id, HIP, + external_allocator_info.alloc, + external_allocator_info.free, + external_allocator_info.empty_cache); }, device_id, false); return CreateAllocator(default_memory_info); - } else { AllocatorCreationInfo default_memory_info( [](OrtDevice::DeviceId id) { @@ -120,12 +149,7 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, - static_cast(arena_extend_strategy), - -1, - -1, - -1, - -1)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -149,20 +173,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { - // dtor shouldn't throw. if something went wrong earlier (e.g. out of ROCM memory) the handles - // here may be bad, and the destroy calls can throw. - // https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#Rc-dtor-noexcept - try { - ROCBLAS_CALL_THROW(rocblas_destroy_handle(rocblas_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "rocblas_destroy_handle threw:" << ex.what(); - } - - try { - MIOPEN_CALL_THROW(miopenDestroy(miopen_handle_)); - } catch (const std::exception& ex) { - LOGS_DEFAULT(ERROR) << "miopenDestroy threw:" << ex.what(); - } + ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_))); + ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { @@ -235,7 +247,7 @@ ROCMExecutionProvider::~ROCMExecutionProvider() { } if (!external_stream_ && stream_) { - HIP_CALL_THROW(hipStreamDestroy(stream_)); + ORT_IGNORE_RETURN_VALUE(HIP_CALL(hipStreamDestroy(stream_))); } } @@ -315,7 +327,7 @@ Status ROCMExecutionProvider::OnRunStart() { Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream) { - HIP_RETURN_IF_ERROR(hipStreamSynchronize(stream_)); + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), @@ -716,12 +728,12 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); // opset 11 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, ArgMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -774,7 +786,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 18, Scan); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, Slice); @@ -827,12 +839,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, float, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, double, Round); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, MLFloat16, Round); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 13, CumSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_int64_t_int64_t, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, int64_t_float_int64_t, OneHot); @@ -1087,6 +1097,17 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, BFloat16, Sub); @@ -1105,17 +1126,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint16_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint32_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint64_t, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Sign); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Sign); // OpSet 14 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, CumSum); @@ -1186,12 +1196,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 15, 18, Shape); // Opset 16 -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LeakyRelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LeakyRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, PRelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double_t, Where); @@ -1258,979 +1269,932 @@ KernelCreateInfo BuildKernelCreateInfo() { return {}; } +// clang-format off static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // opset 10 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // opset 10 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // opset 11 + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // OpSet 12 - BuildKernelCreateInfo, + // OpSet 12 + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // OpSet 13 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // OpSet 13 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // OpSet 14 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // OpSet 14 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // OpSet 15 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // OpSet 15 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // Opset 16 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // Opset 16 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // Opset 17 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // Opset 17 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // Opset 18 - BuildKernelCreateInfo, + // Opset 18 + BuildKernelCreateInfo, - // Opset 19 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // Opset 19 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -2250,6 +2214,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { return Status::OK(); } +// clang-format on } // namespace rocm @@ -2336,7 +2301,6 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // These are usually shape related computation subgraphs // Following logic can be extended for other EPs auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); - std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) @@ -2371,7 +2335,7 @@ OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) cons std::vector ROCMExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_memory_info( - [](OrtDevice::DeviceId device_id) { + [](OrtDevice::DeviceId) { return std::make_unique(HIP_PINNED); }, // TODO: should we use info_.device_id instead of DEFAULT_CPU_ALLOCATOR_DEVICE_ID? @@ -2383,7 +2347,8 @@ std::vector ROCMExecutionProvider::CreatePreferredAllocators() { return std::vector{ CreateRocmAllocator(info_.device_id, info_.gpu_mem_limit, info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg), - CreateAllocator(pinned_memory_info)}; + CreateAllocator(pinned_memory_info), + }; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 3e86afb7d6..c4945b9ac2 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -36,11 +36,11 @@ class ROCMExecutionProvider : public IExecutionProvider { return nullptr; } - rocblas_handle PerThreadRocblasHandle() { + rocblas_handle PerThreadDefaultRocblasHandle() { return GetPerThreadContext().RocblasHandle(); } - miopenHandle_t PerThreadMiopenHandle() { + miopenHandle_t PerThreadDefaultMiopenHandle() { return GetPerThreadContext().MiopenHandle(); } @@ -60,7 +60,6 @@ class ROCMExecutionProvider : public IExecutionProvider { const hipDeviceProp_t& GetDeviceProp() const { return device_prop_; }; int GetMiopenConvExhaustiveSearch() const { return info_.miopen_conv_exhaustive_search; } bool DoCopyOnDefaultStream() const { return info_.do_copy_in_default_stream; } - bool GetMiopenConvUseMaxWorkspace() const { return info_.miopen_conv_use_max_workspace; } ProviderOptions GetProviderOptions() const override { @@ -68,15 +67,15 @@ class ROCMExecutionProvider : public IExecutionProvider { } static AllocatorPtr CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); + ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, const OrtArenaCfg* arena_cfg); ITuningContext* GetTuningContext() const override; std::unique_ptr GetProfiler() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - std::vector CreatePreferredAllocators() override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + std::vector CreatePreferredAllocators() override; private: ROCMExecutionProviderInfo info_; @@ -105,21 +104,30 @@ class ROCMExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, hipStream_t stream) { - if (std::is_same::value) { + constexpr bool is_float = std::is_same::value; + constexpr bool is_double = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_BFloat16 = std::is_same::value; + if (is_float) { if (!constant_ones_float_) { constant_ones_float_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (std::is_same::value) { + } else if (is_double) { if (!constant_ones_double_) { constant_ones_double_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (std::is_same::value) { + } else if (is_half) { if (!constant_ones_half_) { constant_ones_half_ = rocm::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); + } else if (is_BFloat16) { + if (!constant_ones_bfloat16_) { + constant_ones_bfloat16_ = rocm::CreateConstantOnes(); + } + return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); } else { return nullptr; } @@ -132,6 +140,7 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_float_; std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; + std::unique_ptr> constant_ones_bfloat16_; }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 91e3aaaa42..650635c153 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -27,12 +27,10 @@ constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_dur } // namespace provider_option_names } // namespace rocm -namespace { const EnumNameMapping arena_extend_strategy_mapping{ {ArenaExtendStrategy::kNextPowerOfTwo, "kNextPowerOfTwo"}, {ArenaExtendStrategy::kSameAsRequested, "kSameAsRequested"}, }; -} // namespace ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { ROCMExecutionProviderInfo info{}; @@ -81,7 +79,9 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToEnumReference( rocm::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) - .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvExhaustiveSearch, info.miopen_conv_exhaustive_search) + .AddAssignmentToReference( + rocm::provider_option_names::kMiopenConvExhaustiveSearch, + info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) .AddValueParser( diff --git a/onnxruntime/core/providers/rocm/rocm_fwd.h b/onnxruntime/core/providers/rocm/rocm_fwd.h deleted file mode 100644 index b123446fa9..0000000000 --- a/onnxruntime/core/providers/rocm/rocm_fwd.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -namespace rocm { -template -KernelCreateInfo BuildKernelCreateInfo(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 02f15fdad8..463c1cf0d2 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -35,14 +35,12 @@ class RocmKernel : public OpKernel { // use this to precisely locate the node where ROCM failure comes from // if (hipSuccess != hipDeviceSynchronize()) // __debugbreak(); - if (s.IsOK()) { auto err = hipGetLastError(); if (err != hipSuccess) { - s = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "HIP error ", hipGetErrorName(err), ":", hipGetErrorString(err)); } } - return s; } @@ -64,18 +62,18 @@ class RocmKernel : public OpKernel { return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeDefault), count_or_bytes, true); } - template - inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { - if (count_or_bytes == 0) return nullptr; - return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); - } - inline void AddDeferredReleaseCPUPtr(void* p, onnxruntime::Stream* ort_stream) const { ORT_ENFORCE(ort_stream->GetDevice().Type() == OrtDevice::GPU); auto* rocm_ep_stream = static_cast(ort_stream); rocm_ep_stream->EnqueDeferredCPUBuffer(p); } + template + inline IAllocatorUniquePtr AllocateBufferOnCPUPinned(size_t count_or_bytes) const { + if (count_or_bytes == 0) return nullptr; + return IAllocator::MakeUniquePtr(Info().GetAllocator(OrtMemType::OrtMemTypeCPU), count_or_bytes); + } + const hipDeviceProp_t& GetDeviceProp() const { return provider_->GetDeviceProp(); } inline hipStream_t Stream(OpKernelContext* ctx) const { @@ -83,6 +81,22 @@ class RocmKernel : public OpKernel { return stream ? static_cast(stream->GetHandle()) : nullptr; } + inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { + return GetMiopenHandle(static_cast(ctx->GetComputeStream())); + } + + static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { + return stream->miopen_handle_; + } + + inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { + return GetRocblasHandle(static_cast(ctx->GetComputeStream())); + } + + static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { + return stream->rocblas_handle_; + } + tunable::RocmTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } @@ -106,7 +120,7 @@ class RocmKernel : public OpKernel { } } - RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { + RocmAsyncBuffer(const RocmKernel* op_kernel, gsl::span vec) : RocmAsyncBuffer(op_kernel, vec.size()) { memcpy(CpuPtr(), vec.data(), vec.size() * sizeof(T)); } @@ -151,28 +165,12 @@ class RocmKernel : public OpKernel { const RocmKernel* op_kernel_; }; - inline rocblas_handle RocblasHandle() const { - return provider_->PerThreadRocblasHandle(); + inline rocblas_handle DefaultRocblasHandle() const { + return provider_->PerThreadDefaultRocblasHandle(); } - inline miopenHandle_t MiopenHandle() const { - return provider_->PerThreadMiopenHandle(); - } - - static inline rocblas_handle GetRocblasHandle(onnxruntime::RocmStream* stream) { - return stream->rocblas_handle_; - } - - inline rocblas_handle GetRocblasHandle(OpKernelContext* ctx) const { - return GetRocblasHandle(static_cast(ctx->GetComputeStream())); - } - - static inline miopenHandle_t GetMiopenHandle(onnxruntime::RocmStream* stream) { - return stream->miopen_handle_; - } - - inline miopenHandle_t GetMiopenHandle(OpKernelContext* ctx) const { - return GetMiopenHandle(static_cast(ctx->GetComputeStream())); + inline miopenHandle_t DefaultMiopenHandle() const { + return provider_->PerThreadDefaultMiopenHandle(); } protected: diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index e55b2edbad..4d88c25469 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -3,15 +3,13 @@ #include "core/providers/shared_library/provider_api.h" #include "core/providers/rocm/rocm_provider_factory.h" - -#include +#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/common/gsl.h" #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_execution_provider_info.h" #include "core/providers/rocm/rocm_allocator.h" -#include "core/providers/rocm/rocm_provider_factory_creator.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/math/unary_elementwise_ops_impl.h" @@ -47,7 +45,7 @@ std::unique_ptr ROCMProviderFactory::CreateProvider() { return std::make_unique(info_); } -struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { +struct ProviderInfo_ROCM_Impl final : ProviderInfo_ROCM { OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) override { int num_devices; auto hip_err = ::hipGetDeviceCount(&num_devices); @@ -128,9 +126,24 @@ struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { } // Used by slice_concatenate_test.cc and onnxruntime_pybind_state.cc - void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); } + + void rocmMemcpy_HostToDevice(void* dst, const void* src, size_t count) override { + // hipMemcpy() operates on the default stream + HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyHostToDevice)); + + // To ensure that the copy has completed, invoke a stream sync for the default stream. + // For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. + // The function will return once the pageable buffer has been copied to the staging memory for DMA transfer + // to device memory, but the DMA to final destination may not have completed. + + HIP_CALL_THROW(hipStreamSynchronize(0)); + } + // Used by onnxruntime_pybind_state.cc - void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); } + void rocmMemcpy_DeviceToHost(void* dst, const void* src, size_t count) override { + // For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. + HIP_CALL_THROW(hipMemcpy(dst, src, count, hipMemcpyDeviceToHost)); + } int hipGetDeviceCount() override { int num_devices = 0; @@ -152,10 +165,9 @@ struct ProviderInfo_ROCM_Impl : ProviderInfo_ROCM { return std::make_shared(info); } - std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) override { + std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) override { return ROCMExecutionProvider::CreateRocmAllocator(device_id, gpu_mem_limit, arena_extend_strategy, external_allocator_info, default_memory_arena_cfg); } - } g_info; struct ROCM_Provider : Provider { @@ -169,8 +181,8 @@ struct ROCM_Provider : Provider { info.gpu_mem_limit = params->gpu_mem_limit; info.arena_extend_strategy = static_cast(params->arena_extend_strategy); info.miopen_conv_exhaustive_search = params->miopen_conv_exhaustive_search; - info.do_copy_in_default_stream = params->do_copy_in_default_stream; - info.has_user_compute_stream = params->has_user_compute_stream; + info.do_copy_in_default_stream = params->do_copy_in_default_stream != 0; + info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; info.tunable_op.enable = params->tunable_op_enable; @@ -180,21 +192,32 @@ struct ROCM_Provider : Provider { return std::make_shared(info); } + /** + * This function will be called by the C API UpdateROCMProviderOptions(). + * + * What this function does is equivalent to resetting the OrtROCMProviderOptions instance with + * default ROCMExecutionProviderInf instance first and then set up the provided provider options. + * See ROCMExecutionProviderInfo::FromProviderOptions() for more details. + */ void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { - auto info = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); + auto internal_options = onnxruntime::ROCMExecutionProviderInfo::FromProviderOptions(options); auto& rocm_options = *reinterpret_cast(provider_options); - rocm_options.device_id = info.device_id; - rocm_options.gpu_mem_limit = info.gpu_mem_limit; - rocm_options.arena_extend_strategy = static_cast(info.arena_extend_strategy); - rocm_options.miopen_conv_exhaustive_search = info.miopen_conv_exhaustive_search; - rocm_options.do_copy_in_default_stream = info.do_copy_in_default_stream; - rocm_options.has_user_compute_stream = info.has_user_compute_stream; - rocm_options.user_compute_stream = info.user_compute_stream; - rocm_options.default_memory_arena_cfg = info.default_memory_arena_cfg; - rocm_options.tunable_op_enable = info.tunable_op.enable; - rocm_options.tunable_op_tuning_enable = info.tunable_op.tuning_enable; - rocm_options.tunable_op_max_tuning_duration_ms = info.tunable_op.max_tuning_duration_ms; + rocm_options.device_id = internal_options.device_id; + rocm_options.gpu_mem_limit = internal_options.gpu_mem_limit; + rocm_options.arena_extend_strategy = static_cast(internal_options.arena_extend_strategy); + rocm_options.miopen_conv_exhaustive_search = internal_options.miopen_conv_exhaustive_search; + rocm_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream; + rocm_options.has_user_compute_stream = internal_options.has_user_compute_stream; + // The 'has_user_compute_stream' of the OrtROCMProviderOptions instance can be set by C API UpdateROCMProviderOptionsWithValue() as well. + // We only set the 'has_user_compute_stream' of the OrtROCMProviderOptions instance if it is provided in options + if (options.find("has_user_compute_stream") != options.end()) { + rocm_options.user_compute_stream = internal_options.user_compute_stream; + } + rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.tunable_op_enable = internal_options.tunable_op.enable; + rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; + rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.h b/onnxruntime/core/providers/rocm/rocm_provider_factory.h index 8cd7bd3573..80b887af4e 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.h +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.h @@ -3,6 +3,7 @@ #include "onnxruntime_c_api.h" #include "core/framework/provider_options.h" +#include "core/common/common.h" namespace onnxruntime { class IAllocator; @@ -43,7 +44,16 @@ struct ProviderInfo_ROCM { #endif virtual std::shared_ptr CreateExecutionProviderFactory(const onnxruntime::ROCMExecutionProviderInfo& info) = 0; - virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, OrtArenaCfg* default_memory_arena_cfg) = 0; + virtual std::shared_ptr CreateRocmAllocator(int16_t device_id, size_t gpu_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy, onnxruntime::ROCMExecutionProviderExternalAllocatorInfo& external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) = 0; + + // This function is the entry point to ROCM EP's UT cases. + // All tests ared only called from onnxruntime_test_all. + virtual void TestAll() { + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is only implements in test code path."); + } + + protected: + ~ProviderInfo_ROCM() = default; // Can only be destroyed through a subclass instance }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index 0d9877e6b1..670aae91ca 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -1,7 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/rocm/rocm_resource.h" #include "core/providers/rocm/rocm_stream_handle.h" #include "core/providers/rocm/rocm_common.h" // #include "core/common/spin_pause.h" -#include "core/providers/rocm/rocm_resource.h" namespace onnxruntime { @@ -82,15 +84,29 @@ void RocmStream::EnqueDeferredCPUBuffer(void* cpu_buffer) { deferred_cpu_buffers_.push_back(cpu_buffer); } -struct CpuBuffersInfo { // TODO: should be moved to base class +struct CpuBuffersInfo { + // This struct stores the information needed + // to release CPU buffers allocated for GPU kernels. + // It's used to enqueue their release after + // associated GPU kernels in a ROCM stream. + + // This is a CPU allocator in ROCM EP. + // It must be the one used to allocate the + // following pointers. AllocatorPtr allocator; + // buffers[i] is the i-th pointer added by + // AddDeferredReleaseCPUPtr for a specific + // ROCM stream. For example, this fields + // should contain all values in + // deferred_release_buffer_pool_[my_stream] + // when release my_stream's buffers. std::unique_ptr buffers; // CPU buffer buffers[i]. // Number of buffer points in "buffers". size_t n_buffers; }; -static void ReleaseCpuBufferCallback(hipStream_t /*stream*/, hipError_t /*status*/, void* raw_info) { // TODO: should be moved to base class +static void ReleaseCpuBufferCallback(void* raw_info) { std::unique_ptr info = std::make_unique(); info.reset(reinterpret_cast(raw_info)); for (size_t i = 0; i < info->n_buffers; ++i) { @@ -111,14 +127,7 @@ Status RocmStream::CleanUpOnRunEnd() { cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); } cpu_buffers_info->n_buffers = deferred_cpu_buffers_.size(); - // TODO(wechi): CUDA deprecates cudaStreamAddCallback and - // uses another API, cudaLaunchHostFunc(which can be - // captured in CUDA graph). Once AMD adds similar feature, - // we should replace the following line with - // hipLaunchHostFunc(stream, ReleaseCpuBufferCallback, cpu_buffers_info); - - // Release memory asynchronously to avoid blocking the compute stream. - HIP_RETURN_IF_ERROR(hipStreamAddCallback(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release(), 0)); + HIP_RETURN_IF_ERROR(hipLaunchHostFunc(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release())); } else { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetHandle()))); for (auto* buffer : deferred_cpu_buffers_) { @@ -130,10 +139,10 @@ Status RocmStream::CleanUpOnRunEnd() { return Status::OK(); } -void* RocmStream::GetResource(int version, int type) const { +void* RocmStream::GetResource(int version, int id) const { ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!"); void* resource{}; - switch (type) { + switch (id) { case RocmResource::hip_stream_t: return reinterpret_cast(GetHandle()); break; @@ -149,6 +158,7 @@ void* RocmStream::GetResource(int version, int type) const { return resource; } +// CPU Stream command handles void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { static_cast(¬ification)->wait_on_device(stream); } diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 865cff0abf..1f3e5b7554 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/providers/rocm/rocm_pch.h" // #include "core/providers/cuda/shared_inc/cuda_utils.h" @@ -17,14 +20,12 @@ struct RocmStream : Stream { ~RocmStream(); - std::unique_ptr CreateNotification(size_t num_consumers) override; + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; void Flush() override; Status CleanUpOnRunEnd() override; - void* GetResource(int version, int id) const override; - void EnqueDeferredCPUBuffer(void* cpu_buffer); bool own_stream_{true}; @@ -33,6 +34,8 @@ struct RocmStream : Stream { rocblas_handle rocblas_handle_{}; + void* GetResource(int version, int id) const override; + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; diff --git a/onnxruntime/core/providers/rocm/rocm_utils.cu b/onnxruntime/core/providers/rocm/rocm_utils.cu index cbf410e78a..b817e025ce 100644 --- a/onnxruntime/core/providers/rocm/rocm_utils.cu +++ b/onnxruntime/core/providers/rocm/rocm_utils.cu @@ -30,13 +30,14 @@ template void Fill(hipStream_t stream, T* output, T value, int64_t count) { int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); HIP_LONG N = static_cast(count); - _Fill<<>>(output, value, N); + _Fill + <<>>(output, value, N); } template class ConstantBufferImpl : public IConstantBuffer { public: - ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) {} - + ConstantBufferImpl(T val) : buffer_(nullptr), count_(0), val_(val) { + } ~ConstantBufferImpl() { if (buffer_) HIP_CALL_THROW(hipFree(buffer_)); @@ -70,6 +71,7 @@ std::unique_ptr> CreateConstantOnes() { template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); template std::unique_ptr> CreateConstantOnes(); +template std::unique_ptr> CreateConstantOnes(); #define SPECIALIZED_FILL(T) \ template void Fill(hipStream_t stream, T * output, T value, int64_t count); @@ -81,6 +83,7 @@ SPECIALIZED_FILL(int64_t) SPECIALIZED_FILL(float) SPECIALIZED_FILL(double) SPECIALIZED_FILL(__half) +SPECIALIZED_FILL(BFloat16) } // namespace rocm } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h b/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h deleted file mode 100644 index 83ca0a443c..0000000000 --- a/onnxruntime/core/providers/rocm/shared_inc/fast_divmod.h +++ /dev/null @@ -1,90 +0,0 @@ -// -// Copyright (c) 2016, NVIDIA CORPORATION. All rights reserved -// Licensed under the MIT license. See LICENSE.md file in the project root for full license information. -// - -#pragma once - -#include -#include -#include -#include -#include "core/common/common.h" - -namespace onnxruntime { -namespace rocm { - -// DivMod is a helper class for integer division and modulo operation. -// There is a fast version for int type and a slow version for other type. -template -struct DivMod { - DivMod(T d = 1) { - d_ = d == 0 ? 1 : d; - ORT_ENFORCE(d_ >= 1 && d_ <= std::numeric_limits::max()); - } - - __host__ __device__ inline T div(T n) const { - return n / d_; - } - - __host__ __device__ inline T mod(T n) const { - return n % d_; - } - - __host__ __device__ inline void divmod(T n, T& q, T& r) const { - q = div(n); - r = n - q * d_; - } - - T d_; // divisor -}; - -// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf -// In current ORT, fast_divmod is used for calculating the position of a element in tensor, -// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple, -// then GPU compiler can do loop unroll easilly when divmod is called in a loop. -template <> -struct DivMod { - DivMod(int d = 1) { - d_ = d == 0 ? 1 : d; - ORT_ENFORCE(d_ >= 1 && d_ <= static_cast(std::numeric_limits::max())); - - for (l_ = 0; l_ < 32; l_++) - if ((1U << l_) >= d_) break; - - uint64_t one = 1; - uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1; - M_ = static_cast(m); - // according to paper, the value of m' should fit in a unsigned integer. - ORT_ENFORCE(M_ > 0 && M_ == m); - } - - __host__ __device__ inline int div(int n) const { -#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) - uint32_t t = __umulhi(M_, n); - return (t + n) >> l_; -#else - // Using uint64_t for t, then t + n won't overflow. - uint64_t t = ((uint64_t)M_ * n) >> 32; - return static_cast((t + n) >> l_); -#endif - } - - __host__ __device__ inline int mod(int n) const { - return n - div(n) * d_; - } - - __host__ __device__ inline void divmod(int n, int& q, int& r) const { - q = div(n); - r = n - q * d_; - } - - uint32_t d_; // divisor - uint32_t M_; // m' in the paper. - uint32_t l_; // l_ = ceil(log2(d_)) -}; - -using fast_divmod = DivMod; // Keep the old name for backward compatibility. - -} // namespace rocm -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h index d6623ef63f..b6b40666b8 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h +++ b/onnxruntime/core/providers/rocm/shared_inc/rocm_call.h @@ -17,16 +17,20 @@ std::conditional_t RocmCall( #define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) + #define HIPSPARSE_CALL(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define MIOPEN_CALL(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) #define MIOPEN_CALL2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) + #define HIPFFT_CALL(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) #define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess, "", __FILE__, __LINE__)) #define ROCBLAS_CALL_THROW(expr) (RocmCall((expr), #expr, "ROCBLAS", rocblas_status_success, "", __FILE__, __LINE__)) + #define HIPSPARSE_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPSPARSE", HIPSPARSE_STATUS_SUCCESS, "", __FILE__, __LINE__)) #define HIPRAND_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPRAND", HIPRAND_STATUS_SUCCESS, "", __FILE__, __LINE__)) + #define MIOPEN_CALL_THROW(expr) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, "", __FILE__, __LINE__)) #define MIOPEN_CALL_THROW2(expr, m) (RocmCall((expr), #expr, "MIOPEN", miopenStatusSuccess, m, __FILE__, __LINE__)) #define HIPFFT_CALL_THROW(expr) (RocmCall((expr), #expr, "HIPFFT", HIPFFT_SUCCESS, "", __FILE__, __LINE__)) diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index 15e2449cd2..c641103a74 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -162,13 +162,13 @@ static void RunBiasGeluTestBFloat16(const std::vector& input_dims, cons tester.AddInput("B", bias_dims, bias_data_bf16); tester.AddOutput("C", input_dims, output_data_bf16); std::vector> execution_providers; -#ifdef USE_CUDA +#if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); -#elif USE_ROCM +#elif defined(USE_ROCM) execution_providers.push_back(DefaultRocmExecutionProvider()); -#elif USE_DNNL +#elif defined(USE_DNNL) execution_providers.push_back(DefaultDnnlExecutionProvider()); -#elif USE_DML +#elif defined(USE_DML) execution_providers.push_back(DefaultDmlExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -197,9 +197,8 @@ TEST(BiasGeluTest, BFloat16) { } #endif +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(MathOpTest, ComplexMul) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { -0.5f, 0.6f}; @@ -219,13 +218,15 @@ TEST(MathOpTest, ComplexMul) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMulConj) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { -0.5f, 0.6f}; @@ -245,13 +246,15 @@ TEST(MathOpTest, ComplexMulConj) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMul_fp16) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { MLFloat16(-0.5f), MLFloat16(0.6f)}; @@ -271,13 +274,15 @@ TEST(MathOpTest, ComplexMul_fp16) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(MathOpTest, ComplexMulConj_fp16) { - if (DefaultCudaExecutionProvider() == nullptr) return; - std::vector input_a_data = { MLFloat16(-0.5f), MLFloat16(0.6f)}; @@ -297,9 +302,14 @@ TEST(MathOpTest, ComplexMulConj_fp16) { tester.AddOutput("C", {4, 2}, output_data); std::vector> execution_providers; +#ifdef USE_CUDA execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/fft_op_test.cc b/onnxruntime/test/contrib_ops/fft_op_test.cc index eaadb95c8a..56a6466c76 100644 --- a/onnxruntime/test/contrib_ops/fft_op_test.cc +++ b/onnxruntime/test/contrib_ops/fft_op_test.cc @@ -8,7 +8,15 @@ namespace onnxruntime { namespace test { TEST(ContribOpTest, Rfft) { - if (DefaultCudaExecutionProvider() == nullptr) return; + if (DefaultCudaExecutionProvider() == nullptr && DefaultRocmExecutionProvider() == nullptr) return; + + std::vector> execution_providers; + if (DefaultCudaExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (DefaultRocmExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } OpTester test("Rfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -17,13 +25,19 @@ TEST(ContribOpTest, Rfft) { // Target values conputed using PyTorch torch.fft.rfft(X, dim=-1, norm="backward") test.AddInput("X", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); test.AddOutput("Y", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } TEST(ContribOpTest, Irfft) { - if (DefaultCudaExecutionProvider() == nullptr) return; + if (DefaultCudaExecutionProvider() == nullptr && DefaultRocmExecutionProvider() == nullptr) return; + + std::vector> execution_providers; + if (DefaultCudaExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (DefaultRocmExecutionProvider() != nullptr) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } OpTester test("Irfft", 1, onnxruntime::kMSDomain); test.AddAttribute("signal_ndim", static_cast(1)); @@ -31,8 +45,6 @@ TEST(ContribOpTest, Irfft) { test.AddAttribute("normalized", static_cast(0)); test.AddInput("X", {4, 3, 2}, {0.0400f, 0.0000f, 1.6919f, -2.5154f, -0.1722f, 0.0000f, 0.2627f, 0.0000f, -0.4218f, 1.4748f, 1.2454f, 0.0000f, -1.7779f, 0.0000f, 3.8766f, -1.8512f, -0.9730f, 0.0000f, 0.9740f, 0.0000f, -1.4290f, 0.8341f, -4.8699f, 0.0000f}); test.AddOutput("Y", {4, 4}, {0.8129f, 1.3108f, -0.8790f, -1.2046f, 0.1661f, -0.9831f, 0.5879f, 0.4918f, 1.2506f, 0.7244f, -2.6260f, -1.1268f, -1.6885f, 1.0439f, -0.2595f, 1.8780f}); - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/greedy_search_test.cc b/onnxruntime/test/contrib_ops/greedy_search_test.cc index f5259c1391..1baf50c1ba 100644 --- a/onnxruntime/test/contrib_ops/greedy_search_test.cc +++ b/onnxruntime/test/contrib_ops/greedy_search_test.cc @@ -50,12 +50,26 @@ TEST(GreedySearchTest, GptGreedySearchFp16_VocabPadded) { const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; const char* const output_names[] = {"sequences"}; - constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { - Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + constexpr int min_cuda_architecture = 530; + bool is_cuda = HasCudaEnvironment(min_cuda_architecture); +#else + bool is_cuda = false; #endif +#ifdef USE_ROCM + bool is_rocm = true; +#else + bool is_rocm = false; +#endif + + if (is_cuda || is_rocm) { + Ort::SessionOptions session_options; + if (is_cuda) { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + } + if (is_rocm) { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); + } // The following model was obtained by padding the vocabulary size in testdata/transformers/tiny_gpt2_beamsearch_fp16.onnx // (by making beam_size == 1) from 1000 to 1600 (just for illustrative and testing purposes) to see if the greedy search @@ -117,12 +131,26 @@ TEST(GreedySearchTest, GptGreedySearchFp32) { const char* input_names[] = {"input_ids", "max_length", "min_length", "repetition_penalty"}; const char* const output_names[] = {"sequences"}; - constexpr int min_cuda_architecture = 530; - if (HasCudaEnvironment(min_cuda_architecture)) { - Ort::SessionOptions session_options; #ifdef USE_CUDA - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + constexpr int min_cuda_architecture = 530; + bool is_cuda = HasCudaEnvironment(min_cuda_architecture); +#else + bool is_cuda = false; #endif +#ifdef USE_ROCM + bool is_rocm = true; +#else + bool is_rocm = false; +#endif + + if (is_cuda || is_rocm) { + Ort::SessionOptions session_options; + if (is_cuda) { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + } + if (is_rocm) { + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); + } Ort::Session session(*ort_env, ORT_TSTR("testdata/transformers/tiny_gpt2_greedysearch_with_init_decoder.onnx"), session_options); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 8dcf632192..9c0b779870 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -358,7 +358,11 @@ void RunTest(int64_t max_iterations, // we want the CUDA provider to be first, and the CPU provider second. all except the Loop node should run on // CUDA given that, which creates the scenario where we need to copy to/from CPU to execute the Loop node correctly. std::vector> execution_providers; +#if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(expect_result, failure_message, {kTensorrtExecutionProvider}, nullptr, &execution_providers); @@ -1038,8 +1042,8 @@ TEST(Loop, IterationCountAsOutput) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -#ifdef USE_CUDA -// test that when part of the subgraph run on CUDA it executes successfully +#if defined(USE_CUDA) || defined(USE_ROCM) +// test that when part of the subgraph run on CUDA/ROCm it executes successfully TEST(Loop, MixedExecutionProviders) { RunOptions options{}; options.mixed_execution_providers = true; diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 8008fd129c..3d46893cdb 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -411,7 +411,11 @@ static void RunTest_v9(const std::string test_name, int64_t sequence_len, int64_ // we want the CUDA provider to be first, and the CPU provider second. all except the Scan node should run on // CUDA given that, which creates the scenario where we need to copy to/from CPU to execute the Scan node correctly. std::vector> execution_providers; +#if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(expect_result, failure_message, options.excluded_provider_types, nullptr, &execution_providers); @@ -1162,7 +1166,11 @@ void UnknownDimInSubgraphOutput(bool is_v8, bool mixed_execution_providers = fal // we want the CUDA provider to be first, and the CPU provider second. all except the Scan node should run on // CUDA given that, which creates the scenario where we need to copy to/from CPU to execute the Scan node correctly. std::vector> execution_providers; +#if defined(USE_CUDA) execution_providers.push_back(DefaultCudaExecutionProvider()); +#elif defined(USE_ROCM) + execution_providers.push_back(DefaultRocmExecutionProvider()); +#endif execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", RunOptions().excluded_provider_types, nullptr, @@ -1174,7 +1182,7 @@ void UnknownDimInSubgraphOutput(bool is_v8, bool mixed_execution_providers = fal TEST_8_AND_9(UnknownDimInSubgraphOutput); -#ifdef USE_CUDA +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(Scan, MixedExecutionProviders) { RunOptions options{}; options.is_v8 = false; diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e029312804..6f49231752 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -150,6 +150,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): # CUFFT -> HIPFFT s = s.replace("CUFFT", "HIPFFT") + s = s.replace("cufftXtMakePlanMany", "hipfftXtMakePlanMany") + s = s.replace("cufftXtExec", "hipfftXtExec") # Undo where above hipify steps went too far. s = s.replace("id, ROCM", "id, CUDA") # cuda_execution_provider.cc @@ -169,6 +171,24 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("#include ", "#include ") s = s.replace("#include ", "#include ") s = s.replace("#include ", "#include ") + s = s.replace("#include ", "#include ") + s = s.replace('#include "hipfft.h"', "#include ") + s = s.replace('#include "hipfftXt.h"', "#include ") + + # Fix onnxruntime/contrib_ops/rocm/transformers. They include cpu headers which use "cuda" in their names. + s = s.replace("rocm_device_prop_", "cuda_device_prop_") + s = s.replace("rocm_device_arch_", "cuda_device_arch_") + + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names + # And we do this last, undoing or fixing hipify mistakes. + if "fft" in src_file_path: + s = s.replace("rocblas_datatype", "hipDataType") + s = s.replace("hipDataType_f32_c", "HIP_C_32F") + s = s.replace("hipDataType_f32_r", "HIP_R_32F") + s = s.replace("hipDataType_f64_c", "HIP_C_64F") + s = s.replace("hipDataType_f64_r", "HIP_R_64F") + s = s.replace("hipDataType_f16_c", "HIP_C_16F") + s = s.replace("hipDataType_f16_r", "HIP_R_16F") with open(dst_file_path, "w") as f: f.write(s)