mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
[TensorRT EP] Enable a minimal CUDA EP compilation without kernels (#19052)
Adresses https://github.com/microsoft/onnxruntime/issues/18542. I followed the advice given by @RyanUnderhill [here](https://github.com/microsoft/onnxruntime/pull/18731#issuecomment-1848261925) and went with a minimal CUDA EP for now.
This commit is contained in:
parent
bd9d8fb2a5
commit
bc219ed553
11 changed files with 97 additions and 39 deletions
|
|
@ -79,6 +79,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
|
|||
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF)
|
||||
|
||||
option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
|
||||
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
|
||||
option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF)
|
||||
option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
|
||||
option(onnxruntime_USE_COREML "Build with CoreML support" OFF)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,25 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
|
||||
)
|
||||
|
||||
if (onnxruntime_CUDA_MINIMAL)
|
||||
file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc"
|
||||
)
|
||||
# Remove pch files
|
||||
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h"
|
||||
)
|
||||
else()
|
||||
file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
|
||||
)
|
||||
endif()
|
||||
# Remove pch files
|
||||
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h"
|
||||
|
|
@ -16,11 +31,16 @@
|
|||
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
|
||||
)
|
||||
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
|
||||
)
|
||||
|
||||
|
||||
if (onnxruntime_CUDA_MINIMAL)
|
||||
set(onnxruntime_providers_cuda_shared_srcs "")
|
||||
else()
|
||||
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
|
||||
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
|
||||
)
|
||||
endif()
|
||||
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
|
||||
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
|
||||
|
||||
|
|
@ -156,10 +176,15 @@
|
|||
endif()
|
||||
|
||||
add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
|
||||
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
|
||||
if(onnxruntime_CUDA_MINIMAL)
|
||||
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
|
||||
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
|
||||
else()
|
||||
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
|
||||
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_TRITON_KERNEL)
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@
|
|||
#include "core/providers/custom_op_context.h"
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
#include <cublas_v2.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
#endif
|
||||
namespace Ort {
|
||||
|
||||
namespace Custom {
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ const char* CudaErrString<cudaError_t>(cudaError_t x) {
|
|||
return cudaGetErrorString(x);
|
||||
}
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
template <>
|
||||
const char* CudaErrString<cublasStatus_t>(cublasStatus_t e) {
|
||||
cudaDeviceSynchronize();
|
||||
|
|
@ -76,6 +77,7 @@ const char* CudaErrString<cufftResult>(cufftResult e) {
|
|||
return "Unknown cufft error status";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef ORT_USE_NCCL
|
||||
template <>
|
||||
|
|
@ -132,6 +134,7 @@ std::conditional_t<THRW, void, Status> CudaCall(
|
|||
|
||||
template Status CudaCall<cudaError, false>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
|
||||
template void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
template Status CudaCall<cublasStatus_t, false>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
template void CudaCall<cublasStatus_t, true>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
template Status CudaCall<cudnnStatus_t, false>(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
|
|
@ -140,6 +143,7 @@ template Status CudaCall<curandStatus_t, false>(curandStatus_t retCode, const ch
|
|||
template void CudaCall<curandStatus_t, true>(curandStatus_t retCode, const char* exprString, const char* libName, curandStatus_t successCode, const char* msg, const char* file, const int line);
|
||||
template Status CudaCall<cufftResult, false>(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line);
|
||||
template void CudaCall<cufftResult, true>(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line);
|
||||
#endif
|
||||
|
||||
#ifdef ORT_USE_NCCL
|
||||
template Status CudaCall<ncclResult_t, false>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
|
||||
|
|
|
|||
|
|
@ -14,6 +14,27 @@ namespace cuda {
|
|||
// 0x04 - pedantic
|
||||
constexpr const char* kCudaGemmOptions = "ORT_CUDA_GEMM_OPTIONS";
|
||||
|
||||
const char* CudaDataTypeToString(cudaDataType_t dt) {
|
||||
switch (dt) {
|
||||
case CUDA_R_16F:
|
||||
return "CUDA_R_16F";
|
||||
case CUDA_R_16BF:
|
||||
return "CUDA_R_16BF";
|
||||
case CUDA_R_32F:
|
||||
return "CUDA_R_32F";
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
// Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
|
||||
case CUDA_R_8F_E4M3:
|
||||
return "CUDA_R_8F_E4M3";
|
||||
case CUDA_R_8F_E5M2:
|
||||
return "CUDA_R_8F_E5M2";
|
||||
#endif
|
||||
default:
|
||||
return "<unknown>";
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
// Initialize the singleton instance
|
||||
HalfGemmOptions HalfGemmOptions::instance;
|
||||
|
||||
|
|
@ -54,26 +75,6 @@ const char* cublasGetErrorEnum(cublasStatus_t error) {
|
|||
}
|
||||
}
|
||||
|
||||
const char* CudaDataTypeToString(cudaDataType_t dt) {
|
||||
switch (dt) {
|
||||
case CUDA_R_16F:
|
||||
return "CUDA_R_16F";
|
||||
case CUDA_R_16BF:
|
||||
return "CUDA_R_16BF";
|
||||
case CUDA_R_32F:
|
||||
return "CUDA_R_32F";
|
||||
#if !defined(DISABLE_FLOAT8_TYPES)
|
||||
// Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
|
||||
case CUDA_R_8F_E4M3:
|
||||
return "CUDA_R_8F_E4M3";
|
||||
case CUDA_R_8F_E5M2:
|
||||
return "CUDA_R_8F_E5M2";
|
||||
#endif
|
||||
default:
|
||||
return "<unknown>";
|
||||
}
|
||||
}
|
||||
|
||||
const char* CublasComputeTypeToString(cublasComputeType_t ct) {
|
||||
switch (ct) {
|
||||
case CUBLAS_COMPUTE_16F:
|
||||
|
|
@ -92,6 +93,7 @@ const char* CublasComputeTypeToString(cublasComputeType_t ct) {
|
|||
return "<unknown>";
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// It must exist somewhere already.
|
||||
cudaDataType_t ToCudaDataType(int32_t element_type) {
|
||||
|
|
|
|||
|
|
@ -22,13 +22,14 @@ namespace onnxruntime {
|
|||
namespace cuda {
|
||||
|
||||
#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr))
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
#define CUBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUBLAS_CALL(expr))
|
||||
#define CUSPARSE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUSPARSE_CALL(expr))
|
||||
#define CURAND_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CURAND_CALL(expr))
|
||||
#define CUDNN_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_CALL(expr))
|
||||
#define CUDNN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(CUDNN_CALL2(expr, m))
|
||||
#define CUFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUFFT_CALL(expr))
|
||||
|
||||
#endif
|
||||
// Type mapping for MLFloat16 to half
|
||||
template <typename T>
|
||||
class ToCudaType {
|
||||
|
|
@ -93,7 +94,7 @@ inline bool CalculateFdmStrides(gsl::span<fast_divmod> p, const std::vector<int6
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
class CublasMathModeSetter {
|
||||
public:
|
||||
CublasMathModeSetter(const cudaDeviceProp& prop, cublasHandle_t handle, cublasMath_t mode) : handle_(handle) {
|
||||
|
|
@ -189,6 +190,7 @@ const char* cublasGetErrorEnum(cublasStatus_t error);
|
|||
const char* CudaDataTypeToString(cudaDataType_t dt);
|
||||
|
||||
const char* CublasComputeTypeToString(cublasComputeType_t ct);
|
||||
#endif
|
||||
|
||||
cudaDataType_t ToCudaDataType(int32_t element_type);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/providers/cuda/gpu_data_transfer.h"
|
||||
#include "core/providers/cuda/cuda_profiler.h"
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
#include "contrib_ops/cuda/cuda_contrib_kernels.h"
|
||||
#endif
|
||||
|
|
@ -27,6 +28,7 @@
|
|||
#ifdef USE_TRITON_KERNEL
|
||||
#include "core/providers/cuda/triton_kernel.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include "core/providers/cuda/cuda_stream_handle.h"
|
||||
|
||||
|
|
@ -169,21 +171,23 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de
|
|||
ArenaExtendStrategy /*arena_extend_strategy*/, CUDAExecutionProviderExternalAllocatorInfo /*external_allocator_info*/,
|
||||
OrtArenaCfg* /*default_memory_arena_cfg*/) {
|
||||
CUDA_CALL_THROW(cudaSetDevice(device_id));
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
|
||||
CUBLAS_CALL_THROW(cublasLtCreate(&cublas_lt_handle_));
|
||||
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
|
||||
|
||||
CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_));
|
||||
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));
|
||||
|
||||
#endif
|
||||
cuda_graph_.SetStream(stream);
|
||||
}
|
||||
|
||||
CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(cublas_handle_)));
|
||||
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasLtDestroy(cublas_lt_handle_)));
|
||||
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(cudnn_handle_)));
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
|
||||
|
|
@ -441,6 +445,7 @@ namespace cuda {
|
|||
// opset 1 to 9
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Cos);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Cos);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, Cos);
|
||||
|
|
@ -1315,6 +1320,7 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Reshape);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape);
|
||||
#endif
|
||||
|
||||
template <>
|
||||
KernelCreateInfo BuildKernelCreateInfo<void>() {
|
||||
|
|
@ -1326,6 +1332,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
|
||||
|
|
@ -2201,6 +2208,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Reshape)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape)>,
|
||||
#endif
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
@ -2210,6 +2218,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
}
|
||||
}
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry));
|
||||
#endif
|
||||
|
|
@ -2220,6 +2229,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
#ifdef ENABLE_TRAINING_OPS
|
||||
ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry));
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -10,12 +10,19 @@
|
|||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
#include <cublas_v2.h>
|
||||
#include <cusparse.h>
|
||||
#include <curand.h>
|
||||
#include <cudnn.h>
|
||||
#include <cufft.h>
|
||||
#include <cublasLt.h>
|
||||
#else
|
||||
typedef void* cudnnHandle_t;
|
||||
typedef void* cublasHandle_t;
|
||||
typedef void* cublasLtHandle_t;
|
||||
#endif
|
||||
|
||||
#ifdef ORT_USE_NCCL
|
||||
#include <nccl.h>
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ CudaStream::CudaStream(cudaStream_t stream,
|
|||
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
|
||||
deferred_cpu_allocator_(*this),
|
||||
ep_info_(ep_info) {
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
if (own_flag) {
|
||||
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
|
||||
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
|
||||
|
|
@ -80,10 +81,12 @@ CudaStream::CudaStream(cudaStream_t stream,
|
|||
cudnn_handle_ = external_cudnn_handle;
|
||||
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
CudaStream::~CudaStream() {
|
||||
ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd());
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
if (own_stream_) {
|
||||
cublasDestroy(cublas_handle_);
|
||||
cudnnDestroy(cudnn_handle_);
|
||||
|
|
@ -91,6 +94,7 @@ CudaStream::~CudaStream() {
|
|||
if (handle)
|
||||
cudaStreamDestroy(static_cast<cudaStream_t>(handle));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::unique_ptr<synchronize::Notification> CudaStream::CreateNotification(size_t /*num_consumers*/) {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
#include "core/common/gsl.h"
|
||||
#include "shared_inc/cuda_call.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
|
|
@ -222,3 +222,4 @@ const Float8E5M2 Consts<Float8E5M2>::One = Float8E5M2(1.0f, true);
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
#include <cfloat>
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
#ifndef USE_CUDA_MINIMAL
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
|
|
@ -260,3 +260,4 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc,
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue