diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index 7d2a7f2bea..d7e5e2b9e3 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -1198,9 +1198,9 @@ if (onnxruntime_USE_CUDA)
endif()
endif()
endif()
- set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --default-stream legacy")
+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
if (NOT WIN32)
- set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --expt-relaxed-constexpr --compiler-options -fPIC")
+ set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --compiler-options -fPIC")
endif()
# Options passed to cudafe
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=bad_friend_decl\"")
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index 6069ed4839..a2454997bc 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -165,6 +165,9 @@ class IExecutionProvider {
*/
virtual common::Status OnSessionInitializationEnd() { return Status::OK(); }
+ virtual common::Status SetComputeStream(void*) { return Status::OK(); }
+ virtual void* GetComputeStream() const { return nullptr; }
+
void InsertAllocator(AllocatorPtr allocator);
void ReplaceAllocator(AllocatorPtr allocator);
// TODO: temparary sulotion, need to unify the interface in EP and AllocatorManager
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 3b5c4af359..b0985608fc 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -266,8 +266,19 @@ typedef struct OrtCUDAProviderOptions {
size_t cuda_mem_limit; // default cuda memory limitation to maximum finite value of size_t.
int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo.
int do_copy_in_default_stream;
+ int has_user_compute_stream;
+ void* user_compute_stream;
} OrtCUDAProviderOptions;
+///
+/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT
+///
+typedef struct OrtTensorRTProviderOptions {
+ int device_id;
+ int has_user_compute_stream;
+ void* user_compute_stream;
+} OrtTensorRTProviderOptions;
+
///
/// Options for the OpenVINO provider that are passed to SessionOptionsAppendExecutionProvider_OpenVINO
///
@@ -1146,6 +1157,12 @@ struct OrtApi {
*/
ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);
+ /**
+ * Append TensorRT execution provider to the session options
+ * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure.
+ */
+ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT,
+ _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index d5aa79a79d..be43d9cd21 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -326,6 +326,7 @@ struct SessionOptions : Base {
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
+ SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
};
struct ModelMetadata : Base {
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index a5ce8219f6..a818c3c691 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -490,6 +490,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUD
return *this;
}
+inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options));
+ return *this;
+}
+
inline SessionOptions& SessionOptions::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(p_, &provider_options));
return *this;
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc
index 45bda90b1e..6a26e0f6c3 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.cc
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc
@@ -29,6 +29,7 @@ namespace cuda {
ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); \
Ctx##x func_ctx = MakeFuncCtx(); \
Impl_##x::MappedType>( \
+ Stream(), \
reinterpret_cast::MappedType*>(p.input_tensor->template Data()), \
reinterpret_cast::MappedType*>(p.output_tensor->template MutableData()), \
&func_ctx, p.output_tensor->Shape().Size()); \
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
index 62601a1c69..7988ecd42f 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
@@ -45,14 +45,15 @@ struct OP_Gelu : public CtxGelu {
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
- UnaryElementWiseImpl(input_data, \
+ UnaryElementWiseImpl(stream, \
+ input_data, \
output_data, \
*reinterpret_cast*>(func_ctx), \
count); \
}
#define SPECIALIZED_UNARY_ACTIVATION_IMPL(name, T) \
- template void Impl_##name(const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count);
+ template void Impl_##name(cudaStream_t stream, const T* input_data, T* output_data, const Ctx##name* func_ctx, size_t count);
#define SPECIALIZED_UNARY_ACTIVATIONL_HFD(name) \
SPECIALIZED_UNARY_ACTIVATION_IMPL(name, half) \
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
index 95ea6d5af6..56ece01e46 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
@@ -22,6 +22,7 @@ typedef onnxruntime::cuda::CtxNull CtxGelu;
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template \
void Impl_##name( \
+ cudaStream_t stream, \
const T* input_data, \
T* output_data, \
const Ctx##name* func_ctx, \
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index 25a23a5111..ce9147ad1b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -88,6 +88,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
auto temp_buffer = GetScratchBuffer(workSpaceSize);
if (!LaunchAttentionKernel(
device_prop,
+ Stream(),
reinterpret_cast(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->template Data(),
nullptr == mask_index ? nullptr : &(mask_index->Shape().GetDims()),
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 00f92b4f1c..a342168c6d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -148,6 +148,7 @@ bool QkvToContext(
bool LaunchAttentionKernel(
const cudaDeviceProp& prop,
+ cudaStream_t stream,
const void* input,
const int* mask_index,
const std::vector* mask_index_dims,
@@ -163,9 +164,6 @@ bool LaunchAttentionKernel(
int past_sequence_length,
const void* past,
void* present) {
- // use default stream
- const cudaStream_t stream = nullptr;
-
if (element_size == 2) {
return QkvToContext(prop, cublas, stream,
batch_size, sequence_length, num_heads, head_size, element_size,
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index c51c007290..30f03b8668 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -20,6 +20,7 @@ size_t GetAttentionWorkspaceSize(
bool LaunchAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
+ cudaStream_t stream, // cuda stream
const void* input, // Input tensor
const int* mask_index, // Attention mask raw data or index (end position of each sequence, or end positions and start positions). NULL means no mask.
const std::vector* mask_index_dims, // Mask index shape
diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc
index 8adffa85ed..e975181d29 100644
--- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc
@@ -61,6 +61,7 @@ Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const {
size_t element_size = sizeof(T);
if (!LaunchEmbedLayerNormKernel(
+ Stream(),
output->template MutableData(),
mask_index->template MutableData(),
input_ids->template Data(),
diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
index 9e856e2e35..ad005e40e0 100644
--- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.cu
@@ -173,6 +173,7 @@ bool EmbedSkipLayerNorm(
}
bool LaunchEmbedLayerNormKernel(
+ cudaStream_t stream,
void* output,
void* mask_index,
const int* input_ids,
@@ -188,10 +189,8 @@ bool LaunchEmbedLayerNormKernel(
int batch_size,
int sequence_length,
const size_t element_size) {
- const cudaStream_t stream = nullptr; // default stream
-
if (nullptr == input_mask) {
- if (!CUDA_CALL(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size)))
+ if (!CUDA_CALL(cudaMemsetAsync(mask_index, 0, sizeof(int) * batch_size, stream)))
return false;
} else if (!ComputeMaskIndex(stream, sequence_length, batch_size, input_mask, static_cast(mask_index))) {
return false;
diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h
index 18648e6799..6977fd3e8e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm_impl.h
@@ -6,7 +6,8 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-bool LaunchEmbedLayerNormKernel(void* output, // output tensor
+bool LaunchEmbedLayerNormKernel(cudaStream_t stream,
+ void* output, // output tensor
void* mask_index, // output mask index
const int* input_ids, // input word IDs
const int* segment_ids, // input segment IDs
diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
index 642ef3458c..8e4bfb1c84 100644
--- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc
@@ -47,7 +47,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const {
int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size();
typedef typename ToCudaType::MappedType CudaT;
if (!LaunchFastGeluKernel(GetDeviceProp(),
- nullptr,
+ Stream(),
static_cast(input_length),
static_cast(bias_length),
reinterpret_cast(input->template Data()),
diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
index ef2eecb1ec..9ec5298c2b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc
@@ -111,6 +111,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const {
auto workspace_buffer = GetScratchBuffer(workSpaceSize);
if (!LaunchLongformerAttentionKernel(
device_prop,
+ Stream(),
reinterpret_cast(gemm_buffer.get()),
reinterpret_cast(mask->template Data()),
reinterpret_cast(global_gemm_buffer.get()),
diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu
index fd9637dfc9..191a979fc9 100644
--- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu
@@ -814,6 +814,7 @@ bool LongformerQkvToContext(
bool LaunchLongformerAttentionKernel(
const cudaDeviceProp& prop,
+ cudaStream_t stream,
const void* input,
const void* attention_mask,
const void* global_input,
@@ -828,9 +829,6 @@ bool LaunchLongformerAttentionKernel(
void* workspace,
cublasHandle_t& cublas,
const size_t element_size) {
- // use default stream
- const cudaStream_t stream = nullptr;
-
if (element_size == 2) {
return LongformerQkvToContext(prop, cublas, stream,
batch_size, sequence_length, num_heads, head_size, window, element_size,
diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h
index 632f6d6e5c..c08461e800 100644
--- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.h
@@ -18,6 +18,7 @@ size_t GetLongformerAttentionWorkspaceSize(
bool LaunchLongformerAttentionKernel(
const cudaDeviceProp& device_prop, // Device Properties
+ cudaStream_t stream, // CUDA stream
const void* input, // Input tensor
const void* attention_mask, // Attention mask with shape (B, S)
const void* global_input, // Global attention input, or nullptr when max_num_global == 0.
diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
index f8f6c2ad49..b8238f7690 100644
--- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
@@ -93,6 +93,7 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const {
size_t element_size = sizeof(T);
if (!LaunchSkipLayerNormKernel(
+ Stream(),
output->template MutableData(),
input->template Data(),
skip->template Data(),
diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
index 9c11ff85e0..a7b6aabe52 100644
--- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.cu
@@ -100,6 +100,7 @@ bool ComputeSkipLayerNorm(
}
bool LaunchSkipLayerNormKernel(
+ cudaStream_t stream,
void* output,
const void* input,
const void* skip,
@@ -110,9 +111,6 @@ bool LaunchSkipLayerNormKernel(
int hidden_size,
int element_count,
size_t element_size) {
- // use default stream
- const cudaStream_t stream = nullptr;
-
if (element_size == 2) {
return ComputeSkipLayerNorm(
stream,
diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
index 308242c010..0148231f2b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm_impl.h
@@ -8,6 +8,7 @@ namespace contrib {
namespace cuda {
bool LaunchSkipLayerNormKernel(
+ cudaStream_t stream,
void* output, // output tensor
const void* input, // input tensor
const void* skip, // skip tensor
diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc
index 0e24032c48..6cce365871 100644
--- a/onnxruntime/contrib_ops/cuda/fused_conv.cc
+++ b/onnxruntime/contrib_ops/cuda/fused_conv.cc
@@ -90,7 +90,7 @@ class FusedConv : public onnxruntime::cuda::Conv {
Base::s_.y_data, beta, Base::s_.y_tensor, Base::s_.y_data));
}
if (Base::s_.post_slicing_required) {
- onnxruntime::cuda::SliceOutUnwantedOutputSection(Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
+ onnxruntime::cuda::SliceOutUnwantedOutputSection(this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(),
Base::s_.y_dims, Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size);
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc
index 546fc105de..f762b09d9f 100644
--- a/onnxruntime/contrib_ops/cuda/inverse.cc
+++ b/onnxruntime/contrib_ops/cuda/inverse.cc
@@ -35,22 +35,24 @@ ONNX_OPERATOR_KERNEL_EX(
namespace inverse_internal {
template
-Status ComputeMatrixOffsets(T* workspace_data, size_t num_batches, size_t rows, IAllocatorUniquePtr& matrix_ptrs) {
+Status ComputeMatrixOffsets(cudaStream_t stream, T* workspace_data, size_t num_batches, size_t rows, IAllocatorUniquePtr& matrix_ptrs) {
std::vector cuda_ptrs;
const size_t matrix_size = rows * rows;
for (size_t i = 0; i < num_batches; ++i) {
cuda_ptrs.push_back(workspace_data);
workspace_data += matrix_size;
}
- CUDA_RETURN_IF_ERROR(cudaMemcpy(matrix_ptrs.get(), cuda_ptrs.data(), sizeof(T*) * num_batches,
- cudaMemcpyHostToDevice));
+
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(matrix_ptrs.get(), cuda_ptrs.data(), sizeof(T*) * num_batches,
+ cudaMemcpyHostToDevice, stream));
return Status::OK();
}
-Status CheckForSingularity(const IAllocatorUniquePtr& info, const std::unique_ptr& info_cpu, size_t num_batches) {
+Status CheckForSingularity(cudaStream_t stream, const IAllocatorUniquePtr& info, const std::unique_ptr& info_cpu, size_t num_batches) {
// Let's check if any of the info values is non-zero
- CUDA_RETURN_IF_ERROR(cudaMemcpy(info_cpu.get(), info.get(), sizeof(int) * num_batches,
- cudaMemcpyDeviceToHost));
+ // cudaMemcpyAsync from device memory to pageable host memory will return only once the copy has completed.
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(info_cpu.get(), info.get(), sizeof(int) * num_batches,
+ cudaMemcpyDeviceToHost, stream));
for (size_t i = 0; i < num_batches; ++i) {
if (info_cpu[i] != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Matrix is singular at batch:", i);
@@ -63,7 +65,7 @@ Status CheckForSingularity(const IAllocatorUniquePtr& info, const std::uniq
template
struct Inverse::ComputeImpl {
- Status operator()(Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output,
+ Status operator()(cudaStream_t stream, Inverse::CublasHandle cublas_h, const Inverse* inst, const Tensor& input, Tensor& output,
const IAllocatorUniquePtr& info, const IAllocatorUniquePtr& pivots,
size_t num_batches, size_t rows) const {
using namespace onnxruntime::cuda;
@@ -79,52 +81,52 @@ struct Inverse::ComputeImpl {
IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(input_count);
if (std::is_same::value) {
// Convert from MLFloat16(half) to float
- Impl_Cast(reinterpret_cast(input.Data()), input_workspace.get(), input_count);
+ Impl_Cast(stream, reinterpret_cast(input.Data()), input_workspace.get(), input_count);
} else {
- CUDA_RETURN_IF_ERROR(cudaMemcpy(input_workspace.get(), input.Data(), sizeof(float) * input_count,
- cudaMemcpyDeviceToDevice));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(float) * input_count,
+ cudaMemcpyDeviceToDevice, stream));
}
IAllocatorUniquePtr matrix_ptrs = inst->GetScratchBuffer(n_batches);
- ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(input_workspace.get(), num_batches, rows, matrix_ptrs));
+ ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, input_workspace.get(), num_batches, rows, matrix_ptrs));
// Do LU factorization
CUBLAS_RETURN_IF_ERROR(cublasSgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches));
- ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
+ ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches));
// Need to compute ptrs for output buffers
// Output for MLFloat
IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches);
if (std::is_same::value) {
IAllocatorUniquePtr ml_float_output = inst->GetScratchBuffer(input_count);
- ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(ml_float_output.get(), num_batches, rows, output_ptrs));
+ ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, ml_float_output.get(), num_batches, rows, output_ptrs));
// Do the inverse
CUBLAS_RETURN_IF_ERROR(cublasSgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
- ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
+ ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches));
// Copy the result to output with casting
- Impl_Cast(ml_float_output.get(), reinterpret_cast(output.MutableData()), input_count);
+ Impl_Cast(stream, ml_float_output.get(), reinterpret_cast(output.MutableData()), input_count);
// We are done here
} else {
- ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(output.MutableData(), num_batches, rows, output_ptrs));
+ ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, output.MutableData(), num_batches, rows, output_ptrs));
// Do the inverse
CUBLAS_RETURN_IF_ERROR(cublasSgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
- ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
+ ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches));
// We are done here
}
} else if (std::is_same::value) {
IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(static_cast(input_count));
- CUDA_RETURN_IF_ERROR(cudaMemcpy(input_workspace.get(), input.Data(), sizeof(double) * input_count,
- cudaMemcpyDeviceToDevice));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(double) * input_count,
+ cudaMemcpyDeviceToDevice, stream));
IAllocatorUniquePtr matrix_ptrs = inst->GetScratchBuffer(n_batches);
- ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(input_workspace.get(), num_batches, rows, matrix_ptrs));
+ ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, input_workspace.get(), num_batches, rows, matrix_ptrs));
// Do LU factorization
CUBLAS_RETURN_IF_ERROR(cublasDgetrfBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), info.get(), n_batches));
- ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
+ ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches));
// Need to compute ptrs for output buffers
IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches);
- ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(output.MutableData(), num_batches, rows, output_ptrs));
+ ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, output.MutableData(), num_batches, rows, output_ptrs));
CUBLAS_RETURN_IF_ERROR(cublasDgetriBatched(cublas_h, dim, matrix_ptrs.get(), dim, pivots.get(), output_ptrs.get(), dim, info.get(), n_batches));
- ORT_RETURN_IF_ERROR(CheckForSingularity(info, info_cpu, num_batches));
+ ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches));
// We are done here
} else {
ORT_THROW("Type is not supported");
@@ -148,11 +150,11 @@ Status Inverse::ComputeInternal(OpKernelContext* ctx) const {
}
IAllocatorUniquePtr info = GetScratchBuffer(num_batches);
- CUDA_RETURN_IF_ERROR(cudaMemsetAsync(info.get(), 0, num_batches));
+ CUDA_RETURN_IF_ERROR(cudaMemsetAsync(info.get(), 0, num_batches, Stream()));
IAllocatorUniquePtr pivots = GetScratchBuffer(rows * num_batches);
utils::MLTypeCallDispatcherRet t_disp(input->GetElementType());
- return t_disp.Invoke(Base::CublasHandle(), this, *input, *output, info, pivots, num_batches, rows);
+ return t_disp.Invoke(Stream(), Base::CublasHandle(), this, *input, *output, info, pivots, num_batches, rows);
}
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/layer_norm.cc b/onnxruntime/contrib_ops/cuda/layer_norm.cc
index 3a864bc7b7..12f37f36a0 100644
--- a/onnxruntime/contrib_ops/cuda/layer_norm.cc
+++ b/onnxruntime/contrib_ops/cuda/layer_norm.cc
@@ -98,7 +98,7 @@ Status LayerNorm::ComputeInternal(OpKernelContext* ctx) const
inv_var_data = reinterpret_cast(var->template MutableData());
}
- HostApplyLayerNorm(GetDeviceProp(), Y_data, mean_data, inv_var_data, X_data, n1, n2, epsilon_, scale_data, bias_data);
+ HostApplyLayerNorm(GetDeviceProp(), Stream(), Y_data, mean_data, inv_var_data, X_data, n1, n2, epsilon_, scale_data, bias_data);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu
index 0d2d6fd2e2..46e8fa2900 100644
--- a/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/layer_norm_impl.cu
@@ -350,6 +350,7 @@ __global__ void cuApplyLayerNorm(
template
void HostApplyLayerNorm(
const cudaDeviceProp& prop,
+ cudaStream_t stream,
T* output,
U* mean,
U* invvar,
@@ -367,7 +368,7 @@ void HostApplyLayerNorm(
const dim3 blocks(1, std::min(n1, maxGridY), 1);
int nshared =
threads.y > 1 ? threads.y * sizeof(U) + (threads.y / 2) * sizeof(U) : 0;
- cuApplyLayerNorm<<>>(
+ cuApplyLayerNorm<<>>(
output,
mean,
invvar,
@@ -378,7 +379,7 @@ void HostApplyLayerNorm(
}
#define LAYERNORM_LINEAR_IMPL(T, U, simplified) \
- template void HostApplyLayerNorm(const cudaDeviceProp& prop, T* output, U* mean, U* invvar, const T* input, int n1, int n2, \
+ template void HostApplyLayerNorm(const cudaDeviceProp& prop, cudaStream_t stream, T* output, U* mean, U* invvar, const T* input, int n1, int n2, \
double epsilon, const T* gamma, const T* beta);
LAYERNORM_LINEAR_IMPL(float, float, true)
diff --git a/onnxruntime/contrib_ops/cuda/layer_norm_impl.h b/onnxruntime/contrib_ops/cuda/layer_norm_impl.h
index 039b7700a6..1705d99915 100644
--- a/onnxruntime/contrib_ops/cuda/layer_norm_impl.h
+++ b/onnxruntime/contrib_ops/cuda/layer_norm_impl.h
@@ -32,6 +32,7 @@ namespace cuda {
template
void HostApplyLayerNorm(
const cudaDeviceProp& prop,
+ cudaStream_t stream,
T* output,
U* mean,
U* invvar,
diff --git a/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc b/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc
index 71d8679319..d9d30055dd 100644
--- a/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc
+++ b/onnxruntime/contrib_ops/cuda/math/bias_softmax.cc
@@ -15,6 +15,7 @@ namespace cuda {
template
void DispatchBiasSoftmaxForwardImpl(
+ cudaStream_t stream,
Tensor* output_tensor,
const Tensor* input_tensor,
const Tensor* input_bias_tensor,
@@ -25,6 +26,7 @@ void DispatchBiasSoftmaxForwardImpl(
template
void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
+ cudaStream_t stream,
cudnnHandle_t cudaDnnHandle,
int element_count,
int batch_count,
@@ -64,12 +66,12 @@ Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
// expect thread blocks can fill SM at high occupancy without overflowing registers
utils::MLTypeCallDispatcher
t_disp(X->GetElementType());
- t_disp.Invoke(Y, X, B, D, N, D, broadcast_size);
+ t_disp.Invoke(Stream(), Y, X, B, D, N, D, broadcast_size);
} else {
// need to fallback to add kernel + CUDA DNN library softmax call :/
utils::MLTypeCallDispatcher
t_disp(X->GetElementType());
- t_disp.Invoke(CudnnHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
+ t_disp.Invoke(Stream(), CudnnHandle(), D, N, broadcast_axis, softmax_axis, X_shape, X, B_shape, B, Y);
}
return Status::OK();
@@ -77,6 +79,7 @@ Status BiasSoftmax::ComputeInternal(OpKernelContext* ctx) const {
template
void DispatchBiasSoftmaxForward::operator()(
+ cudaStream_t stream,
Tensor* output,
const Tensor* input,
const Tensor* input_bias,
@@ -85,6 +88,7 @@ void DispatchBiasSoftmaxForward::operator()(
int batch_stride,
int bias_broadcast_size_per_batch) {
DispatchBiasSoftmaxForwardImpl(
+ stream,
output,
input,
input_bias,
@@ -96,6 +100,7 @@ void DispatchBiasSoftmaxForward::operator()(
template
void DispatchBiasSoftMaxForwardViaDnnLibrary::operator()(
+ cudaStream_t stream,
cudnnHandle_t cudaDnnHandle,
int element_count,
int batch_count,
@@ -107,6 +112,7 @@ void DispatchBiasSoftMaxForwardViaDnnLibrary::operator()(
const onnxruntime::Tensor* B,
onnxruntime::Tensor* Y) {
DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
+ stream,
cudaDnnHandle,
element_count,
batch_count,
diff --git a/onnxruntime/contrib_ops/cuda/math/bias_softmax.h b/onnxruntime/contrib_ops/cuda/math/bias_softmax.h
index 5bbc7266a3..03baec8d35 100644
--- a/onnxruntime/contrib_ops/cuda/math/bias_softmax.h
+++ b/onnxruntime/contrib_ops/cuda/math/bias_softmax.h
@@ -13,6 +13,7 @@ namespace cuda {
template
struct DispatchBiasSoftmaxForward {
void operator()(
+ cudaStream_t stream,
Tensor* output,
const Tensor* input,
const Tensor* input_bias,
@@ -25,6 +26,7 @@ struct DispatchBiasSoftmaxForward {
template
struct DispatchBiasSoftMaxForwardViaDnnLibrary {
void operator()(
+ cudaStream_t stream,
cudnnHandle_t cudaDnnHandle,
int element_count,
int batch_count,
diff --git a/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu b/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu
index 959a2d191c..27b2363219 100644
--- a/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/math/bias_softmax_impl.cu
@@ -127,6 +127,7 @@ __global__ void BiasSoftmaxWarpForward(
template
void DispatchBiasSoftmaxForwardImpl(
+ cudaStream_t stream,
Tensor* output_tensor,
const Tensor* input_tensor,
const Tensor* input_bias_tensor,
@@ -167,47 +168,47 @@ void DispatchBiasSoftmaxForwardImpl(
switch (log2_elements) {
case 0: // 1
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 1: // 2
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 2: // 4
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 3: // 8
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 4: // 16
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 5: // 32
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 6: // 64
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 7: // 128
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 8: // 256
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 9: // 512
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
case 10: // 1024
BiasSoftmaxWarpForward
- <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
+ <<>>(output, input, input_bias, element_count, batch_count, batch_stride, bias_broadcast_size_per_batch);
break;
default:
break;
@@ -216,6 +217,7 @@ void DispatchBiasSoftmaxForwardImpl(
#define SPECIALIZED_BIAS_SOFTMAX_IMPL(T) \
template void DispatchBiasSoftmaxForwardImpl( \
+ cudaStream_t stream, \
Tensor * output_tensor, \
const Tensor* input_tensor, \
const Tensor* input_bias_tensor, \
@@ -232,6 +234,7 @@ SPECIALIZED_BIAS_SOFTMAX_IMPL(MLFloat16)
// note: This is an unhappy path! There is no performance benefit for the fusion.
template
void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
+ cudaStream_t stream,
cudnnHandle_t cudaDnnHandle,
int element_count,
int batch_count,
@@ -278,6 +281,7 @@ void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
// invoke elementwise add with broadcast kernel
::onnxruntime::cuda::BinaryElementWiseImpl(
+ stream,
(int32_t)X_shape.NumDimensions(),
&lhs_padded_strides,
X_data,
@@ -311,6 +315,7 @@ void DispatchBiasSoftMaxForwardViaDnnLibraryImpl(
#define SPECIALIZED_BIAS_SOFTMAX_IMPL_VIA_DNN(T) \
template void DispatchBiasSoftMaxForwardViaDnnLibraryImpl( \
+ cudaStream_t stream, \
cudnnHandle_t cudaDnnHandle, \
int element_count, \
int batch_count, \
diff --git a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops.cc b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops.cc
index a96e576b7d..5f85223a6b 100644
--- a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops.cc
+++ b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops.cc
@@ -25,6 +25,7 @@ namespace cuda {
BinaryElementwisePreparation prepare; \
ORT_RETURN_IF_ERROR(Prepare(context, &prepare)); \
Impl_##x::MappedType>( \
+ Stream(), \
prepare.output_rank_or_simple_broadcast, \
&prepare.lhs_padded_strides, \
reinterpret_cast::MappedType*>(prepare.lhs_tensor->template Data()), \
diff --git a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu
index c6b977ddbe..01791ed94c 100644
--- a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.cu
@@ -20,7 +20,8 @@ namespace cuda {
#define CONTRIB_BINARY_ELEMENTWISE_IMPL(name) \
CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name) { \
- BinaryElementWiseImpl(output_rank_or_simple_broadcast, \
+ BinaryElementWiseImpl(stream, \
+ output_rank_or_simple_broadcast, \
lhs_padded_strides, \
lhs_data, \
rhs_padded_strides, \
@@ -34,7 +35,8 @@ namespace cuda {
}
#define CONTRIB_SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, T) \
- template void Impl_##x(int32_t output_rank, \
+ template void Impl_##x(cudaStream_t stream, \
+ int32_t output_rank, \
const TArray* lhs_padded_strides, \
const T* lhs_data, \
const TArray* rhs_padded_strides, \
diff --git a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.h b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.h
index bb2af2f55a..6ff4233278 100644
--- a/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.h
+++ b/onnxruntime/contrib_ops/cuda/math/binary_elementwise_ops_impl.h
@@ -20,6 +20,7 @@ namespace cuda {
#define CONTRIB_BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
template \
void Impl_##name( \
+ cudaStream_t stream, \
int32_t output_rank_or_simple_broadcast, \
const TArray* lhs_padded_strides, \
const T* lhs_data, \
diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul.cc b/onnxruntime/contrib_ops/cuda/math/complex_mul.cc
index 70d286ae0d..9584e8de3c 100644
--- a/onnxruntime/contrib_ops/cuda/math/complex_mul.cc
+++ b/onnxruntime/contrib_ops/cuda/math/complex_mul.cc
@@ -42,6 +42,7 @@ Status ComplexMul::ComputeInternal(OpKernelContext* context) const {
BinaryElementwisePreparation prepare;
ORT_RETURN_IF_ERROR(Prepare(context, &prepare));
ComplexMul_Impl::MappedType>(
+ Stream(),
prepare.output_rank_or_simple_broadcast,
&prepare.lhs_padded_strides,
reinterpret_cast::MappedType*>(prepare.lhs_tensor->template Data()),
diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu
index 0004cf9433..fdbc986b89 100644
--- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu
@@ -90,6 +90,7 @@ __global__ void _ElementWiseWithStrideTwo(
template
void ComplexMul_Impl(
+ cudaStream_t stream,
int32_t output_rank_or_simple_broadcast,
const TArray* lhs_padded_strides,
const T* lhs_data,
@@ -110,7 +111,7 @@ void ComplexMul_Impl(
CUDA_LONG N = static_cast(count);
if (lhs_padded_strides && rhs_padded_strides && lhs_padded_strides->Size() && rhs_padded_strides->Size())
- _ElementWiseWithStrideTwo<<>>(
+ _ElementWiseWithStrideTwo<<>>(
output_rank_or_simple_broadcast,
*lhs_padded_strides,
lhs_data,
@@ -123,7 +124,7 @@ void ComplexMul_Impl(
rhs_size,
is_conj);
else if (lhs_padded_strides && lhs_padded_strides->Size())
- _ElementWiseWithStrideTwo<<>>(
+ _ElementWiseWithStrideTwo<<>>(
output_rank_or_simple_broadcast,
*lhs_padded_strides,
lhs_data,
@@ -136,7 +137,7 @@ void ComplexMul_Impl(
rhs_size,
is_conj);
else
- _ElementWiseWithStrideTwo<<>>(
+ _ElementWiseWithStrideTwo<<>>(
output_rank_or_simple_broadcast,
*lhs_padded_strides,
lhs_data,
@@ -152,6 +153,7 @@ void ComplexMul_Impl(
#define SPECIALIZE_STACKEDCOMPLEXMUL_IMPL(T) \
template void ComplexMul_Impl( \
+ cudaStream_t stream, \
int32_t output_rank_or_simple_broadcast, \
const TArray* lhs_padded_strides, \
const T* lhs_data, \
diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.h b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.h
index d48eea9a9f..dae66d8325 100644
--- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.h
+++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.h
@@ -13,6 +13,7 @@ using namespace ::onnxruntime::cuda;
template
void ComplexMul_Impl(
+ cudaStream_t stream,
int32_t output_rank_or_simple_broadcast,
const TArray* lhs_padded_strides,
const T* lhs_data,
diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc
index 3c60644d70..c685882e92 100644
--- a/onnxruntime/contrib_ops/cuda/math/fft_ops.cc
+++ b/onnxruntime/contrib_ops/cuda/math/fft_ops.cc
@@ -127,11 +127,11 @@ Status FFTBase::DoFFT(OpKernelContext* context, const Tensor* X, bool complex
Tensor* Y = const_cast(context)->Output(0, TensorShape(output_dims));
auto* x_data = reinterpret_cast(X->template Data());
auto* y_data = reinterpret_cast(Y->template MutableData());
-
+ CUFFT_RETURN_IF_ERROR(cufftSetStream(plan_info.plan, Stream()));
CUFFT_RETURN_IF_ERROR(cufftXtExec(plan_info.plan, const_cast(x_data), y_data, inverse ? CUFFT_INVERSE : CUFFT_FORWARD));
if (inverse) {
- PostProcess(signal_dims, output_size, y_data);
+ PostProcess(Stream(), signal_dims, output_size, y_data);
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu
index 20d6272628..c1f4a088e0 100644
--- a/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.cu
@@ -27,14 +27,14 @@ __global__ void _Normalize(
}
template
-void PostProcess(const std::vector& signal_dims, int64_t N, T* output_data) {
+void PostProcess(cudaStream_t stream, const std::vector& signal_dims, int64_t N, T* output_data) {
int64_t scale = std::accumulate(signal_dims.begin(), signal_dims.end(), 1ll, std::multiplies());
int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock));
- _Normalize<<>>(output_data, N, static_cast(scale));
+ _Normalize<<>>(output_data, N, static_cast(scale));
}
#define SPECIALIZED_IMPL(T) \
- template void PostProcess(const std::vector& signal_dims, int64_t N, T* output_data);
+ template void PostProcess(cudaStream_t stream, const std::vector& signal_dims, int64_t N, T* output_data);
SPECIALIZED_IMPL(float)
SPECIALIZED_IMPL(double)
diff --git a/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h
index 8a7f7789c0..2312acd5d3 100644
--- a/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h
+++ b/onnxruntime/contrib_ops/cuda/math/fft_ops_impl.h
@@ -12,7 +12,7 @@ namespace contrib {
namespace cuda {
template
-void PostProcess(const std::vector& signal_dims, int64_t N, T* output_data);
+void PostProcess(cudaStream_t stream, const std::vector& signal_dims, int64_t N, T* output_data);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
index 67d51b53d5..5833e2fcee 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
@@ -158,6 +158,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const {
}
// scale back and bias
CudaDequantizeWithBias(
+ Stream(),
gemm_buffer_quantized.get(),
reinterpret_cast(bias->template Data()),
reinterpret_cast(gemm_buffer.get()),
@@ -172,6 +173,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const {
auto temp_buffer = GetScratchBuffer(workSpaceSize);
if (!LaunchAttentionKernel(
GetDeviceProp(),
+ Stream(),
reinterpret_cast(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->template Data(),
nullptr == mask_index ? nullptr : &(mask_index->Shape().GetDims()),
diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu
index 42791ae795..168c8a6f42 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu
@@ -31,10 +31,10 @@ __global__ void DequantizeLinearKernel(const int32_t* quantize, const T* bias, T
}
template
-Status CudaDequantizeWithBias(const int32_t* quantize, const T* bias, T* output, T scale, int m, int n) {
+Status CudaDequantizeWithBias(cudaStream_t stream, const int32_t* quantize, const T* bias, T* output, T scale, int m, int n) {
int blocksPerGrid = static_cast(CeilDiv(m * n, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
CUDA_LONG N = static_cast(m * n);
- DequantizeLinearKernel<<>>(
+ DequantizeLinearKernel<<>>(
quantize,
bias,
output,
@@ -44,8 +44,8 @@ Status CudaDequantizeWithBias(const int32_t* quantize, const T* bias, T* output,
return Status::OK();
}
-template Status CudaDequantizeWithBias(const int32_t* quantize, const float* bias, float* output, float scale, int m, int n);
-template Status CudaDequantizeWithBias(const int32_t* quantize, const half* bias, half* output, half scale, int m, int n);
+template Status CudaDequantizeWithBias(cudaStream_t stream, const int32_t* quantize, const float* bias, float* output, float scale, int m, int n);
+template Status CudaDequantizeWithBias(cudaStream_t stream, const int32_t* quantize, const half* bias, half* output, half scale, int m, int n);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh
index dc0ba262fa..b1aa2b9226 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh
+++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh
@@ -8,7 +8,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
template
-Status CudaDequantizeWithBias(const int32_t* quantize, const Tin* bias, Tin* output, Tin scale, int m, int n);
+Status CudaDequantizeWithBias(cudaStream_t stream, const int32_t* quantize, const Tin* bias, Tin* output, Tin scale, int m, int n);
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/tensor/crop.cc b/onnxruntime/contrib_ops/cuda/tensor/crop.cc
index 66e022e3c4..76495c8b23 100644
--- a/onnxruntime/contrib_ops/cuda/tensor/crop.cc
+++ b/onnxruntime/contrib_ops/cuda/tensor/crop.cc
@@ -56,6 +56,7 @@ Status Crop::ComputeInternal(OpKernelContext* context) const {
fast_divmod fdm_YHW(gsl::narrow_cast((bottomLimit - topBorder) * (rightLimit - leftBorder)));
CropImpl(
+ Stream(),
reinterpret_cast(X->template Data()),
gsl::narrow_cast(leftBorder),
gsl::narrow_cast