mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Avoid cudaStreamSync at the end of Forward/Backward (#9470)
* Skip cudaStreamSynchronize at the end of fw * skip sync stream for end of backward
This commit is contained in:
parent
5797bd6db3
commit
ff23b9ff55
8 changed files with 19 additions and 13 deletions
|
|
@ -160,7 +160,7 @@ class IExecutionProvider {
|
|||
may not be finished on device This function should be regarded as the point
|
||||
that all commands of current Run has been submmited by CPU
|
||||
*/
|
||||
virtual common::Status OnRunEnd() { return Status::OK(); }
|
||||
virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
|
||||
|
||||
/**
|
||||
Called when session creation is complete
|
||||
|
|
|
|||
|
|
@ -333,11 +333,13 @@ Status CUDAExecutionProvider::OnRunStart() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CUDAExecutionProvider::OnRunEnd() {
|
||||
Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
|
||||
// record deferred release event on default stream, and release per_thread_context
|
||||
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
CUDA_RETURN_IF_ERROR(cudaEventRecord(current_deferred_release_event, static_cast<cudaStream_t>(GetComputeStream())));
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
|
||||
if (sync_stream) {
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
|
||||
}
|
||||
ReleasePerThreadContext();
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true;
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
|
|||
|
||||
Status OnRunStart() override;
|
||||
|
||||
Status OnRunEnd() override;
|
||||
Status OnRunEnd(bool sync_stream) override;
|
||||
|
||||
const void* GetExecutionHandle() const noexcept override {
|
||||
// The CUDA interface does not return anything interesting.
|
||||
|
|
|
|||
|
|
@ -336,11 +336,13 @@ Status ROCMExecutionProvider::OnRunStart() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ROCMExecutionProvider::OnRunEnd() {
|
||||
Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
|
||||
// record deferred release event on default stream, and release per_thread_context
|
||||
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
|
||||
HIP_RETURN_IF_ERROR(hipEventRecord(current_deferred_release_event, static_cast<hipStream_t>(GetComputeStream())));
|
||||
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
|
||||
if (sync_stream) {
|
||||
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
|
||||
}
|
||||
ReleasePerThreadContext();
|
||||
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
|
||||
deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true;
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class ROCMExecutionProvider : public IExecutionProvider {
|
|||
|
||||
Status OnRunStart() override;
|
||||
|
||||
Status OnRunEnd() override;
|
||||
Status OnRunEnd(bool sync_stream) override;
|
||||
|
||||
const void* GetExecutionHandle() const noexcept override {
|
||||
// The ROCM interface does not return anything interesting.
|
||||
|
|
|
|||
|
|
@ -645,8 +645,10 @@ std::unique_ptr<IDataTransfer> TensorrtExecutionProvider::GetDataTransfer() cons
|
|||
return onnxruntime::CreateGPUDataTransfer(static_cast<void*>(GetComputeStream()));
|
||||
}
|
||||
|
||||
Status TensorrtExecutionProvider::OnRunEnd() {
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
|
||||
Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
|
||||
if (sync_stream) {
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -911,7 +913,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
|
|||
if (input_shape != nullptr) {
|
||||
auto dim_size = input_shape->dim_size();
|
||||
for (int i = 0; i < dim_size; ++i) {
|
||||
auto &dim = input_shape->dim(i);
|
||||
auto& dim = input_shape->dim(i);
|
||||
if (!dim.has_dim_value() && !dim.has_dim_param()) {
|
||||
has_dim_value_or_param = false;
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
|
|||
|
||||
void RegisterAllocator(std::shared_ptr<AllocatorManager> allocator_manager) override;
|
||||
|
||||
Status OnRunEnd() override;
|
||||
Status OnRunEnd(bool sync_stream) override;
|
||||
|
||||
Status SetComputeStream(void* stream) override;
|
||||
|
||||
|
|
|
|||
|
|
@ -1730,7 +1730,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
|
|||
|
||||
// info all execution providers InferenceSession:Run ended
|
||||
for (auto* xp : exec_providers_to_stop) {
|
||||
auto status = xp->OnRunEnd();
|
||||
auto status = xp->OnRunEnd(/*sync_stream*/ false);
|
||||
ORT_CHECK_AND_SET_RETVAL(status);
|
||||
}
|
||||
|
||||
|
|
@ -1844,7 +1844,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
|
|||
|
||||
// info all execution providers InferenceSession:Run ended
|
||||
for (auto* xp : exec_providers_to_stop) {
|
||||
auto status = xp->OnRunEnd();
|
||||
auto status = xp->OnRunEnd(/*sync_stream*/ true);
|
||||
ORT_CHECK_AND_SET_RETVAL(status);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue