diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index fcddd2a51e..111033c780 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -157,10 +157,6 @@ set(provider_excluded_files "cuda_execution_provider_info.h" "cuda_execution_provider.cc" "cuda_execution_provider.h" - "cuda_memory_check.cc" - "cuda_memory_check.h" - "cuda_fence.cc" - "cuda_fence.h" "cuda_kernel.h" "cuda_pch.cc" "cuda_pch.h" diff --git a/include/onnxruntime/core/providers/rocm/rocm_resource.h b/include/onnxruntime/core/providers/rocm/rocm_resource.h index f4a2076676..db032b4871 100644 --- a/include/onnxruntime/core/providers/rocm/rocm_resource.h +++ b/include/onnxruntime/core/providers/rocm/rocm_resource.h @@ -8,5 +8,9 @@ enum RocmResource : int { hip_stream_t = rocm_resource_offset, miopen_handle_t, - hipblas_handle_t + hipblas_handle_t, + deferred_cpu_allocator_t, + // below are rocm ep options + device_id_t, // 10004 + arena_extend_strategy_t }; diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index cdb4d1f7ed..b8fe875ba5 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -5,9 +5,12 @@ #include #include #include +#include #include +#include #include #include +//#include #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -242,12 +245,63 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); } template <> __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); } +#define ISNAN_BFLOAT16(v__) static_cast(*reinterpret_cast(&v__) & ~BFloat16::kSignMask) \ + > BFloat16::kPositiveInfinityBits + +// Note that there is no consistent canonical NaN for FP16 and BF16; +// HIP uses 0x7FFF for HIPRT_NAN_BF16, but ONNX Runtime uses 0x7FC1. +// (see BFloat16Impl::kPositiveQNaNBits). +#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU) + template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } +template <> +__device__ __inline__ float _Min(float a, float b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); +} + +template <> +__device__ __inline__ double _Min(double a, double b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); +} + +template <> +__device__ __inline__ half _Min(half a, half b) { + return __hmin_nan(a, b); +} + +template <> +__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) { + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b); +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } +template <> +__device__ __inline__ float _Max(float a, float b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); +} + +template <> +__device__ __inline__ double _Max(double a, double b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); +} + +template <> +__device__ __inline__ half _Max(half a, half b) { + return __hmax_nan(a, b); +} + +template <> +__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) { + return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b); +} + +#undef ISNAN_BFLOAT16 +#undef NAN_BFLOAT16 + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } @@ -443,36 +497,36 @@ struct _IsNan { template <> struct _IsNan { __device__ __inline__ bool operator()(half a) const { - return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) - > MLFloat16::kPositiveInfinityBits; + return static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask) + > MLFloat16::kPositiveInfinityBits; } }; template <> struct _IsNan { __device__ __inline__ bool operator()(BFloat16 a) const { - return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) - > BFloat16::kPositiveInfinityBits; + return static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask) + > BFloat16::kPositiveInfinityBits; } }; #if !defined(DISABLE_FLOAT8_TYPES) -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E4M3FN a) const { return (*reinterpret_cast(&a) & 0x7f) == 0x7f; } }; -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E4M3FNUZ a) const { return *reinterpret_cast(&a) == 0x80; } }; -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E5M2 a) const { uint8_t c = *reinterpret_cast(&a); @@ -480,7 +534,7 @@ struct _IsNan { } }; -template <> +template<> struct _IsNan { __device__ __inline__ bool operator()(Float8E5M2FNUZ a) const { return *reinterpret_cast(&a) == 0x80; diff --git a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu index 94bee88a46..e1c89a386d 100644 --- a/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu +++ b/onnxruntime/core/providers/rocm/math/einsum_utils/einsum_auxiliary_ops_diagonal.cu @@ -1,4 +1,3 @@ -#include "hip/hip_runtime.h" // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc index a1f5eba9a2..1340c49c38 100644 --- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc @@ -16,140 +16,29 @@ using namespace onnxruntime::common; namespace onnxruntime { namespace rocm { -// opset 11 explicitly added support for negative axis. implementation already allowed it. -#define REGISTER_KERNEL_TYPED(name, T) \ +#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \ ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ name, \ kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ + 1, end, \ T, \ kRocmExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ name); -#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 12, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + name, \ + kOnnxDomain, \ + version, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()).InputMemoryType(OrtMemTypeCPUInput, 1), \ name); -// Register those with changes in OpSet12. -#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \ - REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register ReduceMin int64_t support in OpSet14. -#define REGISTER_KERNEL_TYPED_14(name, T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 14, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet -#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 11, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); - -// Register with the latest version 13 -#define REGISTER_KERNEL_TYPED_13(name, T) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 11, 12, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - name, \ - kOnnxDomain, \ - 13, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - name); +#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \ + REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \ + REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur) // TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored. template @@ -348,7 +237,9 @@ Status ReduceKernel::ReduceKernelShared( // double* Y, // const TensorShape& output_shape, // miopenReduceTensorOp_t miopen_reduce_op, -// std::vector& output_dims) const; +// miopenHandle_t miopen_handle, +// onnxruntime::Stream* stream, +// TensorShapeVector& output_dims) const; template Status ReduceKernel::ReduceKernelShared( const float* X, @@ -387,7 +278,7 @@ Status PrepareForReduce(const Tensor* X, } const auto input_dims = input_shape.GetDims(); - InlinedVector reduced(rank, false); + std::vector reduced(rank, false); if (axes.size() > 0) { prepare_reduce_metadata.output_dims = input_shape.AsShapeVector(); for (auto axis : axes) { @@ -724,11 +615,35 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input, return Status::OK(); } +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + +// template Status ReduceComputeCore( +// const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, +// /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, +// gsl::span axes, +// bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, +// Stream* ort_stream, +// const TensorShape* input_shape_override); + +template Status ReduceComputeCore( + const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata, + /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op, + gsl::span axes, + bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction, + Stream* ort_stream, + const TensorShape* input_shape_override); + template template Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const { const Tensor* X = ctx->Input(0); - std::vector axes; + TensorShapeVector axes; size_t num_inputs = ctx->InputCount(); const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input(1) : nullptr; // optional input. may be nullptr. @@ -904,7 +819,7 @@ template std::unique_ptr ReduceCompute axes, // bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp, -// bool fast_reduction, const TensorShape* input_shape_override); +// bool fast_reduction, Stream* stream, const TensorShape* input_shape_override); template std::unique_ptr ReduceCompute( const AllocatorPtr& gpu_allocator, miopenReduceTensorOp_t miopen_reduce_op, @@ -915,69 +830,76 @@ template std::unique_ptr ReduceCompute(hipError_t x) { template <> const char* RocmErrString(rocblas_status e) { ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard - switch (e) { CASE_ENUM_TO_STR(rocblas_status_success); CASE_ENUM_TO_STR(rocblas_status_invalid_handle); @@ -53,6 +52,24 @@ const char* RocmErrString(rocblas_status e) { } } +template <> +const char* RocmErrString(hipblasStatus_t e) { + ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard + switch (e) { + CASE_ENUM_TO_STR(HIPBLAS_STATUS_SUCCESS); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_INITIALIZED); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_ALLOC_FAILED); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_INVALID_VALUE); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_ARCH_MISMATCH); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_MAPPING_ERROR); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_EXECUTION_FAILED); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_INTERNAL_ERROR); + CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_SUPPORTED); + default: + return "(look for HIPBLAS_STATUS_xxx in hipblas_api.h)"; + } +} + template <> const char* RocmErrString(hiprandStatus_t) { ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard @@ -76,7 +93,7 @@ const char* RocmErrString(hipfftResult e) { CASE_ENUM_TO_STR(HIPFFT_SETUP_FAILED); CASE_ENUM_TO_STR(HIPFFT_INVALID_SIZE); default: - return "Unknown cufft error status"; + return "Unknown hipfft error status"; } } diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 298d54a996..f36b5e01db 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -10,6 +10,7 @@ #include "core/providers/rocm/rocm_fwd.h" #include "core/providers/rocm/gpu_data_transfer.h" #include "core/providers/rocm/rocm_profiler.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #ifndef DISABLE_CONTRIB_OPS #include "contrib_ops/rocm/rocm_contrib_kernels.h" @@ -43,8 +44,7 @@ class Memcpy final : public OpKernel { // do we support async copy? // The rocmMemCpyAsync will handle the pinned memory and non-pinned memory, // so we don't need the check here. - auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, - Y->Location().device); + auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(*X, *Y, *ctx->GetComputeStream())); return Status::OK(); } else { @@ -89,12 +89,10 @@ class Memcpy final : public OpKernel { Y->Reserve(X_size); for (size_t i = 0; i < X_size; ++i) { const Tensor& source_tensor = X->Get(i); - std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), - alloc); + std::unique_ptr target_tensor = Tensor::Create(source_tensor.DataType(), source_tensor.Shape(), alloc); auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(source_tensor.Location().device, target_tensor->Location().device); - ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, - *ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(gpu_data_transfer->CopyTensorAsync(source_tensor, *target_tensor, *ctx->GetComputeStream())); Y->Add(std::move(*target_tensor)); } return Status::OK(); @@ -130,8 +128,7 @@ ONNX_OPERATOR_KERNEL_EX( AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId device_id, size_t gpu_mem_limit, ArenaExtendStrategy arena_extend_strategy, - ROCMExecutionProviderExternalAllocatorInfo - external_allocator_info, + ROCMExecutionProviderExternalAllocatorInfo external_allocator_info, const OrtArenaCfg* default_memory_arena_cfg) { if (external_allocator_info.UseExternalAllocator()) { AllocatorCreationInfo default_memory_info( @@ -153,8 +150,7 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi device_id, true, {default_memory_arena_cfg ? *default_memory_arena_cfg - : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), - -1, -1, -1, -1L)}, + : OrtArenaCfg(gpu_mem_limit, static_cast(arena_extend_strategy), -1, -1, -1, -1L)}, // make it stream aware true, // enable cross stream sharing? @@ -165,11 +161,8 @@ AllocatorPtr ROCMExecutionProvider::CreateRocmAllocator(OrtDevice::DeviceId devi } } -ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, - size_t /*gpu_mem_limit*/, - ArenaExtendStrategy /*arena_extend_strategy*/, - ROCMExecutionProviderExternalAllocatorInfo - /*external_allocator_info*/, +ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t /*gpu_mem_limit*/, + ArenaExtendStrategy /*arena_extend_strategy*/, ROCMExecutionProviderExternalAllocatorInfo /*external_allocator_info*/, OrtArenaCfg* /*default_memory_arena_cfg*/) { HIP_CALL_THROW(hipSetDevice(device_id)); @@ -187,32 +180,60 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed( + RocmGraphAnnotation_t hip_graph_annotation_id) const { + if (!IsGraphCaptureAllowedOnRun(hip_graph_annotation_id)) { + return false; + } + if (graph_id_to_run_count_.find(hip_graph_annotation_id) == graph_id_to_run_count_.end()) { + return false; + } + return graph_id_to_run_count_.at(hip_graph_annotation_id) >= min_num_runs_before_hip_graph_capture_; } -void ROCMExecutionProvider::PerThreadContext::CaptureBegin(int) { - hip_graph_.Reset(); - hip_graph_.CaptureBegin(0); +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun( + RocmGraphAnnotation_t hip_graph_annotation_id) const { + return hip_graph_.IsGraphCaptureAllowedOnRun(hip_graph_annotation_id); } -void ROCMExecutionProvider::PerThreadContext::CaptureEnd(int) { - hip_graph_.CaptureEnd(0); - is_graph_captured_ = true; +RocmGraphAnnotation_t ROCMExecutionProvider::PerThreadContext::GetRocmGraphAnnotationId( + const onnxruntime::RunOptions& run_options) const { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + // If graph annotation is not provided, fall back to the one hip graph per session behavior + RocmGraphAnnotation_t hip_graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(TryParseStringWithClassicLocale(*graph_annotation_str, hip_graph_annotation_id), + "Failed to parse the hip graph annotation id: ", + *graph_annotation_str); + } + + return hip_graph_annotation_id; } -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(int) const { - return is_graph_captured_; +void ROCMExecutionProvider::PerThreadContext::CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id) { + hip_graph_.CaptureBegin(hip_graph_annotation_id); } -Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(int graph_annotation_id) { - ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); +void ROCMExecutionProvider::PerThreadContext::CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id) { + hip_graph_.CaptureEnd(hip_graph_annotation_id); +} +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const { + return hip_graph_.IsGraphCaptured(graph_annotation_id); +} + +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) { return hip_graph_.Replay(graph_annotation_id); } -void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - ++regular_run_count_before_graph_capture_; +void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture( + RocmGraphAnnotation_t hip_graph_annotation_id) { + if (graph_id_to_run_count_.find(hip_graph_annotation_id) == graph_id_to_run_count_.end()) { + graph_id_to_run_count_[hip_graph_annotation_id] = 1; + return; + } + graph_id_to_run_count_[hip_graph_annotation_id]++; } void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { @@ -237,8 +258,7 @@ void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { } ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, - info.device_id)}, + : IExecutionProvider{onnxruntime::kRocmExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info}, tuning_context_(this, &info_.tunable_op) { HIP_CALL_THROW(hipSetDevice(info_.device_id)); @@ -322,8 +342,7 @@ ROCMExecutionProvider::PerThreadContext& ROCMExecutionProvider::GetPerThreadCont // get or create a context if (context_state_.retired_context_pool.empty()) { context = std::make_shared(info_.device_id, stream_, info_.gpu_mem_limit, - info_.arena_extend_strategy, info_.external_allocator_info, - info_.default_memory_arena_cfg); + info_.arena_extend_strategy, info_.external_allocator_info, info_.default_memory_arena_cfg); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -364,26 +383,28 @@ Status ROCMExecutionProvider::Sync() const { return Status::OK(); } -Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && - !GetPerThreadContext().IsGraphCaptured(0)) { - LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; - GetPerThreadContext().CaptureBegin(0); + RocmGraphAnnotation_t hip_graph_annotation_id = GetPerThreadContext().GetRocmGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(hip_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(hip_graph_annotation_id)) { + LOGS(*GetLogger(), INFO) << "Capturing the hip graph for this model"; + GetPerThreadContext().CaptureBegin(hip_graph_annotation_id); } return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(0)) { - if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(0); +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + RocmGraphAnnotation_t hip_graph_annotation_id = GetPerThreadContext().GetRocmGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(hip_graph_annotation_id)) { + if (GetPerThreadContext().IsGraphCaptureAllowed(hip_graph_annotation_id)) { + GetPerThreadContext().CaptureEnd(hip_graph_annotation_id); // HIP work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(0)); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(hip_graph_annotation_id)); } else { - GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(hip_graph_annotation_id); } } @@ -412,18 +433,19 @@ bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_hip_graph; } -bool ROCMExecutionProvider::IsGraphCaptured(int) const { - return GetPerThreadContext().IsGraphCaptured(0); +bool ROCMExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return GetPerThreadContext().IsGraphCaptured(graph_annotation_id); } -Status ROCMExecutionProvider::ReplayGraph(int /*graph_annotation_id*/) { - return GetPerThreadContext().ReplayGraph(0); +Status ROCMExecutionProvider::ReplayGraph(int graph_annotation_id) { + return GetPerThreadContext().ReplayGraph(graph_annotation_id); } namespace rocm { // opset 1 to 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, float, Cos); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, double, Cos); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, MLFloat16, Cos); @@ -482,8 +504,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, - LogSoftmax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, float, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, double, Pow); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 11, MLFloat16, Pow); @@ -516,32 +537,20 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Greater); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, - LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, - LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, - LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, - LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, int64_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint32_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, uint64_t, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, - LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 15, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int32_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, int64_t, Add); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 12, uint32_t, Add); @@ -597,8 +606,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 10, float, Clip); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Reciprocal); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, - Reciprocal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Reciprocal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, float, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, double, Sqrt); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 12, MLFloat16, Sqrt); @@ -612,18 +620,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, double, Erf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, MLFloat16, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, bool, Not); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, - BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, LRN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, LRN); @@ -631,14 +633,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, - ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, - ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, double, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, - AveragePool); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 9, MLFloat16, AveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalAveragePool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalAveragePool); @@ -651,51 +650,54 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, GlobalMaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 10, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL1); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceL2); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMean); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, uint8_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, int32_t, ReduceProd); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, float, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, double, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, MLFloat16, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, int32_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 12, int64_t, ReduceSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, float, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, double, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 6, 8, MLFloat16, Cast); @@ -720,6 +722,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, bool, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, float, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, double, Pad); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 2, 10, MLFloat16, Pad); @@ -768,7 +771,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, float, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, double, Shrink); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, MLFloat16, Shrink); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 9, 12, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Less); @@ -832,12 +834,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ArgMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ArgMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Compress); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Concat); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Flatten); @@ -851,45 +847,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, NonMaxSuppression); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, Range); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 15, Scan); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, ScatterElements); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 12, int32_t, Slice); @@ -958,7 +915,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom // OpSet 12 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Clip); - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, float, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, double, MaxPool); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, MLFloat16, MaxPool); @@ -967,22 +923,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Pow); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, uint8_t, ReduceMin); - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, int64_t, GatherND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 12, 12, Dropout); @@ -1037,6 +977,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor); @@ -1107,7 +1048,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kO class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Cast); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, bool, Cast); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); @@ -1127,6 +1067,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, U class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, GatherElements); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 19, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); @@ -1142,50 +1083,36 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, double, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, MLFloat16, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int32_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int64_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, int8_t, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL1); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceL2); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceLogSumExp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMean); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceProd); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceProd); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, ReduceSum); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceSumSquare); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize); @@ -1281,16 +1208,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, double, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, MLFloat16, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int32_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, uint8_t, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, int64_t, ReduceMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kRocmExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kRocmExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kRocmExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, ReduceMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMin); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 14, BFloat16, Sub); @@ -1314,6 +1244,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, PRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, Where); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, BFloat16, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double_t, Where); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, int32_t, Where); @@ -1335,6 +1266,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -1343,18 +1275,24 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); // Opset 18 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -1370,52 +1308,81 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint32_t, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, uint64_t, Cast); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, bool, Cast); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E4M3FN, Cast); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Float8E5M2, Cast); +// #endif -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - float, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - float, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - MLFloat16, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, DequantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, DequantizeLinear); +// #endif +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, DequantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, DequantizeLinear); +// #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Identity); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, If); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Loop); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - float, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - float, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, - MLFloat16, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, - MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, float, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, float, QuantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, float, QuantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, float, QuantizeLinear); +// #endif +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, uint8_t, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, 20, int8_t, MLFloat16, QuantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E4M3FN, MLFloat16, QuantizeLinear); +// class ONNX_OPERATOR_VERSIONED_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2, MLFloat16, QuantizeLinear); +// #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Reshape); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); // Opset 20 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsNaN); -// Opset 21 -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, - DequantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, - QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, - QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, - QuantizeLinear); -class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, - QuantizeLinear); +// Opset 21. +// TODO(fajin): support other quantized types +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, DequantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, DequantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, DequantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, DequantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, DequantizeLinear); +// #endif + +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, uint8_t, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, int8_t, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, float, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, UInt4x2, MLFloat16, QuantizeLinear); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 21, Int4x2, MLFloat16, QuantizeLinear); +// #if !defined(DISABLE_FLOAT8_TYPES) +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, float, QuantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, float, QuantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, MLFloat16, QuantizeLinear); +// class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 21, Float8E5M2, MLFloat16, QuantizeLinear); +// #endif template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1428,6 +1395,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1633,51 +1601,51 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1815,15 +1783,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + 19, IsInf)>, // opset 11 - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1837,45 +1802,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1949,22 +1875,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2012,7 +1922,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2020,6 +1929,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2088,6 +1998,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2122,62 +2033,43 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2266,16 +2158,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2299,6 +2187,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2328,23 +2217,30 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // Opset 18 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, @@ -2360,11 +2256,23 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2372,26 +2280,58 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 21 + // TODO(fajin): support other quantized types BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +//#if !defined(DISABLE_FLOAT8_TYPES) +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +// BuildKernelCreateInfo, +//#endif }; for (auto& function_table_entry : function_table) { @@ -2456,6 +2396,9 @@ std::vector> ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup) const { InlinedVector candidates; + // A subset of the above vector. A subset of the tentative_nodes might be moved to CPU. + InlinedVector tentative_nodes; + const logging::Logger& logger = *GetLogger(); for (auto& node_index : graph.GetNodesInTopologicalOrder()) { const auto* p_node = graph.GetNode(node_index); if (p_node == nullptr) @@ -2463,13 +2406,16 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, const auto& node = *p_node; if (!node.GetExecutionProviderType().empty()) { + if (node.GetExecutionProviderType() == kRocmExecutionProvider) { + candidates.push_back(node.Index()); + } continue; } const KernelCreateInfo* rocm_kernel_def = kernel_lookup.LookUpKernel(node); // none of the provided registries has a ROCM kernel for this node if (rocm_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "ROCM kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, INFO) << "ROCM kernel not found in registries for Op type: " << node.OpType() << " node name: " << node.Name(); continue; } @@ -2487,9 +2433,10 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, if (!force_inside && not_supported) { if (not_supported) { - LOGS_DEFAULT(WARNING) << "ROCM kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); + LOGS(logger, WARNING) << "ROCM kernel not supported. Fallback to CPU execution provider for Op type: " << node.OpType() << " node name: " << node.Name(); } } else { + tentative_nodes.push_back(node.Index()); candidates.push_back(node.Index()); } } @@ -2497,7 +2444,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) @@ -2521,7 +2468,8 @@ void ROCMExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_, use_ep_level_unified_stream_, GetPerThreadContext().MiopenHandle(), - GetPerThreadContext().HipblasHandle()); + GetPerThreadContext().HipblasHandle(), + info_); } OrtDevice ROCMExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 7de6ef79fa..3caff88fe9 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -45,6 +45,12 @@ class ROCMExecutionProvider : public IExecutionProvider { return GetPerThreadContext().MiopenHandle(); } + hipStream_t ComputeStream() { + // this will return the ROCM EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return stream_; + } + template const T* GetConstOnes(size_t count, hipStream_t stream) { return GetPerThreadContext().template GetConstOnes(count, stream); @@ -75,8 +81,8 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured(int graph_annotation_id) const override; - Status ReplayGraph(int graph_annotation_id) override; + bool IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const override; + Status ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -98,6 +104,7 @@ class ROCMExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy, ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); hipblasHandle_t HipblasHandle() const { return hipblas_handle_; @@ -138,12 +145,14 @@ class ROCMExecutionProvider : public IExecutionProvider { } } - bool IsGraphCaptureAllowed() const; - void CaptureBegin(int graph_annotation_id); - void CaptureEnd(int graph_annotation_id); - bool IsGraphCaptured(int graph_annotation_id) const; - Status ReplayGraph(int graph_annotation_id); - void IncrementRegularRunCountBeforeGraphCapture(); + bool IsGraphCaptureAllowed(RocmGraphAnnotation_t hip_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(RocmGraphAnnotation_t hip_graph_annotation_id) const; + void CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id); + void CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id); + bool IsGraphCaptured(RocmGraphAnnotation_t hip_graph_annotation_id) const; + RocmGraphAnnotation_t GetRocmGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + Status ReplayGraph(RocmGraphAnnotation_t hip_graph_annotation_id); + void IncrementRegularRunCountBeforeGraphCapture(RocmGraphAnnotation_t hip_graph_annotation_id); private: hipblasHandle_t hipblas_handle_ = nullptr; @@ -157,8 +166,8 @@ class ROCMExecutionProvider : public IExecutionProvider { // Hip graph with multi threads will be supported in the future, so hip_graph_ // is put under PerThreadContext. ROCMGraph hip_graph_; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; + // Map of graph id to regular_run_count_before_graph_capture + std::unordered_map graph_id_to_run_count_; // There is chance that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index 7276299563..933a72122e 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -97,14 +97,14 @@ class RocmKernel : public OpKernel { return stream->hipblas_handle_; } - tunable::RocmTuningContext* GetTuningContext() const { - return static_cast(provider_->GetTuningContext()); - } - bool UseTF32() const { return false; } + tunable::RocmTuningContext* GetTuningContext() const { + return static_cast(provider_->GetTuningContext()); + } + // To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory // and it can only be released after the copy has finished template @@ -177,6 +177,12 @@ class RocmKernel : public OpKernel { return provider_->PerThreadDefaultMiopenHandle(); } + inline hipStream_t DefaultHipStream() const { + // this will return the ROCM EP level stream which can differ from the actual compute tasks stream + // the compute task stream is supplied within OpKernelContext during inference + return provider_->ComputeStream(); + } + inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const { auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device); return gpu_data_transfer->CopyTensorAsync(src, dst, stream); diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index fdf64d07e0..170a566d85 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -185,7 +185,7 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; - info.enable_hip_graph = params->enable_hip_graph; + info.enable_hip_graph = params->enable_hip_graph != 0; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index c175252df3..bbd1e1befc 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(RocmStream& rocm_stream) : rocm_stream_(rocm_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->rocm_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->rocm_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->rocm_stream_.GetCpuAllocator()->Info(); + }; +} + struct RocmNotification : public synchronize::Notification { RocmNotification(Stream& s) : Notification(s) { HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); @@ -25,7 +44,8 @@ struct RocmNotification : public synchronize::Notification { void wait_on_device(Stream& device_stream) { ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); // launch a wait command to the rocm stream - HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), + event_, 0)); }; void wait_on_host() { @@ -42,10 +62,13 @@ RocmStream::RocmStream(hipStream_t stream, bool release_cpu_buffer_on_rocm_stream, bool own_flag, miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle) : Stream(stream, device), - own_stream_(own_flag), - cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream) { + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info) : Stream(stream, device), + own_stream_(own_flag), + cpu_allocator_(cpu_allocator), + release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream), + deferred_cpu_allocator_(*this), + ep_info_(ep_info) { if (own_flag) { HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_)); HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream)); @@ -152,6 +175,16 @@ void* RocmStream::GetResource(int version, int id) const { case RocmResource::hipblas_handle_t: return reinterpret_cast(hipblas_handle_); break; + case RocmResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; + case RocmResource::device_id_t: + return reinterpret_cast(ep_info_.device_id); + break; + case RocmResource::arena_extend_strategy_t: + return reinterpret_cast(ep_info_.arena_extend_strategy); + break; + break; default: break; } @@ -174,25 +207,28 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis hipStream_t external_stream, bool use_existing_stream, miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle) { + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info) { // wait rocm notification on rocm ep stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitRocmNotificationOnDevice); // wait rocm notification on cpu ep stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitRocmNotificationOnHost); if (!use_existing_stream) - stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream](const OrtDevice& device) { + stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream, ep_info](const OrtDevice& device) { HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); - return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr); + // HIP_CALL_THROW(hipStreamCreate(&stream)); + return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr, ep_info); }); else stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream, external_stream, external_miopen_handle, - external_hipblas_handle](const OrtDevice& device) { - return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_hipblas_handle); + external_hipblas_handle, + ep_info](const OrtDevice& device) { + return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_hipblas_handle, ep_info); }); } diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.h b/onnxruntime/core/providers/rocm/rocm_stream_handle.h index 98b8fa8556..320fb4661e 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.h +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.h @@ -3,13 +3,21 @@ #pragma once #include "core/providers/rocm/rocm_pch.h" -// #include "core/providers/cuda/shared_inc/cuda_utils.h" +// #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" #include "core/framework/stream_handles.h" +#include "core/providers/rocm/rocm_execution_provider_info.h" namespace onnxruntime { + +struct RocmStream; void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification); +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(RocmStream&); + RocmStream& rocm_stream_; +}; + struct RocmStream : Stream { RocmStream(hipStream_t stream, const OrtDevice& device, @@ -17,7 +25,8 @@ struct RocmStream : Stream { bool release_cpu_buffer_on_rocm_stream, bool own_flag, miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle); + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info); ~RocmStream(); @@ -37,12 +46,16 @@ struct RocmStream : Stream { void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; } private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_rocm_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; + const ROCMExecutionProviderInfo ep_info_; }; void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, @@ -52,5 +65,6 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis hipStream_t external_stream, bool use_existing_stream, miopenHandle_t external_miopen_handle, - hipblasHandle_t external_hipblas_handle); + hipblasHandle_t external_hipblas_handle, + const ROCMExecutionProviderInfo& ep_info); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h b/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h index 580f465c49..95fa4f37d7 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tunable.h @@ -4,7 +4,6 @@ #pragma once #include -#include #include "core/providers/rocm/rocm_common.h" // avoid provider_api.h ODR violation #include "core/framework/tunable.h" @@ -22,7 +21,6 @@ template using Op = Op; class Timer; - template using TunableOp = TunableOp; diff --git a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc index 05cdc82e90..88e5fde189 100644 --- a/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc +++ b/onnxruntime/core/providers/rocm/tunable/rocm_tuning_context.cc @@ -42,26 +42,6 @@ static Status ValidateRocBlasVersion(const std::string& value) { return Status::OK(); } -std::string RocmTuningResultsValidator::GetDeviceModel() const { - return ep_->GetDeviceProp().name; -} - -Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { - auto current = GetDeviceModel(); - ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, - ", onnxruntime currently run with device ", current); - return Status::OK(); -} - -RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { - RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); - RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); - RegisterValidator( - "DEVICE_MODEL", - [this]() { return GetDeviceModel(); }, - [this](const std::string& value) { return ValidateDeviceModel(value); }); -} - std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { std::ostringstream oss; #ifdef USE_COMPOSABLE_KERNEL @@ -87,6 +67,26 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const { return oss.str(); } +std::string RocmTuningResultsValidator::GetDeviceModel() const { + return ep_->GetDeviceProp().name; +} + +Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const { + auto current = GetDeviceModel(); + ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value, + ", onnxruntime currently run with device ", current); + return Status::OK(); +} + +RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} { + RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion); + RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion); + RegisterValidator( + "DEVICE_MODEL", + [this]() { return GetDeviceModel(); }, + [this](const std::string& value) { return ValidateDeviceModel(value); }); +} + RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info) : ITuningContext(ep), info_(info), validator_(ep) {} diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6782215fcd..0be1c0b196 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -335,6 +335,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod #endif } else if (provider_type == 3) { #ifdef USE_ROCM + std::cout << "Running simple inference with rocm provider" << std::endl; OrtROCMProviderOptions rocm_options; session_options.AppendExecutionProvider_ROCM(rocm_options); #else @@ -384,7 +385,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); -#if defined(USE_CUDA) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx"); #endif static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx"); @@ -2341,7 +2342,7 @@ TEST(CApiTest, basic_cuda_graph) { #endif } -#if defined(USE_CUDA) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) struct CudaGraphInputOutputData_0 { const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; @@ -2385,6 +2386,12 @@ static void RunWithCudaGraphAnnotation(T& cg_data, Ort::MemoryAllocation& input_data, Ort::MemoryAllocation& output_data, const char* cuda_graph_annotation) { +// a local hipify of select cuda symbols to avoid code duplication +#ifdef USE_ROCM +#define cudaMemcpy hipMemcpy +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#endif #ifdef USE_DML Ort::SessionOptions session_options; Ort::Allocator allocator(session, info_mem); @@ -2488,6 +2495,11 @@ static void RunWithCudaGraphAnnotation(T& cg_data, // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); +#ifdef USE_ROCM +#undef cudaMemcpy +#undef cudaMemcpyHostToDevice +#undef cudaMemcpyDeviceToHost +#endif } TEST(CApiTest, basic_cuda_graph_with_annotation) { @@ -2502,7 +2514,7 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { ort_dml_api->SessionOptionsAppendExecutionProvider_DML1(session_options, dml_objects.dml_device.Get(), dml_objects.command_queue.Get()); Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); -#else +#elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -2516,6 +2528,20 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) { static_cast(session_options), rel_cuda_options.get()) == nullptr); Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#elif defined(USE_ROCM) + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); + Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); #endif Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 07167b0a61..ff246503e8 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -21,7 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn") s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut") s = s.replace("kTotalCudaStreams", "kTotalHipStreams") - + # these should be "hip" but it's easier to just use rocm to avoid complicated file renaming + s = s.replace("CudaGraph", "RocmGraph") + s = s.replace("CUDAGraph", "ROCMGraph") + s = s.replace("cuda_graph", "rocm_graph") s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels") s = s.replace("cudaEvent", "hipEvent") s = s.replace("CreateCudaAllocator", "CreateRocmAllocator")