[ROCm] redo hipify of version controlled files (#22449)

### Description
Updates the ROCm EP opsets to match the current CUDA EP opsets. Also
enable the test CApiTest.basic_cuda_graph_with_annotation.

Note that some changes are whitespace-only. These changes were made to
improve the comparison of corresponding ROCm and CUDA EP source files
when using a side by side diff tool.

### Motivation and Context
The ROCm EP derives from the CUDA EP. Many source files are shared
between the EPs and "hipified" during the ROCm EP build, however quite a
few files within the ROCm EP are under source control after their
initial hipification. Over time these ROCm EP files get stale relative
to their CUDA EP counterparts. It becomes necessary to re-hipify these
otherwise static files in order to pick up important changes such as
opset differences.
This commit is contained in:
Jeff Daily 2024-10-18 12:40:54 -07:00 committed by GitHub
parent d2a5ee2e5e
commit 5aabc53121
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 754 additions and 722 deletions

View file

@ -157,10 +157,6 @@ set(provider_excluded_files
"cuda_execution_provider_info.h"
"cuda_execution_provider.cc"
"cuda_execution_provider.h"
"cuda_memory_check.cc"
"cuda_memory_check.h"
"cuda_fence.cc"
"cuda_fence.h"
"cuda_kernel.h"
"cuda_pch.cc"
"cuda_pch.h"

View file

@ -8,5 +8,9 @@
enum RocmResource : int {
hip_stream_t = rocm_resource_offset,
miopen_handle_t,
hipblas_handle_t
hipblas_handle_t,
deferred_cpu_allocator_t,
// below are rocm ep options
device_id_t, // 10004
arena_extend_strategy_t
};

View file

@ -5,9 +5,12 @@
#include <stdint.h>
#include <vector>
#include <mutex>
#include <limits>
#include <assert.h>
#include <math.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <hip/hip_bf16.h>
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
@ -242,12 +245,63 @@ __device__ __inline__ double _Pow(double a, double b) { return pow(a, b); }
template <>
__device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (float)b)); }
#define ISNAN_BFLOAT16(v__) static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&v__) & ~BFloat16::kSignMask) \
> BFloat16::kPositiveInfinityBits
// Note that there is no consistent canonical NaN for FP16 and BF16;
// HIP uses 0x7FFF for HIPRT_NAN_BF16, but ONNX Runtime uses 0x7FC1.
// (see BFloat16Impl::kPositiveQNaNBits).
#define NAN_BFLOAT16 BFloat16::FromBits((uint16_t)0x7FFFU)
template <typename T>
__device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; }
template <>
__device__ __inline__ float _Min(float a, float b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a < b ? a : b );
}
template <>
__device__ __inline__ double _Min(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a < b ? a : b );
}
template <>
__device__ __inline__ half _Min(half a, half b) {
return __hmin_nan(a, b);
}
template <>
__device__ __inline__ BFloat16 _Min(BFloat16 a, BFloat16 b) {
return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a < b ? a : b);
}
template <typename T>
__device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template <>
__device__ __inline__ float _Max(float a, float b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a > b ? a : b );
}
template <>
__device__ __inline__ double _Max(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a > b ? a : b );
}
template <>
__device__ __inline__ half _Max(half a, half b) {
return __hmax_nan(a, b);
}
template <>
__device__ __inline__ BFloat16 _Max(BFloat16 a, BFloat16 b) {
return (ISNAN_BFLOAT16(a) || ISNAN_BFLOAT16(b)) ? NAN_BFLOAT16 : (a > b ? a : b);
}
#undef ISNAN_BFLOAT16
#undef NAN_BFLOAT16
template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
@ -443,36 +497,36 @@ struct _IsNan {
template <>
struct _IsNan<half> {
__device__ __inline__ bool operator()(half a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
> MLFloat16::kPositiveInfinityBits;
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~MLFloat16::kSignMask)
> MLFloat16::kPositiveInfinityBits;
}
};
template <>
struct _IsNan<BFloat16> {
__device__ __inline__ bool operator()(BFloat16 a) const {
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
> BFloat16::kPositiveInfinityBits;
return static_cast<uint16_t>(*reinterpret_cast<const uint16_t*>(&a) & ~BFloat16::kSignMask)
> BFloat16::kPositiveInfinityBits;
}
};
#if !defined(DISABLE_FLOAT8_TYPES)
template <>
template<>
struct _IsNan<Float8E4M3FN> {
__device__ __inline__ bool operator()(Float8E4M3FN a) const {
return (*reinterpret_cast<const uint8_t*>(&a) & 0x7f) == 0x7f;
}
};
template <>
template<>
struct _IsNan<Float8E4M3FNUZ> {
__device__ __inline__ bool operator()(Float8E4M3FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;
}
};
template <>
template<>
struct _IsNan<Float8E5M2> {
__device__ __inline__ bool operator()(Float8E5M2 a) const {
uint8_t c = *reinterpret_cast<const uint8_t*>(&a);
@ -480,7 +534,7 @@ struct _IsNan<Float8E5M2> {
}
};
template <>
template<>
struct _IsNan<Float8E5M2FNUZ> {
__device__ __inline__ bool operator()(Float8E5M2FNUZ a) const {
return *reinterpret_cast<const uint8_t*>(&a) == 0x80;

View file

@ -1,4 +1,3 @@
#include "hip/hip_runtime.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

View file

@ -16,140 +16,29 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {
// opset 11 explicitly added support for negative axis. implementation already allowed it.
#define REGISTER_KERNEL_TYPED(name, T) \
#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, 10, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
11, 12, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
13, \
1, end, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
#define REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, 10, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
11, 11, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
12, 12, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
version, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()).InputMemoryType(OrtMemTypeCPUInput, 1), \
name<T>);
// Register those with changes in OpSet12.
#define REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(name, T) \
REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
13, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
#define REGISTER_KERNEL_VERSIONED_TYPED_13(name, T) \
REGISTER_KERNEL_VERSIONED_TYPED_12(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
13, 13, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// Register ReduceMin int64_t support in OpSet14.
#define REGISTER_KERNEL_TYPED_14(name, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
14, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
#define REGISTER_KERNEL_VERSIONED_TYPED_11(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, 10, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
11, 11, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
// Register with the latest version 13
#define REGISTER_KERNEL_TYPED_13(name, T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, 10, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
11, 12, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>); \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
13, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 1) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
@ -348,7 +237,9 @@ Status ReduceKernel<allow_multi_axes>::ReduceKernelShared(
// double* Y,
// const TensorShape& output_shape,
// miopenReduceTensorOp_t miopen_reduce_op,
// std::vector<int64_t>& output_dims) const;
// miopenHandle_t miopen_handle,
// onnxruntime::Stream* stream,
// TensorShapeVector& output_dims) const;
template Status ReduceKernel<true>::ReduceKernelShared<float, float, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
const float* X,
@ -387,7 +278,7 @@ Status PrepareForReduce(const Tensor* X,
}
const auto input_dims = input_shape.GetDims();
InlinedVector<bool> reduced(rank, false);
std::vector<bool> reduced(rank, false);
if (axes.size() > 0) {
prepare_reduce_metadata.output_dims = input_shape.AsShapeVector();
for (auto axis : axes) {
@ -724,11 +615,35 @@ Status ReduceComputeCore(const AllocatorPtr& gpu_allocator, const Tensor& input,
return Status::OK();
}
template Status ReduceComputeCore<float, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
/*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op,
gsl::span<const int64_t> axes,
bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
Stream* ort_stream,
const TensorShape* input_shape_override);
// template Status ReduceComputeCore<double, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
// const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
// /*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op,
// gsl::span<const int64_t> axes,
// bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
// Stream* ort_stream,
// const TensorShape* input_shape_override);
template Status ReduceComputeCore<MLFloat16, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, const Tensor& input, PrepareReduceMetadata& prepare_reduce_metadata,
/*out*/ Tensor& output, miopenReduceTensorOp_t miopen_reduce_op,
gsl::span<const int64_t> axes,
bool calculate_log, bool calculate_sqt, bool log_sum_exp, bool fast_reduction,
Stream* ort_stream,
const TensorShape* input_shape_override);
template <bool allow_multi_axes>
template <typename T, miopenReduceTensorIndices_t ReduceTensorIndices>
Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, miopenReduceTensorOp_t miopen_reduce_op) const {
const Tensor* X = ctx->Input<Tensor>(0);
std::vector<int64_t> axes;
TensorShapeVector axes;
size_t num_inputs = ctx->InputCount();
const Tensor* axes_tensor = num_inputs == 2 ? ctx->Input<Tensor>(1) : nullptr; // optional input. may be nullptr.
@ -904,7 +819,7 @@ template std::unique_ptr<Tensor> ReduceCompute<float, MIOPEN_REDUCE_TENSOR_NO_IN
// AllocatorPtr allocator,
// const Tensor& input, gsl::span<const int64_t> axes,
// bool keep_dims, bool calculate_log, bool calculate_sqt, bool log_sum_exp,
// bool fast_reduction, const TensorShape* input_shape_override);
// bool fast_reduction, Stream* stream, const TensorShape* input_shape_override);
template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_NO_INDICES>(
const AllocatorPtr& gpu_allocator, miopenReduceTensorOp_t miopen_reduce_op,
@ -915,69 +830,76 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N
} // namespace ReductionOps
#define REGISTER_KERNEL_HFD(name) \
REGISTER_KERNEL_TYPED(name, MLFloat16) \
REGISTER_KERNEL_TYPED(name, float) \
REGISTER_KERNEL_TYPED(name, BFloat16)
// REGISTER_KERNEL_TYPED(name, double)
// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
#define REGISTER_KERNEL_HFD_VERSIONED_11(name) \
REGISTER_KERNEL_VERSIONED_TYPED_11(name, MLFloat16) \
REGISTER_KERNEL_VERSIONED_TYPED_11(name, float)
// REGISTER_KERNEL_VERSIONED_TYPED_11(name, double)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_HFD_VERSIONED_11(ArgMax)
REGISTER_KERNEL_HFD_VERSIONED_11(ArgMin)
REGISTER_KERNEL_HFD(ReduceL1)
REGISTER_KERNEL_HFD(ReduceL2)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int32_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int64_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, int8_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, uint8_t, 17, 18)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, MLFloat16)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, float)
// REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, double)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int32_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int64_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, int8_t)
REGISTER_KERNEL_TYPED_13_WITH_VERSIONED_12(ReduceMax, uint8_t)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMean, int32_t, 17, 18)
REGISTER_KERNEL_HFD(ReduceMean)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int32_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int64_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, int8_t, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMin, uint8_t, 17, 18)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, MLFloat16)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, float)
// REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, double)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int32_t)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int64_t)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, int8_t)
REGISTER_KERNEL_VERSIONED_TYPED_13(ReduceMin, uint8_t)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceProd, int32_t, 17, 18)
REGISTER_KERNEL_TYPED_14(ReduceMin, MLFloat16)
REGISTER_KERNEL_TYPED_14(ReduceMin, float)
// REGISTER_KERNEL_TYPED_14(ReduceMin, double)
REGISTER_KERNEL_TYPED_14(ReduceMin, int32_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, int8_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, uint8_t)
REGISTER_KERNEL_TYPED_14(ReduceMin, int64_t)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, MLFloat16, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, float, 12, 13)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, double, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, int32_t, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, int64_t, 12, 13)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSum, BFloat16, 12, 13)
REGISTER_KERNEL_HFD(ReduceProd)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSum, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_13(ReduceSum, MLFloat16)
REGISTER_KERNEL_TYPED_13(ReduceSum, float)
// REGISTER_KERNEL_TYPED_13(ReduceSum, double)
REGISTER_KERNEL_TYPED_13(ReduceSum, int32_t)
REGISTER_KERNEL_TYPED_13(ReduceSum, int64_t)
REGISTER_KERNEL_TYPED_13(ReduceSum, BFloat16)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceSumSquare, BFloat16, 17, 18)
REGISTER_KERNEL_HFD(ReduceLogSum)
REGISTER_KERNEL_HFD(ReduceSumSquare)
REGISTER_KERNEL_HFD(ReduceLogSumExp)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceLogSumExp, BFloat16, 17, 18)
#define REGISTER_KERNEL_INT32(name) \
REGISTER_KERNEL_TYPED(name, int32_t)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL1, int32_t, 17, 18)
REGISTER_KERNEL_INT32(ReduceL1)
REGISTER_KERNEL_INT32(ReduceL2)
REGISTER_KERNEL_INT32(ReduceMean)
REGISTER_KERNEL_INT32(ReduceProd)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, float, 17, 18)
// REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, double, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, BFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceL2, int32_t, 17, 18)
} // namespace rocm
} // namespace onnxruntime

View file

@ -33,7 +33,6 @@ const char* RocmErrString<hipError_t>(hipError_t x) {
template <>
const char* RocmErrString<rocblas_status>(rocblas_status e) {
ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard
switch (e) {
CASE_ENUM_TO_STR(rocblas_status_success);
CASE_ENUM_TO_STR(rocblas_status_invalid_handle);
@ -53,6 +52,24 @@ const char* RocmErrString<rocblas_status>(rocblas_status e) {
}
}
template <>
const char* RocmErrString<hipblasStatus_t>(hipblasStatus_t e) {
ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard
switch (e) {
CASE_ENUM_TO_STR(HIPBLAS_STATUS_SUCCESS);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_INITIALIZED);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_ALLOC_FAILED);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_INVALID_VALUE);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_ARCH_MISMATCH);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_MAPPING_ERROR);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_EXECUTION_FAILED);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_INTERNAL_ERROR);
CASE_ENUM_TO_STR(HIPBLAS_STATUS_NOT_SUPPORTED);
default:
return "(look for HIPBLAS_STATUS_xxx in hipblas_api.h)";
}
}
template <>
const char* RocmErrString<hiprandStatus_t>(hiprandStatus_t) {
ORT_IGNORE_RETURN_VALUE(hipDeviceSynchronize()); // void to silence nodiscard
@ -76,7 +93,7 @@ const char* RocmErrString<hipfftResult>(hipfftResult e) {
CASE_ENUM_TO_STR(HIPFFT_SETUP_FAILED);
CASE_ENUM_TO_STR(HIPFFT_INVALID_SIZE);
default:
return "Unknown cufft error status";
return "Unknown hipfft error status";
}
}

File diff suppressed because it is too large Load diff

View file

@ -45,6 +45,12 @@ class ROCMExecutionProvider : public IExecutionProvider {
return GetPerThreadContext().MiopenHandle();
}
hipStream_t ComputeStream() {
// this will return the ROCM EP level stream which can differ from the actual compute tasks stream
// the compute task stream is supplied within OpKernelContext during inference
return stream_;
}
template <typename T>
const T* GetConstOnes(size_t count, hipStream_t stream) {
return GetPerThreadContext().template GetConstOnes<T>(count, stream);
@ -75,8 +81,8 @@ class ROCMExecutionProvider : public IExecutionProvider {
std::unique_ptr<profiling::EpProfiler> GetProfiler() override;
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
bool IsGraphCaptured(RocmGraphAnnotation_t graph_annotation_id) const override;
Status ReplayGraph(RocmGraphAnnotation_t graph_annotation_id) override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
@ -98,6 +104,7 @@ class ROCMExecutionProvider : public IExecutionProvider {
PerThreadContext(OrtDevice::DeviceId device_id, hipStream_t stream, size_t rocm_mem_limit, ArenaExtendStrategy arena_extend_strategy,
ROCMExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg);
~PerThreadContext();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext);
hipblasHandle_t HipblasHandle() const {
return hipblas_handle_;
@ -138,12 +145,14 @@ class ROCMExecutionProvider : public IExecutionProvider {
}
}
bool IsGraphCaptureAllowed() const;
void CaptureBegin(int graph_annotation_id);
void CaptureEnd(int graph_annotation_id);
bool IsGraphCaptured(int graph_annotation_id) const;
Status ReplayGraph(int graph_annotation_id);
void IncrementRegularRunCountBeforeGraphCapture();
bool IsGraphCaptureAllowed(RocmGraphAnnotation_t hip_graph_annotation_id) const;
bool IsGraphCaptureAllowedOnRun(RocmGraphAnnotation_t hip_graph_annotation_id) const;
void CaptureBegin(RocmGraphAnnotation_t hip_graph_annotation_id);
void CaptureEnd(RocmGraphAnnotation_t hip_graph_annotation_id);
bool IsGraphCaptured(RocmGraphAnnotation_t hip_graph_annotation_id) const;
RocmGraphAnnotation_t GetRocmGraphAnnotationId(const onnxruntime::RunOptions& run_options) const;
Status ReplayGraph(RocmGraphAnnotation_t hip_graph_annotation_id);
void IncrementRegularRunCountBeforeGraphCapture(RocmGraphAnnotation_t hip_graph_annotation_id);
private:
hipblasHandle_t hipblas_handle_ = nullptr;
@ -157,8 +166,8 @@ class ROCMExecutionProvider : public IExecutionProvider {
// Hip graph with multi threads will be supported in the future, so hip_graph_
// is put under PerThreadContext.
ROCMGraph hip_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
// Map of graph id to regular_run_count_before_graph_capture
std::unordered_map<RocmGraphAnnotation_t, int> graph_id_to_run_count_;
// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.

View file

@ -97,14 +97,14 @@ class RocmKernel : public OpKernel {
return stream->hipblas_handle_;
}
tunable::RocmTuningContext* GetTuningContext() const {
return static_cast<tunable::RocmTuningContext*>(provider_->GetTuningContext());
}
bool UseTF32() const {
return false;
}
tunable::RocmTuningContext* GetTuningContext() const {
return static_cast<tunable::RocmTuningContext*>(provider_->GetTuningContext());
}
// To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory
// and it can only be released after the copy has finished
template <typename T>
@ -177,6 +177,12 @@ class RocmKernel : public OpKernel {
return provider_->PerThreadDefaultMiopenHandle();
}
inline hipStream_t DefaultHipStream() const {
// this will return the ROCM EP level stream which can differ from the actual compute tasks stream
// the compute task stream is supplied within OpKernelContext during inference
return provider_->ComputeStream();
}
inline Status CopyTensor(const Tensor& src, Tensor& dst, onnxruntime::Stream& stream) const {
auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(src.Location().device, dst.Location().device);
return gpu_data_transfer->CopyTensorAsync(src, dst, stream);

View file

@ -185,7 +185,7 @@ struct ROCM_Provider : Provider {
info.has_user_compute_stream = params->has_user_compute_stream != 0;
info.user_compute_stream = params->user_compute_stream;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
info.enable_hip_graph = params->enable_hip_graph;
info.enable_hip_graph = params->enable_hip_graph != 0;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms;

View file

@ -7,6 +7,25 @@
namespace onnxruntime {
DeferredCpuAllocator::DeferredCpuAllocator(RocmStream& rocm_stream) : rocm_stream_(rocm_stream) {
OrtAllocator::version = ORT_API_VERSION;
OrtAllocator::Alloc =
[](OrtAllocator* this_, size_t size) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
return self->rocm_stream_.GetCpuAllocator()->Alloc(size);
};
OrtAllocator::Free =
[](OrtAllocator* this_, void* p) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
self->rocm_stream_.EnqueDeferredCPUBuffer(p);
};
OrtAllocator::Info =
[](const OrtAllocator* this_) {
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
return &self->rocm_stream_.GetCpuAllocator()->Info();
};
}
struct RocmNotification : public synchronize::Notification {
RocmNotification(Stream& s) : Notification(s) {
HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming));
@ -25,7 +44,8 @@ struct RocmNotification : public synchronize::Notification {
void wait_on_device(Stream& device_stream) {
ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString());
// launch a wait command to the rocm stream
HIP_CALL_THROW(hipStreamWaitEvent(static_cast<hipStream_t>(device_stream.GetHandle()), event_, 0));
HIP_CALL_THROW(hipStreamWaitEvent(static_cast<hipStream_t>(device_stream.GetHandle()),
event_, 0));
};
void wait_on_host() {
@ -42,10 +62,13 @@ RocmStream::RocmStream(hipStream_t stream,
bool release_cpu_buffer_on_rocm_stream,
bool own_flag,
miopenHandle_t external_miopen_handle,
hipblasHandle_t external_hipblas_handle) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream) {
hipblasHandle_t external_hipblas_handle,
const ROCMExecutionProviderInfo& ep_info) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_rocm_stream_(release_cpu_buffer_on_rocm_stream),
deferred_cpu_allocator_(*this),
ep_info_(ep_info) {
if (own_flag) {
HIPBLAS_CALL_THROW(hipblasCreate(&hipblas_handle_));
HIPBLAS_CALL_THROW(hipblasSetStream(hipblas_handle_, stream));
@ -152,6 +175,16 @@ void* RocmStream::GetResource(int version, int id) const {
case RocmResource::hipblas_handle_t:
return reinterpret_cast<void*>(hipblas_handle_);
break;
case RocmResource::deferred_cpu_allocator_t:
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
break;
case RocmResource::device_id_t:
return reinterpret_cast<void*>(ep_info_.device_id);
break;
case RocmResource::arena_extend_strategy_t:
return reinterpret_cast<void*>(ep_info_.arena_extend_strategy);
break;
break;
default:
break;
}
@ -174,25 +207,28 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
hipStream_t external_stream,
bool use_existing_stream,
miopenHandle_t external_miopen_handle,
hipblasHandle_t external_hipblas_handle) {
hipblasHandle_t external_hipblas_handle,
const ROCMExecutionProviderInfo& ep_info) {
// wait rocm notification on rocm ep
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitRocmNotificationOnDevice);
// wait rocm notification on cpu ep
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitRocmNotificationOnHost);
if (!use_existing_stream)
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream](const OrtDevice& device) {
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream, ep_info](const OrtDevice& device) {
HIP_CALL_THROW(hipSetDevice(device.Id()));
hipStream_t stream = nullptr;
HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
return std::make_unique<RocmStream>(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr);
// HIP_CALL_THROW(hipStreamCreate(&stream));
return std::make_unique<RocmStream>(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr, ep_info);
});
else
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator,
release_cpu_buffer_on_rocm_stream,
external_stream,
external_miopen_handle,
external_hipblas_handle](const OrtDevice& device) {
return std::make_unique<RocmStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_hipblas_handle);
external_hipblas_handle,
ep_info](const OrtDevice& device) {
return std::make_unique<RocmStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, false, external_miopen_handle, external_hipblas_handle, ep_info);
});
}

View file

@ -3,13 +3,21 @@
#pragma once
#include "core/providers/rocm/rocm_pch.h"
// #include "core/providers/cuda/shared_inc/cuda_utils.h"
// #include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
#include "core/framework/stream_handles.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
namespace onnxruntime {
struct RocmStream;
void WaitRocmNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
struct DeferredCpuAllocator : public OrtAllocator {
DeferredCpuAllocator(RocmStream&);
RocmStream& rocm_stream_;
};
struct RocmStream : Stream {
RocmStream(hipStream_t stream,
const OrtDevice& device,
@ -17,7 +25,8 @@ struct RocmStream : Stream {
bool release_cpu_buffer_on_rocm_stream,
bool own_flag,
miopenHandle_t external_miopen_handle,
hipblasHandle_t external_hipblas_handle);
hipblasHandle_t external_hipblas_handle,
const ROCMExecutionProviderInfo& ep_info);
~RocmStream();
@ -37,12 +46,16 @@ struct RocmStream : Stream {
void* GetResource(int version, int id) const override;
onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }
WaitNotificationFn GetWaitNotificationFn() const override { return WaitRocmNotificationOnDevice; }
private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_rocm_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
const ROCMExecutionProviderInfo ep_info_;
};
void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
@ -52,5 +65,6 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
hipStream_t external_stream,
bool use_existing_stream,
miopenHandle_t external_miopen_handle,
hipblasHandle_t external_hipblas_handle);
hipblasHandle_t external_hipblas_handle,
const ROCMExecutionProviderInfo& ep_info);
} // namespace onnxruntime

View file

@ -4,7 +4,6 @@
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include "core/providers/rocm/rocm_common.h" // avoid provider_api.h ODR violation
#include "core/framework/tunable.h"
@ -22,7 +21,6 @@ template <typename ParamsT>
using Op = Op<ParamsT>;
class Timer;
template <typename ParamsT>
using TunableOp = TunableOp<ParamsT, Timer>;

View file

@ -42,26 +42,6 @@ static Status ValidateRocBlasVersion(const std::string& value) {
return Status::OK();
}
std::string RocmTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}
Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}
RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} {
RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion);
RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}
std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
std::ostringstream oss;
#ifdef USE_COMPOSABLE_KERNEL
@ -87,6 +67,26 @@ std::string RocmTuningResultsValidator::GetOrtBuildConfig() const {
return oss.str();
}
std::string RocmTuningResultsValidator::GetDeviceModel() const {
return ep_->GetDeviceProp().name;
}
Status RocmTuningResultsValidator::ValidateDeviceModel(const std::string& value) const {
auto current = GetDeviceModel();
ORT_RETURN_IF(current != value, "Device model mismatch: tuning results produced with device ", value,
", onnxruntime currently run with device ", current);
return Status::OK();
}
RocmTuningResultsValidator::RocmTuningResultsValidator(ROCMExecutionProvider* ep) : ep_{ep} {
RegisterValidator("HIP_VERSION", GetHipVersion, ValidateHipVersion);
RegisterValidator("ROCBLAS_VERSION", GetRocBlasVersion, ValidateRocBlasVersion);
RegisterValidator(
"DEVICE_MODEL",
[this]() { return GetDeviceModel(); },
[this](const std::string& value) { return ValidateDeviceModel(value); });
}
RocmTuningContext::RocmTuningContext(ROCMExecutionProvider* ep, TunableOpInfo* info)
: ITuningContext(ep), info_(info), validator_(ep) {}

View file

@ -335,6 +335,7 @@ static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& mod
#endif
} else if (provider_type == 3) {
#ifdef USE_ROCM
std::cout << "Running simple inference with rocm provider" << std::endl;
OrtROCMProviderOptions rocm_options;
session_options.AppendExecutionProvider_ROCM(rocm_options);
#else
@ -384,7 +385,7 @@ static void TestInference(Ort::Env& env, const std::basic_string<ORTCHAR_T>& mod
}
static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx");
#if defined(USE_CUDA) || defined(USE_DML)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx");
#endif
static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx");
@ -2341,7 +2342,7 @@ TEST(CApiTest, basic_cuda_graph) {
#endif
}
#if defined(USE_CUDA) || defined(USE_DML)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
struct CudaGraphInputOutputData_0 {
const std::array<int64_t, 2> x_shape = {3, 2};
std::array<float, 3 * 2> x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
@ -2385,6 +2386,12 @@ static void RunWithCudaGraphAnnotation(T& cg_data,
Ort::MemoryAllocation& input_data,
Ort::MemoryAllocation& output_data,
const char* cuda_graph_annotation) {
// a local hipify of select cuda symbols to avoid code duplication
#ifdef USE_ROCM
#define cudaMemcpy hipMemcpy
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#endif
#ifdef USE_DML
Ort::SessionOptions session_options;
Ort::Allocator allocator(session, info_mem);
@ -2488,6 +2495,11 @@ static void RunWithCudaGraphAnnotation(T& cg_data,
// Clean up
binding.ClearBoundInputs();
binding.ClearBoundOutputs();
#ifdef USE_ROCM
#undef cudaMemcpy
#undef cudaMemcpyHostToDevice
#undef cudaMemcpyDeviceToHost
#endif
}
TEST(CApiTest, basic_cuda_graph_with_annotation) {
@ -2502,7 +2514,7 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) {
ort_dml_api->SessionOptionsAppendExecutionProvider_DML1(session_options, dml_objects.dml_device.Get(), dml_objects.command_queue.Get());
Ort::MemoryInfo info_mem("DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault);
#else
#elif defined(USE_CUDA)
// Enable cuda graph in cuda provider option.
OrtCUDAProviderOptionsV2* cuda_options = nullptr;
ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr);
@ -2516,6 +2528,20 @@ TEST(CApiTest, basic_cuda_graph_with_annotation) {
static_cast<OrtSessionOptions*>(session_options),
rel_cuda_options.get()) == nullptr);
Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#elif defined(USE_ROCM)
// Enable hip graph in rocm provider option.
OrtROCMProviderOptions* rocm_options = nullptr;
ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr);
std::unique_ptr<OrtROCMProviderOptions, decltype(api.ReleaseROCMProviderOptions)>
rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions);
std::vector<const char*> keys{"enable_hip_graph"};
std::vector<const char*> values{"1"};
ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr);
ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM(
static_cast<OrtSessionOptions*>(session_options),
rel_rocm_options.get()) == nullptr);
Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault);
#endif
Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options);

View file

@ -21,7 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn")
s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut")
s = s.replace("kTotalCudaStreams", "kTotalHipStreams")
# these should be "hip" but it's easier to just use rocm to avoid complicated file renaming
s = s.replace("CudaGraph", "RocmGraph")
s = s.replace("CUDAGraph", "ROCMGraph")
s = s.replace("cuda_graph", "rocm_graph")
s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels")
s = s.replace("cudaEvent", "hipEvent")
s = s.replace("CreateCudaAllocator", "CreateRocmAllocator")