diff --git a/cmake/external/emsdk b/cmake/external/emsdk index c220895fd1..fc645b7626 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit c220895fd1163c01f0a8b44229fb9e4fe0ae0958 +Subproject commit fc645b7626ebf86530dbd82fbece74d457e7ae07 diff --git a/cmake/external/onnx b/cmake/external/onnx index 5a5f8a5935..f7ee1ac60d 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 5a5f8a5935762397aa68429b5493084ff970f774 +Subproject commit f7ee1ac60d06abe8e26c9b6bbe1e3db5286b614b diff --git a/cmake/external/protobuf b/cmake/external/protobuf index a902b39270..0dab03ba7b 160000 --- a/cmake/external/protobuf +++ b/cmake/external/protobuf @@ -1 +1 @@ -Subproject commit a902b39270841beafc307dfa709610aa1cac2f06 +Subproject commit 0dab03ba7bc438d7ba3eac2b2c1eb39ed520f928 diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index cfa936de5d..3aa11a9354 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1193,6 +1193,15 @@ if (onnxruntime_USE_MIGRAPHX) set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs) + include(CheckLibraryExists) + check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) + if(HAS_STREAM_SYNC) + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC) + message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED") + else() + message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED") + endif() + install(TARGETS onnxruntime_providers_migraphx ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index f5d7b8f2b4..f63909eb4b 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -12,6 +12,7 @@ #include "hip_fence.h" #include "gpu_data_transfer.h" #include "migraphx_call.h" +#include "migraphx_inc.h" #include #include @@ -1039,7 +1040,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& delete static_cast(state); }; - compute_info.compute_func = [](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { + compute_info.compute_func = [this](FunctionState state, const OrtApi* /* api */, OrtKernelContext* context) { Ort::KernelContext ctx(context); MIGraphXFuncState* mgx_state = reinterpret_cast(state); @@ -1165,9 +1166,13 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& { // lock to avoid race condition std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); + + #ifdef MIGRAPHX_STREAM_SYNC + auto prog_outputs = prog.run_async(m, static_cast(GetComputeStream())); + #else auto prog_outputs = prog.eval(m); HIP_CALL_THROW(hipDeviceSynchronize()); - + #endif // In case of input parameters are reused as output parameter call hipMemcpy auto output_num = prog_outputs.size(); if (prog_output_indices.size() < output_num) { @@ -1193,4 +1198,32 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } +#ifdef MIGRAPHX_STREAM_SYNC + +Status MIGraphXExecutionProvider::Sync() const { + HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr))); + + auto status = hipStreamQuery(static_cast(GetComputeStream())); + if (status != hipSuccess) { + return Status(onnxruntime::common::ONNXRUNTIME, onnxruntime::common::EP_FAIL); + } + return Status::OK(); +} + +Status MIGraphXExecutionProvider::OnRunStart() +{ + return Status::OK(); +} + +Status MIGraphXExecutionProvider::OnRunEnd(bool) { + auto status = hipStreamQuery(static_cast(GetComputeStream())); + + if (status != hipSuccess) { + HIP_CALL_THROW(hipStreamSynchronize(static_cast(GetComputeStream()))); + } + return Status::OK(); +} + +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d16a982414..c76dec2598 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -41,6 +41,14 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider() = default; +#ifdef MIGRAPHX_STREAM_SYNC + Status Sync() const override; + + Status OnRunStart() override; + + Status OnRunEnd(bool sync_stream) override; +#endif + std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const override;