From a561fde126211aa22255455bbffc7cec8fd2b38c Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Fri, 14 Oct 2022 13:23:51 -0400 Subject: [PATCH] MIGraphX Execution Provider: Stream Synchronization (#12899) **Description**: Changes to the MIGraphx execution provider code to allow for stream synchronization on the gpu side **Motivation and Context** Performance boost by removing redundant host to device synchronizations The current implementation of the execution provider continuously calls hipDeviceSynchronize() between computations which adds overhead and an idle wait between the GPU's computations. This is noticeable during device This change leverages new functionality that's been added to MIGraphX to allow for GPU side synchronization which avoids the need for host->device waits. To maintain backwards compatibility with older MIGraphX versions, the compile time define MIGRAPHX_STREAM_SYNC has been added to the API to allow for older version operate with newer builds of onnxruntime without loss of functionality to the current feature set as of (08/09/22) Co-authored-by: Ted Themistokleous --- cmake/external/emsdk | 2 +- cmake/external/onnx | 2 +- cmake/external/protobuf | 2 +- cmake/onnxruntime_providers.cmake | 9 +++++ .../migraphx/migraphx_execution_provider.cc | 37 ++++++++++++++++++- .../migraphx/migraphx_execution_provider.h | 8 ++++ 6 files changed, 55 insertions(+), 5 deletions(-) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index c220895fd1..fc645b7626 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit c220895fd1163c01f0a8b44229fb9e4fe0ae0958 +Subproject commit fc645b7626ebf86530dbd82fbece74d457e7ae07 diff --git a/cmake/external/onnx b/cmake/external/onnx index 5a5f8a5935..f7ee1ac60d 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 5a5f8a5935762397aa68429b5493084ff970f774 +Subproject commit f7ee1ac60d06abe8e26c9b6bbe1e3db5286b614b diff --git a/cmake/external/protobuf b/cmake/external/protobuf index a902b39270..0dab03ba7b 160000 --- a/cmake/external/protobuf +++ b/cmake/external/protobuf @@ -1 +1 @@ -Subproject commit a902b39270841beafc307dfa709610aa1cac2f06 +Subproject commit 0dab03ba7bc438d7ba3eac2b2c1eb39ed520f928 diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index cfa936de5d..3aa11a9354 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1193,6 +1193,15 @@ if (onnxruntime_USE_MIGRAPHX) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs) + include(CheckLibraryExists) + check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) + if(HAS_STREAM_SYNC) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC) + message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED") + else() + message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED") + endif() + install(TARGETS onnxruntime_providers_migraphx ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f5d7b8f2b4..f63909eb4b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -12,6 +12,7 @@ #include "hip_fence.h" #include "gpu_data_transfer.h" #include "migraphx_call.h" +#include "migraphx_inc.h" #include #include @@ -1039,7 +1040,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { + compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); @@ -1165,9 +1166,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& { // lock to avoid race condition std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); + + #ifdef MIGRAPHX_STREAM_SYNC + auto prog_outputs = prog.run_async(m, static_cast(GetComputeStream())); + #else auto prog_outputs = prog.eval(m); HIP_CALL_THROW(hipDeviceSynchronize()); - + #endif // In case of input parameters are reused as output parameter call hipMemcpy auto output_num = prog_outputs.size(); if (prog_output_indices.size() < output_num) { @@ -1193,4 +1198,32 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } +#ifdef MIGRAPHX_STREAM_SYNC + +Status MIGraphXExecutionProvider::Sync() const { + HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr))); + + auto status = hipStreamQuery(static_cast(GetComputeStream())); + if (status != hipSuccess) { + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL); + } + return Status::OK(); +} + +Status MIGraphXExecutionProvider::OnRunStart() +{ + return Status::OK(); +} + +Status MIGraphXExecutionProvider::OnRunEnd(bool) { + auto status = hipStreamQuery(static_cast(GetComputeStream())); + + if (status != hipSuccess) { + HIP_CALL_THROW(hipStreamSynchronize(static_cast(GetComputeStream()))); + } + return Status::OK(); +} + +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d16a982414..c76dec2598 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -41,6 +41,14 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider() = default; +#ifdef MIGRAPHX_STREAM_SYNC + Status Sync() const override; + + Status OnRunStart() override; + + Status OnRunEnd(bool sync_stream) override; +#endif + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const override;