Avoid cudaStreamSync at the end of Forward/Backward (#9470)

* Skip cudaStreamSynchronize at the end of fw 

* skip sync stream for end of backward
This commit is contained in:
Sherlock 2021-10-21 11:28:25 -07:00 committed by GitHub
parent 5797bd6db3
commit ff23b9ff55
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 19 additions and 13 deletions

View file

@ -160,7 +160,7 @@ class IExecutionProvider {
may not be finished on device This function should be regarded as the point
that all commands of current Run has been submmited by CPU
*/
virtual common::Status OnRunEnd() { return Status::OK(); }
virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
/**
Called when session creation is complete

View file

@ -333,11 +333,13 @@ Status CUDAExecutionProvider::OnRunStart() {
return Status::OK();
}
Status CUDAExecutionProvider::OnRunEnd() {
Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) {
// record deferred release event on default stream, and release per_thread_context
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
CUDA_RETURN_IF_ERROR(cudaEventRecord(current_deferred_release_event, static_cast<cudaStream_t>(GetComputeStream())));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
if (sync_stream) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
}
ReleasePerThreadContext();
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true;

View file

@ -29,7 +29,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
Status OnRunStart() override;
Status OnRunEnd() override;
Status OnRunEnd(bool sync_stream) override;
const void* GetExecutionHandle() const noexcept override {
// The CUDA interface does not return anything interesting.

View file

@ -336,11 +336,13 @@ Status ROCMExecutionProvider::OnRunStart() {
return Status::OK();
}
Status ROCMExecutionProvider::OnRunEnd() {
Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
// record deferred release event on default stream, and release per_thread_context
auto current_deferred_release_event = GetPerThreadContext().GetCurrentDeferredReleaseEvent();
HIP_RETURN_IF_ERROR(hipEventRecord(current_deferred_release_event, static_cast<hipStream_t>(GetComputeStream())));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
if (sync_stream) {
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
}
ReleasePerThreadContext();
std::lock_guard<OrtMutex> lock(deferred_release_cpu_ptr_mutex_);
deferred_release_cpu_ptr_[current_deferred_release_event].recorded = true;

View file

@ -29,7 +29,7 @@ class ROCMExecutionProvider : public IExecutionProvider {
Status OnRunStart() override;
Status OnRunEnd() override;
Status OnRunEnd(bool sync_stream) override;
const void* GetExecutionHandle() const noexcept override {
// The ROCM interface does not return anything interesting.

View file

@ -645,8 +645,10 @@ std::unique_ptr<IDataTransfer> TensorrtExecutionProvider::GetDataTransfer() cons
return onnxruntime::CreateGPUDataTransfer(static_cast<void*>(GetComputeStream()));
}
Status TensorrtExecutionProvider::OnRunEnd() {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
if (sync_stream) {
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(static_cast<cudaStream_t>(GetComputeStream())));
}
return Status::OK();
}
@ -911,7 +913,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
if (input_shape != nullptr) {
auto dim_size = input_shape->dim_size();
for (int i = 0; i < dim_size; ++i) {
auto &dim = input_shape->dim(i);
auto& dim = input_shape->dim(i);
if (!dim.has_dim_value() && !dim.has_dim_param()) {
has_dim_value_or_param = false;
break;

View file

@ -125,7 +125,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
void RegisterAllocator(std::shared_ptr<AllocatorManager> allocator_manager) override;
Status OnRunEnd() override;
Status OnRunEnd(bool sync_stream) override;
Status SetComputeStream(void* stream) override;

View file

@ -1730,7 +1730,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options,
// info all execution providers InferenceSession:Run ended
for (auto* xp : exec_providers_to_stop) {
auto status = xp->OnRunEnd();
auto status = xp->OnRunEnd(/*sync_stream*/ false);
ORT_CHECK_AND_SET_RETVAL(status);
}
@ -1844,7 +1844,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
// info all execution providers InferenceSession:Run ended
for (auto* xp : exec_providers_to_stop) {
auto status = xp->OnRunEnd();
auto status = xp->OnRunEnd(/*sync_stream*/ true);
ORT_CHECK_AND_SET_RETVAL(status);
}