mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
MIGraphX Execution Provider: Stream Synchronization (#12899)
**Description**: Changes to the MIGraphx execution provider code to allow for stream synchronization on the gpu side **Motivation and Context** Performance boost by removing redundant host to device synchronizations The current implementation of the execution provider continuously calls hipDeviceSynchronize() between computations which adds overhead and an idle wait between the GPU's computations. This is noticeable during device This change leverages new functionality that's been added to MIGraphX to allow for GPU side synchronization which avoids the need for host->device waits. To maintain backwards compatibility with older MIGraphX versions, the compile time define MIGRAPHX_STREAM_SYNC has been added to the API to allow for older version operate with newer builds of onnxruntime without loss of functionality to the current feature set as of (08/09/22) Co-authored-by: Ted Themistokleous <tthemist@amd.com>
This commit is contained in:
parent
3c08f24efc
commit
a561fde126
6 changed files with 55 additions and 5 deletions
2
cmake/external/emsdk
vendored
2
cmake/external/emsdk
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit c220895fd1163c01f0a8b44229fb9e4fe0ae0958
|
||||
Subproject commit fc645b7626ebf86530dbd82fbece74d457e7ae07
|
||||
2
cmake/external/onnx
vendored
2
cmake/external/onnx
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit 5a5f8a5935762397aa68429b5493084ff970f774
|
||||
Subproject commit f7ee1ac60d06abe8e26c9b6bbe1e3db5286b614b
|
||||
2
cmake/external/protobuf
vendored
2
cmake/external/protobuf
vendored
|
|
@ -1 +1 @@
|
|||
Subproject commit a902b39270841beafc307dfa709610aa1cac2f06
|
||||
Subproject commit 0dab03ba7bc438d7ba3eac2b2c1eb39ed520f928
|
||||
|
|
@ -1193,6 +1193,15 @@ if (onnxruntime_USE_MIGRAPHX)
|
|||
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections")
|
||||
target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs)
|
||||
|
||||
include(CheckLibraryExists)
|
||||
check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC)
|
||||
if(HAS_STREAM_SYNC)
|
||||
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC)
|
||||
message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED")
|
||||
else()
|
||||
message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED")
|
||||
endif()
|
||||
|
||||
install(TARGETS onnxruntime_providers_migraphx
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "hip_fence.h"
|
||||
#include "gpu_data_transfer.h"
|
||||
#include "migraphx_call.h"
|
||||
#include "migraphx_inc.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
|
|
@ -1039,7 +1040,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
delete static_cast<MIGraphXFuncState*>(state);
|
||||
};
|
||||
|
||||
compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
|
||||
compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) {
|
||||
Ort::KernelContext ctx(context);
|
||||
MIGraphXFuncState* mgx_state = reinterpret_cast<MIGraphXFuncState*>(state);
|
||||
|
||||
|
|
@ -1165,9 +1166,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
{
|
||||
// lock to avoid race condition
|
||||
std::lock_guard<OrtMutex> lock(*(mgx_state->mgx_mu_ptr));
|
||||
|
||||
#ifdef MIGRAPHX_STREAM_SYNC
|
||||
auto prog_outputs = prog.run_async(m, static_cast<hipStream_t>(GetComputeStream()));
|
||||
#else
|
||||
auto prog_outputs = prog.eval(m);
|
||||
HIP_CALL_THROW(hipDeviceSynchronize());
|
||||
|
||||
#endif
|
||||
// In case of input parameters are reused as output parameter call hipMemcpy
|
||||
auto output_num = prog_outputs.size();
|
||||
if (prog_output_indices.size() < output_num) {
|
||||
|
|
@ -1193,4 +1198,32 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#ifdef MIGRAPHX_STREAM_SYNC
|
||||
|
||||
Status MIGraphXExecutionProvider::Sync() const {
|
||||
HIP_CALL_THROW(hipStreamSynchronize(static_cast<hipStream_t>(nullptr)));
|
||||
|
||||
auto status = hipStreamQuery(static_cast<hipStream_t>(GetComputeStream()));
|
||||
if (status != hipSuccess) {
|
||||
return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MIGraphXExecutionProvider::OnRunStart()
|
||||
{
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MIGraphXExecutionProvider::OnRunEnd(bool) {
|
||||
auto status = hipStreamQuery(static_cast<hipStream_t>(GetComputeStream()));
|
||||
|
||||
if (status != hipSuccess) {
|
||||
HIP_CALL_THROW(hipStreamSynchronize(static_cast<hipStream_t>(GetComputeStream())));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -41,6 +41,14 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
|
|||
explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info);
|
||||
~MIGraphXExecutionProvider() = default;
|
||||
|
||||
#ifdef MIGRAPHX_STREAM_SYNC
|
||||
Status Sync() const override;
|
||||
|
||||
Status OnRunStart() override;
|
||||
|
||||
Status OnRunEnd(bool sync_stream) override;
|
||||
#endif
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph_viewer,
|
||||
const IKernelLookup& /*kernel_lookup*/) const override;
|
||||
|
|
|
|||
Loading…
Reference in a new issue