MIGraphX Execution Provider: Stream Synchronization (#12899)

**Description**: Changes to the MIGraphx execution provider code to
allow for stream synchronization on the gpu side

**Motivation and Context**
Performance boost by removing redundant host to device synchronizations 

The current implementation of the execution provider continuously calls
hipDeviceSynchronize() between computations which adds overhead and an
idle wait between the GPU's computations. This is noticeable during
device

This change leverages new functionality that's been added to MIGraphX to
allow for GPU side synchronization which avoids the need for
host->device waits.

To maintain backwards compatibility with older MIGraphX versions, the
compile time define MIGRAPHX_STREAM_SYNC has been added to the API to
allow for older version operate with newer builds of onnxruntime without
loss of functionality to the current feature set as of (08/09/22)

Co-authored-by: Ted Themistokleous <tthemist@amd.com>
This commit is contained in:
Ted Themistokleous 2022-10-14 13:23:51 -04:00 committed by GitHub
parent 3c08f24efc
commit a561fde126
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 55 additions and 5 deletions

@ -1 +1 @@
Subproject commit c220895fd1163c01f0a8b44229fb9e4fe0ae0958
Subproject commit fc645b7626ebf86530dbd82fbece74d457e7ae07

2
cmake/external/onnx vendored

@ -1 +1 @@
Subproject commit 5a5f8a5935762397aa68429b5493084ff970f774
Subproject commit f7ee1ac60d06abe8e26c9b6bbe1e3db5286b614b

@ -1 +1 @@
Subproject commit a902b39270841beafc307dfa709610aa1cac2f06
Subproject commit 0dab03ba7bc438d7ba3eac2b2c1eb39ed520f928

View file

@ -1193,6 +1193,15 @@ if (onnxruntime_USE_MIGRAPHX)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs)
include(CheckLibraryExists)
check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC)
if(HAS_STREAM_SYNC)
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC)
message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED")
else()
message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED")
endif()
install(TARGETS onnxruntime_providers_migraphx
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}

View file

@ -12,6 +12,7 @@
#include "hip_fence.h"
#include "gpu_data_transfer.h"
#include "migraphx_call.h"
#include "migraphx_inc.h"
#include <fstream>
#include <algorithm>
@ -1039,7 +1040,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
delete static_cast<MIGraphXFuncState*>(state);
};
compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
Ort::KernelContext ctx(context);
MIGraphXFuncState* mgx_state = reinterpret_cast<MIGraphXFuncState*>(state);
@ -1165,9 +1166,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
{
// lock to avoid race condition
std::lock_guard<OrtMutex> lock(*(mgx_state->mgx_mu_ptr));
#ifdef MIGRAPHX_STREAM_SYNC
auto prog_outputs = prog.run_async(m, static_cast<hipStream_t>(GetComputeStream()));
#else
auto prog_outputs = prog.eval(m);
HIP_CALL_THROW(hipDeviceSynchronize());
#endif
// In case of input parameters are reused as output parameter call hipMemcpy
auto output_num = prog_outputs.size();
if (prog_output_indices.size() < output_num) {
@ -1193,4 +1198,32 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
return Status::OK();
}
#ifdef MIGRAPHX_STREAM_SYNC
Status MIGraphXExecutionProvider::Sync() const {
HIP_CALL_THROW(hipStreamSynchronize(static_cast<hipStream_t>(nullptr)));
auto status = hipStreamQuery(static_cast<hipStream_t>(GetComputeStream()));
if (status != hipSuccess) {
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL);
}
return Status::OK();
}
Status MIGraphXExecutionProvider::OnRunStart()
{
return Status::OK();
}
Status MIGraphXExecutionProvider::OnRunEnd(bool) {
auto status = hipStreamQuery(static_cast<hipStream_t>(GetComputeStream()));
if (status != hipSuccess) {
HIP_CALL_THROW(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
}
return Status::OK();
}
#endif
} // namespace onnxruntime

View file

@ -41,6 +41,14 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info);
~MIGraphXExecutionProvider() = default;
#ifdef MIGRAPHX_STREAM_SYNC
Status Sync() const override;
Status OnRunStart() override;
Status OnRunEnd(bool sync_stream) override;
#endif
std::vector<std::unique_ptr<ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_lookup*/) const override;