diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index b9a774995d..a57463df78 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -84,6 +84,7 @@ option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir"
option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF)
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF)
option(onnxruntime_USE_DML "Build with DirectML support" OFF)
+option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF)
option(onnxruntime_USE_WINML "Build with WinML support" OFF)
option(onnxruntime_USE_ACL "Build with ACL support" OFF)
option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF)
@@ -854,6 +855,14 @@ if (onnxruntime_USE_TENSORRT)
endif()
endif()
+if (onnxruntime_USE_MIGRAPHX)
+ if (WIN32)
+ message(FATAL_ERROR "MIGraphX does not support build in Windows!")
+ endif()
+ set(AMD_MIGRAPHX_HOME ${onnxruntime_MIGRAPHX_HOME})
+ add_definitions(-DUSE_MIGRAPHX=1)
+endif()
+
if (onnxruntime_USE_TVM)
if (WIN32 AND MSVC)
# wd4100: identifier' : unreferenced formal parameter
diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake
index 6815227589..a234638b63 100644
--- a/cmake/onnxruntime.cmake
+++ b/cmake/onnxruntime.cmake
@@ -82,6 +82,7 @@ target_link_libraries(onnxruntime PRIVATE
${PROVIDERS_NNAPI}
${PROVIDERS_RKNPU}
${PROVIDERS_TENSORRT}
+ ${PROVIDERS_MIGRAPHX}
${PROVIDERS_OPENVINO}
${PROVIDERS_NUPHAR}
${PROVIDERS_VITISAI}
diff --git a/cmake/onnxruntime_csharp.cmake b/cmake/onnxruntime_csharp.cmake
index a6856af498..5fd082f53c 100644
--- a/cmake/onnxruntime_csharp.cmake
+++ b/cmake/onnxruntime_csharp.cmake
@@ -22,6 +22,10 @@ if (onnxruntime_USE_TENSORRT)
STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_TENSORRT,")
endif()
+if (onnxruntime_USE_MIGRAPHX)
+ STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_MIGRAPHX,")
+endif()
+
if (onnxruntime_USE_OPENVINO)
STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_OPENVINO,")
endif()
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 9998321a45..2f33a36532 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -67,6 +67,10 @@ if(onnxruntime_USE_DML)
set(PROVIDERS_DML onnxruntime_providers_dml)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES dml)
endif()
+if(onnxruntime_USE_MIGRAPHX)
+ set(PROVIDERS_MIGRAPHX onnxruntime_providers_migraphx)
+ list(APPEND ONNXRUNTIME_PROVIDER_NAMES migraphx)
+endif()
if(onnxruntime_USE_OPENVINO)
set(PROVIDERS_OPENVINO onnxruntime_providers_openvino)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES openvino)
@@ -607,6 +611,32 @@ if (onnxruntime_USE_DML)
set_target_properties(onnxruntime_providers_dml PROPERTIES FOLDER "ONNXRuntime")
endif()
+if (onnxruntime_USE_MIGRAPHX)
+ # Add search paths for default rocm installation
+ list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm)
+
+ find_package(hip)
+ find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME})
+
+ set(migraphx_libs migraphx::c hip::host)
+
+ file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS
+ "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h"
+ "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc"
+ )
+
+ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs})
+ add_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs})
+ target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs})
+ set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime")
+ target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare)
+ target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT})
+ onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnxruntime_framework onnx)
+ add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES})
+ install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/migraphx DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers)
+ set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX)
+endif()
+
if (onnxruntime_USE_ACL)
add_definitions(-DUSE_ACL=1)
file(GLOB_RECURSE onnxruntime_providers_acl_cc_srcs
diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake
index 5edaa75bc7..81625925b8 100644
--- a/cmake/onnxruntime_python.cmake
+++ b/cmake/onnxruntime_python.cmake
@@ -85,6 +85,7 @@ set(onnxruntime_pybind11_state_libs
${PROVIDERS_CUDA}
${PROVIDERS_DNNL}
${PROVIDERS_TENSORRT}
+ ${PROVIDERS_MIGRAPHX}
${PROVIDERS_NGRAPH}
${PROVIDERS_OPENVINO}
${PROVIDERS_NUPHAR}
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index d949310ed7..572485d74a 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -314,6 +314,11 @@ if(onnxruntime_USE_DML)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_dml)
endif()
+if(onnxruntime_USE_MIGRAPHX)
+ list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx)
+endif()
+
+
file(GLOB_RECURSE onnxruntime_test_tvm_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/tvm/*.h"
"${ONNXRUNTIME_ROOT}/test/tvm/*.cc"
@@ -341,6 +346,7 @@ set(ONNXRUNTIME_TEST_LIBS
${PROVIDERS_CUDA}
${PROVIDERS_DNNL}
${PROVIDERS_TENSORRT}
+ ${PROVIDERS_MIGRAPHX}
${PROVIDERS_NGRAPH}
${PROVIDERS_OPENVINO}
${PROVIDERS_NUPHAR}
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 98e2cd013d..c099f8ce6c 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -444,6 +444,9 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tensorrt(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
+ [DllImport(nativeLib, CharSet = charSet)]
+ public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
+
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index 65d349aa76..0fe79c0475 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -135,6 +135,14 @@ namespace Microsoft.ML.OnnxRuntime
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(_nativePtr, deviceId));
}
+ ///
+ /// Use only if you have the onnxruntime package specific to this Execution Provider.
+ ///
+ public void AppendExecutionProvider_MIGraphX(int deviceId)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(_nativePtr, deviceId));
+ }
+
///
/// Use only if you have the onnxruntime package specific to this Execution Provider.
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index b17baab6c4..b0e30d12e9 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -97,6 +97,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_TENSORRT
opt.AppendExecutionProvider_Tensorrt(0);
#endif
+#if USE_MIGRAPHX
+ opt.AppendExecutionProvider_MIGraphX(0);
+#endif
#if USE_NNAPI
opt.AppendExecutionProvider_Nnapi();
#endif
@@ -1614,6 +1617,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_TENSORRT
,"OrtSessionOptionsAppendExecutionProvider_Tensorrt"
#endif
+#if USE_MIGRAPHX
+ ,"OrtSessionOptionsAppendExecutionProvider_MIGraphX"
+#endif
#if USE_NNAPI
,"OrtSessionOptionsAppendExecutionProvider_Nnapi"
#endif
diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx
new file mode 100644
index 0000000000..fcf33bee30
--- /dev/null
+++ b/dockerfiles/Dockerfile.migraphx
@@ -0,0 +1,49 @@
+# --------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------
+# Dockerfile to run ONNXRuntime with MIGraphX integration
+#--------------------------------------------------------------------------
+
+FROM ubuntu:16.04
+
+ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
+ARG ONNXRUNTIME_BRANCH=master
+ENV DEBIAN_FRONTEND noninteractive
+ENV LC_ALL C.UTF-8
+ENV LANG C.UTF-8
+ENV MIGRAPHX_DISABLE_FAST_GELU=1
+
+# Install rocm
+RUN apt-get update && apt-get install -y --no-install-recommends curl && \
+ curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \
+ sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/debian/ xenial main > /etc/apt/sources.list.d/rocm.list'
+
+RUN apt-get update &&\
+ apt-get install -y sudo git bash build-essential cmake libpython3.5-dev python3-pip miopen-hip rocblas half
+
+# Install rbuild
+RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz
+
+# Install MIGraphX from source
+RUN mkdir -p /migraphx
+RUN cd /migraphx && git clone --depth=1 --branch migraphx_for_ort https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src
+RUN cd /migraphx && rbuild package --cxx /opt/rocm/bin/hcc -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3
+RUN dpkg -i /migraphx/build/*.deb
+RUN rm -rf /migraphx
+
+WORKDIR /code
+ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH}
+
+# Workaround for broken cmake in hip's binary package
+RUN sed -i -e 's/hcc::hccrt;hcc::hc_am//g' /opt/rocm/hip/lib/cmake/hip/hip-targets-release.cmake
+ENV CXXFLAGS "-D__HIP_PLATFORM_HCC__=1"
+
+# Prepare onnxruntime repository & build onnxruntime
+RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
+ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
+ cd onnxruntime &&\
+ /bin/sh ./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_migraphx &&\
+ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\
+ cd .. &&\
+ rm -rf onnxruntime cmake-3.14.3-Linux-x86_64
diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h
index c348ab4ff2..d79b9bc8d1 100644
--- a/include/onnxruntime/core/framework/allocator.h
+++ b/include/onnxruntime/core/framework/allocator.h
@@ -23,13 +23,14 @@ struct OrtDevice {
// Pre-defined device types.
static const DeviceType CPU = 0;
- static const DeviceType GPU = 1; //CUDA
+ static const DeviceType GPU = 1; //CUDA or HIP
static const DeviceType FPGA = 2;
struct MemType {
// Pre-defined memory types.
static const MemoryType DEFAULT = 0;
static const MemoryType CUDA_PINNED = 1;
+ static const MemoryType HIP_PINNED = 2;
};
constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
@@ -141,6 +142,8 @@ namespace onnxruntime {
constexpr const char* CPU = "Cpu";
constexpr const char* CUDA = "Cuda";
constexpr const char* CUDA_PINNED = "CudaPinned";
+constexpr const char* MIGRAPHX = "MIGraphX";
+constexpr const char* MIGRAPHX_PINNED = "MIGraphXPinned";
constexpr const char* TRT = "Tensorrt";
constexpr const char* TRT_PINNED = "TensorrtPinned";
diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h
index b491b5f124..fa69c90113 100644
--- a/include/onnxruntime/core/graph/constants.h
+++ b/include/onnxruntime/core/graph/constants.h
@@ -22,6 +22,7 @@ constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc";
constexpr const char* kMSFeaturizersDomain = "com.microsoft.mlfeaturizers";
constexpr const char* kMSDmlDomain = "com.microsoft.dml";
constexpr const char* kNGraphDomain = "com.intel.ai";
+constexpr const char* kMIGraphXDomain = "";
constexpr const char* kVitisAIDomain = "com.xilinx";
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
@@ -34,5 +35,6 @@ constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider";
constexpr const char* kNnapiExecutionProvider = "NnapiExecutionProvider";
constexpr const char* kRknpuExecutionProvider = "RknpuExecutionProvider";
constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider";
+constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider";
constexpr const char* kAclExecutionProvider = "ACLExecutionProvider";
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h
new file mode 100644
index 0000000000..8f8219fde0
--- /dev/null
+++ b/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h
@@ -0,0 +1,15 @@
+// Copyright 2019 AMD AMDMIGraphX
+
+#include "onnxruntime_c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id);
+
+#ifdef __cplusplus
+}
+#endif
+
+
diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
index 47db160487..4c89debf15 100644
--- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
+++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
@@ -22,6 +22,7 @@
#include "onnxruntime/core/providers/nuphar/nuphar_provider_factory.h"
#include "onnxruntime/core/providers/openvino/openvino_provider_factory.h"
#include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h"
+#include "onnxruntime/core/providers/migraphx/migraphx_provider_factory.h"
#include "onnxruntime/core/providers/acl/acl_provider_factory.h"
#ifdef USE_DIRECTML
#include "onnxruntime/core/providers/dml/dml_provider_factory.h"
@@ -423,6 +424,22 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNup
#endif
}
+/*
+ * Class: ai_onnxruntime_OrtSession_SessionOptions
+ * Method: addMIGraphX
+ * Signature: (JJI)V
+ */
+JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIGraphX
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceNum) {
+ (void)jobj;
+ #ifdef USE_MIGRAPHX
+ checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_MIGraphX((OrtSessionOptions*) handle, deviceNum));
+ #else
+ (void)apiHandle;(void)handle;(void)deviceNum; // Parameters used when MIGraphX is defined.
+ throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with MIGraphX support.");
+ #endif
+}
+
/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: addDirectML
diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc
new file mode 100644
index 0000000000..2d443d43b7
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc
@@ -0,0 +1,60 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "migraphx_inc.h"
+#include "gpu_data_transfer.h"
+
+namespace onnxruntime {
+GPUDataTransfer::GPUDataTransfer() {
+ // create streams, default is nullptr
+ streams_[kHipStreamDefault] = nullptr;
+ hipStreamCreateWithFlags(&streams_[kHipStreamCopyIn], hipStreamNonBlocking);
+ hipStreamCreateWithFlags(&streams_[kHipStreamCopyOut], hipStreamNonBlocking);
+}
+
+GPUDataTransfer::~GPUDataTransfer() {
+ hipStreamDestroy(streams_[kHipStreamCopyIn]);
+ hipStreamDestroy(streams_[kHipStreamCopyOut]);
+}
+
+bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
+ return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED
+ || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED;
+}
+
+common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const {
+ size_t bytes = src.SizeInBytes();
+ const void* src_data = src.DataRaw();
+ void* dst_data = dst.MutableDataRaw();
+
+ auto& src_device = src.Location().device;
+ auto& dst_device = dst.Location().device;
+
+ if (dst_device.Type() == OrtDevice::GPU) {
+ if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) {
+ // copy from pinned memory to GPU, this is non-blocking
+ hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, streams_[exec_queue_id]);
+ } else if (src_device.Type() == OrtDevice::GPU) {
+ // copying between GPU, this is non-blocking
+ hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, streams_[kHipStreamDefault]);
+ } else {
+ // copy from other CPU memory to GPU, this is blocking
+ hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice);
+ }
+ } else if (src_device.Type() == OrtDevice::GPU) {
+ if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) {
+ // copying from GPU to pinned memory, this is non-blocking
+ hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, streams_[exec_queue_id]);
+ } else {
+ // copying from GPU to CPU memory, this is blocking
+ hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost);
+ }
+ } else {
+ // copying between cpu memory
+ memcpy(dst_data, src_data, bytes);
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h
new file mode 100644
index 0000000000..9b966236cd
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h
@@ -0,0 +1,36 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "migraphx_inc.h"
+#include "core/framework/data_transfer.h"
+
+namespace onnxruntime {
+
+enum HIPStreamType : int {
+ kHipStreamDefault = 0,
+ kHipStreamCopyIn,
+ kHipStreamCopyOut,
+ kTotalHipStreams,
+};
+
+class GPUDataTransfer : public IDataTransfer {
+ public:
+ GPUDataTransfer();
+ ~GPUDataTransfer();
+
+ bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;
+
+ common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override;
+
+ hipStream_t GetStream(int queue_id) const {
+ ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalHipStreams);
+ return streams_[queue_id];
+ }
+
+ private:
+ hipStream_t streams_[kTotalHipStreams];
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.cc b/onnxruntime/core/providers/migraphx/hip_allocator.cc
new file mode 100644
index 0000000000..bd2b4c785d
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/hip_allocator.cc
@@ -0,0 +1,71 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "migraphx_inc.h"
+#include "hip_allocator.h"
+#include "core/framework/allocatormgr.h"
+#include "core/framework/session_state.h"
+#include "hip_fence.h"
+#include "gpu_data_transfer.h"
+
+namespace onnxruntime {
+
+static const GPUDataTransfer* GetGPUDataTransfer(const SessionState* session_state) {
+ OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0);
+ OrtDevice cpu_device;
+ return dynamic_cast(session_state->GetDataTransferMgr().GetDataTransfer(gpu_device, cpu_device));
+}
+
+void HIPAllocator::CheckDevice() const {
+#ifndef NDEBUG
+ // check device to match at debug build
+ // if it's expected to change, call hipSetDevice instead of the check
+ int current_device;
+ hipGetDevice(¤t_device);
+ ORT_ENFORCE(current_device == info_.id);
+#endif
+}
+
+void* HIPAllocator::Alloc(size_t size) {
+ CheckDevice();
+ void* p = nullptr;
+ if (size > 0) {
+ hipMalloc((void**)&p, size);
+ }
+ return p;
+}
+
+void HIPAllocator::Free(void* p) {
+ CheckDevice();
+ hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown
+}
+
+const OrtMemoryInfo& HIPAllocator::Info() const {
+ return info_;
+}
+
+FencePtr HIPAllocator::CreateFence(const SessionState* session_state) {
+ return std::make_shared(GetGPUDataTransfer(session_state));
+}
+
+void* HIPPinnedAllocator::Alloc(size_t size) {
+ void* p = nullptr;
+ if (size > 0) {
+ hipHostMalloc((void**)&p, size);
+ }
+ return p;
+}
+
+void HIPPinnedAllocator::Free(void* p) {
+ hipHostFree(p);
+}
+
+const OrtMemoryInfo& HIPPinnedAllocator::Info() const {
+ return info_;
+}
+
+FencePtr HIPPinnedAllocator::CreateFence(const SessionState* session_state) {
+ return std::make_shared(GetGPUDataTransfer(session_state));
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.h b/onnxruntime/core/providers/migraphx/hip_allocator.h
new file mode 100644
index 0000000000..bfd651ae08
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/hip_allocator.h
@@ -0,0 +1,38 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/framework/allocator.h"
+
+namespace onnxruntime {
+
+class HIPAllocator : public IDeviceAllocator {
+ public:
+ HIPAllocator(int device_id, const char* name) : info_(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeDefault) {}
+ virtual void* Alloc(size_t size) override;
+ virtual void Free(void* p) override;
+ virtual const OrtMemoryInfo& Info() const override;
+ virtual FencePtr CreateFence(const SessionState* session_state) override;
+
+ private:
+ void CheckDevice() const;
+
+ private:
+ const OrtMemoryInfo info_;
+};
+
+//TODO: add a default constructor
+class HIPPinnedAllocator : public IDeviceAllocator {
+ public:
+ HIPPinnedAllocator(int device_id, const char* name) : info_(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, device_id), device_id, OrtMemTypeCPUOutput) {}
+ virtual void* Alloc(size_t size) override;
+ virtual void Free(void* p) override;
+ virtual const OrtMemoryInfo& Info() const override;
+ virtual FencePtr CreateFence(const SessionState* session_state) override;
+
+ private:
+ const OrtMemoryInfo info_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/hip_fence.cc b/onnxruntime/core/providers/migraphx/hip_fence.cc
new file mode 100644
index 0000000000..44313c756a
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/hip_fence.cc
@@ -0,0 +1,53 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "migraphx_inc.h"
+#include "hip_fence.h"
+#include "gpu_data_transfer.h"
+
+namespace onnxruntime {
+
+HIPFence::HIPFence(const GPUDataTransfer* data_transfer) : data_transfer_(data_transfer) {
+ hipEventCreate(&read_event_);
+ hipEventCreate(&write_event_);
+}
+
+HIPFence::~HIPFence() {
+ hipEventDestroy(read_event_);
+ hipEventDestroy(write_event_);
+}
+
+void HIPFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) {
+ (void)provider_type;
+ (void)async_queue_id;
+ // sync on CPU for all other providers, this is blocking
+ hipEventSynchronize(write_event_);
+}
+
+void HIPFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) {
+ (void)provider_type;
+ (void)queue_id;
+
+ // sync on CPU for all other providers, this is blocking
+ hipEventSynchronize(read_event_);
+ hipEventSynchronize(write_event_);
+}
+
+bool HIPFence::CanRelease() {
+ return hipEventQuery(read_event_) == hipSuccess &&
+ hipEventQuery(write_event_) == hipSuccess;
+}
+
+void HIPFence::AfterUsedAsInput(int queue_id) {
+ // update read fence
+ hipStream_t stream = data_transfer_->GetStream(queue_id);
+ hipEventRecord(read_event_, stream);
+}
+
+void HIPFence::AfterUsedAsOutput(int queue_id) {
+ // update write fence
+ hipStream_t stream = data_transfer_->GetStream(queue_id);
+ hipEventRecord(write_event_, stream);
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/hip_fence.h b/onnxruntime/core/providers/migraphx/hip_fence.h
new file mode 100644
index 0000000000..ba9803ee37
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/hip_fence.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/framework/tensor.h"
+#include "core/graph/basic_types.h"
+
+namespace onnxruntime {
+class GPUDataTransfer;
+
+class HIPFence : public IFence {
+ public:
+ HIPFence(const GPUDataTransfer* data_transfer);
+ virtual ~HIPFence();
+ virtual void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) override;
+ virtual void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) override;
+ virtual void AfterUsedAsInput(int queue_id) override;
+ virtual void AfterUsedAsOutput(int queue_id) override;
+ virtual bool CanRelease() override;
+
+ private:
+ hipEvent_t read_event_;
+ hipEvent_t write_event_;
+ const GPUDataTransfer* data_transfer_;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
new file mode 100644
index 0000000000..4fbac1d6fa
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
@@ -0,0 +1,1297 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License
+
+#include "core/common/common.h"
+#include "core/common/logging/logging.h"
+#include "core/framework/compute_capability.h"
+#include "core/framework/allocatormgr.h"
+#include "core/framework/kernel_registry.h"
+#include "core/framework/memcpy.h"
+#include "core/graph/graph_viewer.h"
+#include "core/graph/model.h"
+#include "core/graph/graph_utils.h"
+#include "core/session/onnxruntime_cxx_api.h"
+#include "core/optimizer/reshape_fusion.h"
+#include "migraphx_inc.h"
+#include "migraphx_execution_provider.h"
+#include "hip_allocator.h"
+#include "gpu_data_transfer.h"
+#include
+
+#if defined(_MSC_VER)
+#pragma warning(disable : 4244 4245)
+#elif __GNUC__
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-parameter"
+#endif
+#if defined(_MSC_VER)
+#pragma warning(default : 4244 4245)
+#elif __GNUC__
+#pragma GCC diagnostic pop
+#endif
+
+#define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz))
+
+namespace onnxruntime {
+
+ONNX_OPERATOR_KERNEL_EX(
+ MemcpyFromHost,
+ kOnnxDomain,
+ 1,
+ kMIGraphXExecutionProvider,
+ KernelDefBuilder()
+ .InputMemoryType(0)
+ .ExecQueueId(kHipStreamCopyIn)
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
+ Memcpy);
+
+ONNX_OPERATOR_KERNEL_EX(
+ MemcpyToHost,
+ kOnnxDomain,
+ 1,
+ kMIGraphXExecutionProvider,
+ KernelDefBuilder()
+ .OutputMemoryType(0)
+ .ExecQueueId(kHipStreamCopyOut)
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
+ Memcpy);
+
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
+
+static void RegisterMIGraphXKernels(KernelRegistry& kernel_registry) {
+ static const BuildKernelCreateInfoFn function_table[] = {
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ };
+
+ for (auto& function_table_entry : function_table) {
+ kernel_registry.Register(function_table_entry());
+ }
+}
+
+std::shared_ptr GetMIGraphXKernelRegistry() {
+ std::shared_ptr kernel_registry = std::make_shared();
+ RegisterMIGraphXKernels(*kernel_registry);
+
+ return kernel_registry;
+}
+
+std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() const {
+ static std::shared_ptr kernel_registry = onnxruntime::GetMIGraphXKernelRegistry();
+ return kernel_registry;
+}
+
+MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info)
+ : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider} {
+
+ // Set GPU device to be used
+ hipSetDevice(info.device_id);
+ DeviceAllocatorRegistrationInfo default_memory_info(
+ {OrtMemTypeDefault, [](int id) { return onnxruntime::make_unique(id, MIGRAPHX); }, std::numeric_limits::max()});
+ allocator_ = CreateAllocator(default_memory_info, device_id_);
+ InsertAllocator(allocator_);
+
+
+ DeviceAllocatorRegistrationInfo pinned_memory_info(
+ {OrtMemTypeCPUOutput, [](int) { return onnxruntime::make_unique(0, MIGRAPHX_PINNED); }, std::numeric_limits::max()});
+ InsertAllocator(CreateAllocator(pinned_memory_info, device_id_));
+
+
+ // create the target based on the device_id
+ hipDeviceProp_t prop;
+ hipGetDeviceProperties(&prop, device_id_);
+ std::set valid_targets = {"gpu", "cpu"};
+ if (valid_targets.count(info.target_device) == 0)
+ {
+ LOGS_DEFAULT(FATAL) << "Device " << info.target_device << " are not supported";
+ }
+
+ t_ = migraphx::target(info.target_device.c_str());
+}
+
+AllocatorPtr MIGraphXExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const {
+ if (mem_type == OrtMemTypeDefault) {
+ return allocator_;
+ } else {
+ return IExecutionProvider::GetAllocator(id, mem_type);
+ }
+}
+
+std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const {
+ return onnxruntime::make_unique();
+}
+
+static bool IsTypeSupported(const NodeArg* node_arg) {
+ const auto* type_proto = node_arg->TypeAsProto();
+ if (!type_proto) {
+ return false;
+ }
+
+ switch (type_proto->tensor_type().elem_type()) {
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32:
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+static bool get_migraphx_type(ONNXTensorElementDataType type,
+ migraphx_shape_datatype_t &mgx_type)
+{
+ mgx_type = migraphx_shape_float_type;
+ switch(type) {
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16:
+ mgx_type = migraphx_shape_half_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT:
+ mgx_type = migraphx_shape_float_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE:
+ mgx_type = migraphx_shape_double_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8:
+ mgx_type = migraphx_shape_int8_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16:
+ mgx_type = migraphx_shape_int16_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32:
+ mgx_type = migraphx_shape_int32_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64:
+ mgx_type = migraphx_shape_int64_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8:
+ mgx_type = migraphx_shape_uint8_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16:
+ mgx_type = migraphx_shape_uint16_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32:
+ mgx_type = migraphx_shape_uint32_type;
+ break;
+ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT64:
+ mgx_type = migraphx_shape_uint64_type;
+ break;
+ default:
+ LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU";
+ LOGS_DEFAULT(WARNING) << "implementation" << std::endl;
+ return false;
+ }
+
+ return true;
+}
+
+static bool can_eval_concat(const Node* concat, const InitializedTensorSet& initializers, const logging::Logger& logger)
+{
+ if (concat == nullptr) return true;
+ const auto concat_args = concat->InputDefs();
+ if (concat_args.size() != 3)
+ {
+ return false;
+ }
+
+ auto arg_0 = concat_args[0];
+ bool b_found = (initializers.find(arg_0->Name()) != initializers.end());
+ auto arg_2 = concat_args[2];
+ b_found &= (initializers.find(arg_2->Name()) != initializers.end());
+ if (b_found)
+ {
+ std::vector parent_path{
+ {0, 1, "Unsqueeze", {1, 11}, kOnnxDomain},
+ {0, 0, "Gather", {1, 11}, kOnnxDomain},
+ {0, 0, "Shape", {1}, kOnnxDomain}};
+ std::vector edges;
+ b_found = graph_utils::FindPath(*concat, true, parent_path, edges, logger);
+ if (b_found)
+ {
+ const Node& gather = edges[1]->GetNode();
+ const auto* arg_index = gather.InputDefs()[1];
+ if (initializers.find(arg_index->Name()) != initializers.end())
+ {
+ return true;
+ }
+ }
+ }
+
+ return false;
+}
+
+static bool can_eval_cast(const Node* cast, const InitializedTensorSet& initializers, const logging::Logger& logger)
+{
+ std::vector parent_path = {
+ {0, 0, "Concat", {1, 4, 11}, kOnnxDomain},
+ {0, 0, "Unsqueeze", {1}, kOnnxDomain},
+ {0, 0, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Squeeze", {1, 11}, kOnnxDomain},
+ {0, 0, "Slice", {1, 10, 11}, kOnnxDomain},
+ {0, 0, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Shape", {1}, kOnnxDomain}};
+ std::vector edges;
+ if (graph_utils::FindPath(*cast, true, parent_path, edges, logger)) {
+ const Node& concat = edges[1]->GetNode();
+ const Node& slice = edges[5]->GetNode();
+ const auto& concat_args = concat.InputDefs();
+ bool const_flag = true;
+ for (std::size_t i = 1; i < concat_args.size(); i++)
+ {
+ const_flag &= (initializers.find(concat_args[i]->Name()) != initializers.end());
+ }
+ if (const_flag)
+ {
+ const auto& slice_args = slice.InputDefs();
+ for (std::size_t i = 1; i < slice_args.size(); ++i)
+ {
+ const_flag &= (initializers.find(slice_args[i]->Name()) != initializers.end());
+ }
+ }
+ if (const_flag)
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static bool can_eval_input_shape(const Node* node, const InitializedTensorSet& initializers, const logging::Logger& logger)
+{
+ // scenario 1: [Root] --> Shape --> Cast --> Cast
+ std::vector parent_path{
+ {0, 1, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Shape", {1}, kOnnxDomain}};
+
+ std::vector edges;
+ if (graph_utils::FindPath(*node, true, parent_path, edges, logger))
+ {
+ return true;
+ }
+
+ // scenario 2:
+ const Node* concat = graph_utils::GetInputNode(*node, 1);
+ if (concat and concat->OpType() == "Concat")
+ {
+ if (can_eval_concat(concat, initializers, logger))
+ {
+ return true;
+ }
+ }
+
+ // scenario 3:
+ const Node* cast = graph_utils::GetInputNode(*node, 1);
+ if (cast and cast->OpType() == "Cast")
+ {
+ if (can_eval_cast(cast, initializers, logger))
+ {
+ return true;
+ }
+ }
+
+ // scenario 4:
+ std::vector parent_path_4 = {
+ {0, 1, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Concat", {1, 4, 11}, kOnnxDomain},
+ {0, 0, "Unsqueeze", {1}, kOnnxDomain},
+ {0, 0, "Mul", {1, 6, 7}, kOnnxDomain},
+ {0, 0, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Squeeze", {1, 11}, kOnnxDomain},
+ {0, 0, "Slice", {1, 10, 11}, kOnnxDomain},
+ {0, 0, "Cast", {1, 6, 9}, kOnnxDomain},
+ {0, 0, "Shape", {1}, kOnnxDomain}};
+ if (graph_utils::FindPath(*node, true, parent_path_4, edges, logger)) {
+ const Node& concat = edges[1]->GetNode();
+ const Node& mul = edges[3]->GetNode();
+ const Node& slice = edges[6]->GetNode();
+ const auto& concat_args = concat.InputDefs();
+ bool const_flag = true;
+ for (std::size_t i = 1; i < concat_args.size(); i++)
+ {
+ const_flag &= (initializers.find(concat_args[i]->Name()) != initializers.end());
+ }
+ if (const_flag)
+ {
+ const auto& mul_args = mul.InputDefs();
+ const_flag &= (initializers.find(mul_args[1]->Name()) != initializers.end());
+ }
+ if (const_flag)
+ {
+ const auto& slice_args = slice.InputDefs();
+ for (std::size_t i = 1; i < slice_args.size(); ++i)
+ {
+ const_flag &= (initializers.find(slice_args[i]->Name()) != initializers.end());
+ }
+ }
+ if (const_flag)
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) {
+ const auto& optype = node->OpType();
+ const auto& initializers = graph_viewer.GetAllInitializedTensors();
+ if (optype == "AveragePool") {
+ // ceil_mode attribute is not supported in MIGraphX
+ const auto& attributes = node->GetAttributes();
+ const auto ceil_attr = attributes.find("ceil_mode");
+ // default value of ceil_mode (0) is supported.
+ if (ceil_attr != attributes.end() && ceil_attr->second.i() != 0) {
+ return true;
+ }
+
+ // input can only have 4 dims
+ const auto input_shape = node->InputDefs()[0]->Shape();
+ if (input_shape != nullptr and input_shape->dim_size() != 4)
+ {
+ return true;
+ }
+
+ // migraphx does not support count_include_pad to be 1
+ const auto cip_attr = attributes.find("count_include_pad");
+ if (cip_attr != attributes.end() && cip_attr->second.i() != 0)
+ {
+ return true;
+ }
+
+ const auto ap_attr = attributes.find("auto_pad");
+ if (ap_attr != attributes.end())
+ {
+ // explicit pad should be symmetric in migraphx
+ auto s_pad = ap_attr->second.s();
+ auto pads_attr = attributes.find("pads");
+ if (s_pad == "NOTSET")
+ {
+ if (pads_attr != attributes.end())
+ {
+ auto pads = pads_attr->second.ints();
+ if (pads.size() != 4)
+ {
+ return true;
+ }
+
+ if ((pads[0] != pads[2]) || (pads[1] != pads[3]))
+ {
+ return true;
+ }
+ }
+ }
+ // either SAME_UPPER or SAME_LOWER
+ else if (s_pad.find("SAME") != std::string::npos)
+ {
+ // pads cannot exist when auto_pad is same_upper or same_lower
+ if (pads_attr != attributes.end())
+ {
+ return true;
+ }
+
+ // compute the padding size to see whether they are symmetric
+ std::vector strides = {1, 1};
+ auto stride_attr = attributes.find("strides");
+ if (stride_attr != attributes.end())
+ {
+ auto attr_strides = stride_attr->second.ints();
+ strides.clear();
+ std::copy(attr_strides.begin(), attr_strides.end(), std::back_inserter(strides));
+ }
+
+ std::vector kernel_lens = {1, 1};
+ auto kernel_attr = attributes.find("kernel_shape");
+ if (kernel_attr != attributes.end())
+ {
+ auto attr_k = kernel_attr->second.ints();
+ std::copy(attr_k.begin(), attr_k.end(), kernel_lens.begin());
+ }
+
+ auto tensor_dims = input_shape->dim();
+ std::vector in_lens;
+ std::transform(tensor_dims.begin(),
+ tensor_dims.end(),
+ std::back_inserter(in_lens),
+ [&](auto&& d) -> std::size_t {
+ if(d.has_dim_value())
+ {
+ return d.dim_value();
+ }
+ return 1;
+ });
+
+ std::vector out_lens(2);
+ out_lens[0] = (in_lens[2] + strides[0] - 1) / strides[0];
+ out_lens[1] = (in_lens[3] + strides[1] - 1) / strides[1];
+ std::vector explicit_pads(2);
+ explicit_pads[0] = (out_lens[0] - 1) * strides[0] + kernel_lens[0] - in_lens[2];
+ explicit_pads[1] = (out_lens[1] - 1) * strides[1] + kernel_lens[1] - in_lens[3];
+ if ((explicit_pads[0] & 1) != 0 or (explicit_pads[1] & 1) != 0)
+ {
+ return true;
+ }
+ }
+ }
+ } else if (optype == "BatchNormalization") {
+ // input can only have 4 dims
+ const auto input_shape = node->InputDefs()[0]->Shape();
+ if (input_shape != nullptr and input_shape->dim_size() != 4)
+ {
+ return true;
+ }
+ } else if (optype == "Clip") {
+ auto args = node->InputDefs();
+ if (args.size() >= 3)
+ {
+ if (initializers.find(args[2]->Name()) == initializers.end())
+ return true;
+ }
+ if (args.size() >= 2)
+ {
+ if (initializers.find(args[1]->Name()) == initializers.end())
+ return true;
+ }
+ } else if (optype == "Conv") {
+ // input can only have 4 dims
+ const auto input_shape = node->InputDefs()[0]->Shape();
+ if (input_shape != nullptr and input_shape->dim_size() != 4)
+ {
+ return true;
+ }
+ } else if (optype == "ConstantOfShape") {
+ const auto shape_arg = node->InputDefs()[0];
+ if (initializers.find(shape_arg->Name()) != initializers.end())
+ {
+ return false;
+ }
+ const Node* shape_node = graph_utils::GetInputNode(*node, 0);
+ if (shape_node and shape_node->OpType() == "Concat")
+ {
+ if (can_eval_concat(shape_node, initializers, logger))
+ {
+ return false;
+ }
+ }
+ else if (shape_node and shape_node->OpType() == "Cast")
+ {
+ if (can_eval_cast(shape_node, initializers, logger))
+ {
+ return false;
+ }
+ }
+ return true;
+ } else if (optype == "ConvInteger") {
+ if (node->InputDefs()[0]->Shape()->dim_size() != 4)
+ {
+ return true;
+ }
+
+ // migraphx can handle only two inputs
+ if (node->InputDefs().size() != 2)
+ {
+ return true;
+ }
+
+ // only support int8 type
+ const auto& input_type = node->InputDefs()[0]->TypeAsProto();
+ if (input_type == nullptr)
+ {
+ return true;
+ }
+
+ if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)
+ {
+ return true;
+ }
+ } else if (optype == "Expand") {
+ // MIGraphX only supports constant shape input values
+ const auto& shape_input = node->InputDefs()[1];
+ return !graph_viewer.IsConstantInitializer(shape_input->Name(), true);
+ } else if (optype == "MaxPool") {
+ //MaxPool "indices" output is not currently supported.
+ if (node->OutputDefs().size() > 1) {
+ return true;
+ }
+
+ // ceil_mode and dilations attrs are not supported in MIGraphX
+ const auto& attributes = node->GetAttributes();
+ const auto ceil_attr = attributes.find("ceil_mode");
+ // default value of ceil_mode (0) is supported.
+ if (ceil_attr != attributes.end() and ceil_attr->second.i() != 0) {
+ return true;
+ }
+
+ auto dila_attr = attributes.find("dilations");
+ if (dila_attr != attributes.end()) {
+ auto dilas = dila_attr->second.ints();
+ bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1;});
+ if (ret == false)
+ {
+ return true;
+ }
+ }
+
+ // storage order 1 (column major format) is not supported
+ const auto storage_order_attr = attributes.find("storage_order");
+ if (storage_order_attr != attributes.end() and storage_order_attr->second.i() != 0)
+ {
+ return true;
+ }
+
+ // input can only have 4 dims
+ const auto input_shape = node->InputDefs()[0]->Shape();
+ if (input_shape != nullptr and input_shape->dim_size() != 4)
+ {
+ return true;
+ }
+ } else if (optype == "MatMulInteger") {
+ // migraphx can handle only two inputs
+ if (node->InputDefs().size() != 2)
+ {
+ return true;
+ }
+
+ // only support int8 type
+ const auto& input_type = node->InputDefs()[0]->TypeAsProto();
+ if (input_type == nullptr)
+ {
+ return true;
+ }
+
+ if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8)
+ {
+ return true;
+ }
+ } else if (optype == "OneHot") {
+ const auto& arg_depth = node->InputDefs()[1];
+ return (initializers.find(arg_depth->Name()) == initializers.end());
+ } else if (optype == "Pad") {
+ const auto& args = node->InputDefs();
+ // if pad size is not constant, migraphx cannot support
+ if (args.size() >= 2)
+ {
+ const auto& shape_arg = node->InputDefs()[1];
+ if (initializers.find(shape_arg->Name()) == initializers.end()) {
+ return true;
+ }
+ }
+
+ const auto& attributes = node->GetAttributes();
+ // Pad only support constant mode
+ const auto mode_attr = attributes.find("mode");
+ std::string mode = "constant";
+ if(mode_attr != attributes.end())
+ {
+ mode = mode_attr->second.s();
+ }
+ static const std::set allowed_modes = {"constant", "reflect"};
+ if (allowed_modes.count(mode) == 0)
+ {
+ return true;
+ }
+
+ // input value only applied to constant mode
+ if (mode == "constant")
+ {
+ if (args.size() == 3)
+ {
+ const auto& val_arg = node->InputDefs()[2];
+ if (initializers.find(val_arg->Name()) == initializers.end()) {
+ return true;
+ }
+ }
+ }
+ } else if (optype == "Range") {
+ const auto& args = node->InputDefs();
+ if (!std::all_of(args.begin(), args.end(), [&](auto arg) {
+ return (initializers.find(arg->Name()) != initializers.end());
+ })) {
+ return true;
+ }
+ } else if (optype == "Reshape") {
+ const auto& args = node->InputDefs();
+ if (args.size() == 2)
+ {
+ const auto& shape_arg = args[1];
+ if (initializers.find(shape_arg->Name()) != initializers.end()) {
+ return false;
+ }
+
+ if (can_eval_input_shape(node, initializers, logger))
+ {
+ return false;
+ }
+ }
+ return true;
+ } else if (optype == "Slice") {
+ // MIGraphX does not properly handle the situation where any
+ // value of the "starts" attribute is higher than a corresponding
+ // value in the "ends"
+ const auto& args = node->InputDefs();
+ if (args.size() == 5) {
+ return true;
+ }
+
+ const auto& attributes = node->GetAttributes();
+ if (args.size() >= 4) {
+ if (initializers.find(args[3]->Name()) == initializers.end())
+ return true;
+ }
+
+ if (args.size() >= 3) {
+ if (initializers.find(args[2]->Name()) == initializers.end())
+ return true;
+ }
+
+ if (args.size() >= 2) {
+ if (initializers.find(args[1]->Name()) == initializers.end())
+ return true;
+ }
+
+ if (attributes.count("starts") > 0 and attributes.count("ends") > 0) {
+ const auto& starts = attributes.find("starts")->second.ints();
+ const auto& ends = attributes.find("ends")->second.ints();
+ for (int i = 0; i < starts.size(); ++i) {
+ if (starts.Get(i) > ends.Get(i)) {
+ return true;
+ }
+ }
+ }
+ }
+ else if (optype == "Tile")
+ {
+ const auto& args = node->InputDefs();
+ return (initializers.find(args[1]->Name()) == initializers.end());
+ }
+
+ //Op doesn't fall into known any of unsupported modes.
+ return false;
+}
+
+static bool IsNodeSupported(const std::set& op_set,
+ const onnxruntime::GraphViewer& graph_viewer,
+ const NodeIndex node_idx,
+ const logging::Logger& logger) {
+ const auto& node = graph_viewer.GetNode(node_idx);
+ const auto& optype = node->OpType();
+ const auto& domain = node->Domain();
+
+ // Three types of checking:
+ // 1. Check input and output data types are supported.
+ // 2. Check op_type is implemented in migraphx
+ // 3. Check the mode is implemented in migraphx
+ // if 3. failed, call the constant folding capability in migraphx
+ // to see whether some input parameters can be calculated statically
+ // check data type
+ bool are_types_supported = true;
+
+ node->ForEachDef([&are_types_supported](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) {
+ are_types_supported &= IsTypeSupported(&node_arg);
+ });
+
+ if (!are_types_supported) {
+ return false;
+ }
+
+ // whether an operator implemented in migraphx
+ if (op_set.count(optype) == 0) {
+ return false;
+ }
+
+ // check that some modes might not be supported in migraphx for some operators
+ if (domain == kOnnxDomain && IsUnsupportedOpMode(node, graph_viewer, logger)) {
+ // not supported, then check the constant folding capability of migraphx
+ // to see whether it is supported
+ return false;
+ }
+
+ return true;
+}
+
+static void AppendNodesToSubGraph(const std::vector& nodes,
+ const std::vector& inputs,
+ const std::vector& outputs,
+ std::vector>& result) {
+ static size_t op_counter = 0;
+
+ auto meta_def = onnxruntime::make_unique();
+ meta_def->name = "MIGraphX_" + std::to_string(++op_counter);
+ meta_def->domain = kMIGraphXDomain;
+ meta_def->since_version = 1;
+ meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL;
+ meta_def->inputs = inputs;
+ meta_def->outputs = outputs;
+
+ std::unique_ptr sub_graph = onnxruntime::make_unique();
+ sub_graph->nodes = nodes;
+ sub_graph->SetMetaDef(meta_def);
+ result.push_back(onnxruntime::make_unique(std::move(sub_graph)));
+}
+
+static std::vector
+GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
+ /*out*/ std::unordered_set& mgx_required_initializers,
+ const logging::Logger& logger) {
+ static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "ArgMax", "ArgMin",
+ "Asin", "Asinh", "Atan", "Atanh", "AveragePool", "BatchNormalization", "Cast", "Ceil", "Clip",
+ "Concat", "Constant", "ConstantFill", "ConstantOfShape", "Conv", "Cos", "Cosh", "Div", "Dropout",
+ "Elu", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", "Gemm", "GlobalAveragePool",
+ "GlobalMaxPool", "Identity", "ImageScaler", "InstanceNormalization", "LRN", "LSTM", "LeakyRelu",
+ "Log", "LogSoftmax", "MatMul", "Max", "MaxPool", "Min", "Mul", "OneHot", "Pad", "Pow", "PRelu",
+ "RNN","Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", "ReduceMax",
+ "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare", "Relu", "Reshape",
+ "Round", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Split", "Sqrt", "Squeeze",
+ "Sub", "Sum", "Tan", "Tanh", "Tile", "Transpose", "Unsqueeze"};
+ std::vector unsupported_nodes_idx;
+ for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) {
+ if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) {
+ // Collect inputs that are initializers
+ graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) {
+ if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) {
+ mgx_required_initializers.insert(node_arg.Name());
+ } }, true);
+ } else {
+ unsupported_nodes_idx.push_back(node_idx);
+ }
+ }
+
+ return unsupported_nodes_idx;
+}
+
+// Returns a vector clusters(or node_idx). For each unsupported node, the graph
+// is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph).
+// This functions returns vector of all supported_subgraphx by amdmigraphx
+static std::vector>
+GetPartitionedSubgraphs(const std::vector& topological_order, const std::vector& unsupported_nodes) {
+ std::vector> mgx_subgraphx;
+
+ auto prev = topological_order.begin();
+
+ for (const auto& unsup_node : unsupported_nodes) {
+ auto it = std::find(prev, topological_order.end(), unsup_node);
+ // Create a cluster vector[supported_node_idx, unsupported_node_idx)
+ // and append it to return list.
+ std::vector this_subgraph{prev, it};
+ if (!this_subgraph.empty()) {
+ mgx_subgraphx.push_back(std::move(this_subgraph));
+ }
+ // Point prev to node idx past this unsuported node.
+ prev = ++it;
+ }
+
+ // Tail
+ std::vector this_subgraph{prev, topological_order.end()};
+ if (!this_subgraph.empty()) {
+ mgx_subgraphx.push_back(std::move(this_subgraph));
+ }
+
+ return mgx_subgraphx;
+}
+
+static void GetInputsOutputsOfSubgraph(const GraphViewer& graph_viewer,
+ const std::vector& nodes,
+ const std::unordered_set& mgx_required_initializers,
+ std::vector& nodes_inputs,
+ std::vector& nodes_outputs) {
+ std::unordered_set input_args;
+ std::vector ordered_input_args;
+ std::unordered_set output_args;
+ std::unordered_set external_output_args;
+
+ for (const auto& node_idx : nodes) {
+ const auto& node = graph_viewer.GetNode(node_idx);
+
+ // Collect all inputs and outputs
+ node->ForEachDef(
+ [&input_args, &ordered_input_args, &output_args](const NodeArg& node_arg, bool is_input) {
+ if (is_input) {
+ if (!input_args.count(node_arg.Name())) {
+ ordered_input_args.push_back(node_arg.Name());
+ }
+ input_args.insert(node_arg.Name());
+ } else {
+ output_args.insert(node_arg.Name());
+ }
+ },
+ true);
+
+ // Check if output of this node is used by nodes outside
+ // subgraph. If yes add this to cluster outputs
+ for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) {
+ const auto& ext_node = graph_viewer.GetNode((*it).Index());
+
+ if (std::find(nodes.begin(), nodes.end(), ext_node->Index()) == nodes.end()) {
+ // Node is external to subgraph. Search through its
+ // inputs to find the output that is generated by subgraph.
+ std::set ext_node_inputs;
+ ext_node->ForEachDef(
+ [&ext_node_inputs](const onnxruntime::NodeArg& arg, bool is_input) {
+ if (is_input) {
+ ext_node_inputs.insert(arg.Name());
+ }
+ },
+ true);
+
+ for (const auto& out_def : node->OutputDefs()) {
+ if (ext_node_inputs.find(out_def->Name()) != ext_node_inputs.end()) {
+ external_output_args.insert(out_def->Name());
+ }
+ }
+ }
+ }
+ }
+
+ //Extract initializers used by subgraph.
+ std::unordered_set original_graph_inputs;
+ for (const auto& node_arg : graph_viewer.GetInputsIncludingInitializers()) {
+ original_graph_inputs.insert(node_arg->Name());
+ }
+
+ const auto& initializers = graph_viewer.GetAllInitializedTensors();
+ std::vector const_inputs;
+ for (const auto& in_arg : ordered_input_args) {
+ if ((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) ||
+ mgx_required_initializers.count(in_arg)) {
+ const_inputs.push_back(in_arg);
+ }
+ }
+
+ for (const auto& in_arg : ordered_input_args) {
+ if (!output_args.count(in_arg) &&
+ !((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) ||
+ mgx_required_initializers.count(in_arg))) {
+ nodes_inputs.push_back(in_arg);
+ }
+ }
+
+ for (const auto& in_arg : const_inputs) {
+ nodes_inputs.push_back(in_arg);
+ }
+
+ std::copy(external_output_args.begin(), external_output_args.end(), std::back_inserter(nodes_outputs));
+ for (const auto& node_arg : graph_viewer.GetOutputs()) {
+ const auto& name = node_arg->Name();
+ if (output_args.count(name) && !external_output_args.count(name)) {
+ nodes_outputs.push_back(name);
+ }
+ }
+}
+
+std::vector>
+MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
+ const std::vector& /*kernel_registries*/) const {
+
+ std::vector> result;
+ if (graph_viewer.IsSubgraph()) {
+ return result;
+ }
+
+ for (const auto& tensor : graph_viewer.GetAllInitializedTensors()) {
+ if (tensor.second->has_data_location() && tensor.second->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
+ LOGS_DEFAULT(WARNING) << "MIGraphX: Initializers with external data lepcation are not currently supported";
+ return result;
+ }
+ }
+
+ // Construct modelproto from graph
+ onnxruntime::Model model(graph_viewer.Name(), true, ModelMetaData(), PathString{},
+ IOnnxRuntimeOpSchemaRegistryList(), graph_viewer.DomainToVersionMap(),
+ std::vector(), *GetLogger());
+
+ std::unordered_map map_dim_param_values;
+ onnxruntime::Graph& graph_build = model.MainGraph();
+ for (const auto& node : graph_viewer.Nodes()) {
+ std::vector inputs, outputs;
+ for (auto input : node.InputDefs()) {
+ auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
+ inputs.push_back(&n_input);
+ }
+ for (auto output : node.OutputDefs()) {
+ auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
+ outputs.push_back(&n_output);
+ }
+ graph_build.AddNode(node.Name(), node.OpType(), node.Description(), inputs, outputs, &node.GetAttributes(), node.Domain());
+ }
+
+ //Add initializer to graph
+ std::size_t init_tensor_num = 0;
+ const auto& init_tensors = graph_viewer.GetAllInitializedTensors();
+ for (const auto& tensor : init_tensors) {
+ init_tensor_num++;
+ graph_build.AddInitializedTensor(*(tensor.second));
+ }
+
+ ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
+ model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+
+ auto status = graph_build.Resolve();
+ std::string onnx_string_buffer;
+ model_proto.SerializeToString(&onnx_string_buffer);
+
+ // This is a list of initializers that migraphx considers as constants.
+ // Example weights, reshape shape etc.
+ std::unordered_set mgx_required_initializers;
+ const auto unsupported_nodes = GetUnsupportedNodeIndices(graph_viewer, mgx_required_initializers, *GetLogger());
+
+ // Too many unsupported operators, fallback to run on CPU
+ if (unsupported_nodes.size() >= 6)
+ {
+ return result;
+ }
+
+ //If all ops are supported, no partitioning is required. Short-circuit and avoid splitting.
+ if (unsupported_nodes.empty()) {
+ std::vector inputs;
+ std::vector outputs;
+
+ //Fill inputs with names
+ std::for_each(graph_viewer.GetInputs().begin(), graph_viewer.GetInputs().end(),
+ [&inputs](const NodeArg* node_arg) { inputs.push_back(node_arg->Name()); });
+
+ // In scenarios, when there are no inputs or all inputs being initializers,
+ // ConstantFolding optimization in onnxruntime pre-computes the value.
+ if (inputs.empty()) {
+ return result;
+ }
+
+ // Initializers need to be part of meta_def->inputs
+ std::for_each(mgx_required_initializers.begin(), mgx_required_initializers.end(),
+ [&inputs](const std::string& initializer) { inputs.push_back(initializer); });
+
+ // Fill outputs with names
+ std::for_each(graph_viewer.GetOutputs().begin(), graph_viewer.GetOutputs().end(),
+ [&outputs](const NodeArg* node_arg) { outputs.push_back(node_arg->Name()); });
+
+ // Create and add this graph to result.
+ AppendNodesToSubGraph(graph_viewer.GetNodesInTopologicalOrder(), inputs, outputs, result);
+
+ } else { // unsupported_nodes_idx.empty()
+ const auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes);
+
+ for (const auto& this_cluster : mgx_clusters) {
+ std::vector cluster_inputs, cluster_outputs;
+ GetInputsOutputsOfSubgraph(graph_viewer, this_cluster, mgx_required_initializers, cluster_inputs, cluster_outputs);
+
+ if (!cluster_inputs.empty()) {
+ AppendNodesToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result);
+ }
+ }
+ }
+
+ return result;
+}
+
+static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node,
+ const logging::Logger& logger) {
+ const auto* node_function = fused_node->GetFunctionBody();
+
+ ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name());
+
+ const Graph& node_subgraph = node_function->Body();
+ onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, PathString{},
+ IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(),
+ std::vector(), logger};
+
+ ONNX_NAMESPACE::ModelProto model_proto = model.ToProto();
+ //model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+
+ *(model_proto.mutable_graph()) = node_subgraph.ToGraphProto();
+
+ auto opset = model_proto.add_opset_import();
+ opset->set_domain(kOnnxDomain);
+ opset->set_version(node_subgraph.DomainToVersionMap().at(kOnnxDomain));
+
+ return model_proto;
+}
+
+bool get_input_output_names(std::string& onnx_buffer,
+ std::vector& input_names,
+ std::vector& output_names)
+{
+ bool no_input_shape = false;
+
+ input_names.clear();
+ output_names.clear();
+ onnx::ModelProto model;
+ if (model.ParseFromArray(onnx_buffer.data(), onnx_buffer.size()))
+ {
+ if (model.has_graph())
+ {
+ // compute output names
+ auto& graph = model.graph();
+
+ // compute input names
+ std::unordered_set ini_names;
+ for(auto&& f : graph.initializer())
+ ini_names.insert(f.name());
+
+ for(auto&& input : graph.input())
+ {
+ const std::string& name = input.name();
+ if (ini_names.count(name) == 0)
+ {
+ input_names.push_back(name);
+ auto dim_size = input.type().tensor_type().shape().dim_size();
+ if (dim_size == 0)
+ {
+ no_input_shape = true;
+ }
+ }
+ }
+
+ auto prog_output = graph.output();
+ std::vector all_output_names;
+ std::vector prog_output_names;
+ std::transform(prog_output.begin(),
+ prog_output.end(),
+ std::back_inserter(all_output_names),
+ [](auto& node) { return node.name(); });
+ std::copy_if(
+ all_output_names.begin(),
+ all_output_names.end(),
+ std::back_inserter(output_names),
+ [&](const auto& name) { return !name.empty(); });
+ }
+ }
+
+ return no_input_shape;
+}
+
+Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes,
+ std::vector& node_compute_funcs) {
+ migraphx::onnx_options options;
+ bool no_input_shape = false;
+ // std::size_t fused_node_idx = 0;
+ for (const auto& fused_node : fused_nodes) {
+ // map parameter input name to index
+ std::unordered_map input_name_index;
+ const auto& input_defs = fused_node->InputDefs();
+ input_name_index.reserve(input_defs.size());
+ for (std::size_t i = 0; i < input_defs.size(); ++i) {
+ input_name_index[input_defs[i]->Name()] = i;
+ }
+
+ // reconstruct the subgraph proto from fused nodes
+ onnx::ModelProto model_proto = GetModelProtoFromFusedNode(fused_node, *GetLogger());
+ std::string onnx_string_buffer;
+ model_proto.SerializeToString(&onnx_string_buffer);
+ std::vector input_names, output_names;
+ no_input_shape |= get_input_output_names(onnx_string_buffer, input_names, output_names);
+
+ // by parsing the model_proto, create a program corresponding to
+ // the input fused_node
+ migraphx::program prog;
+
+ if (!no_input_shape)
+ {
+ prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options);
+ prog.compile(t_);
+
+ auto prog_output_shapes = prog.get_output_shapes();
+ for (std::size_t i = 0; i < output_names.size(); ++i)
+ {
+ auto out_len = prog_output_shapes[i].lengths();
+ options.set_input_parameter_shape(output_names[i], out_len);
+ }
+ }
+
+ // compile the program
+ map_progs_[fused_node->Name()] = prog;
+
+ map_onnx_string_[fused_node->Name()] = onnx_string_buffer;
+ map_input_index_[fused_node->Name()] = input_name_index;
+ map_no_input_shape_[fused_node->Name()] = no_input_shape;
+ NodeComputeInfo compute_info;
+ compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
+ std::unique_ptr p = onnxruntime::make_unique();
+ *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name],
+ map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_,
+ map_no_input_shape_[context->node_name]};
+ *state = p.release();
+ return 0;
+ };
+
+ compute_info.release_state_func = [](FunctionState state) {
+ if (state)
+ delete static_cast(state);
+ };
+
+ compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) {
+ Ort::CustomOpApi ort{*api};
+ MIGraphXFuncState* mgx_state = reinterpret_cast(state);
+ std::unordered_map& map_input_name_index = mgx_state->input_name_indexes;
+ migraphx::target t = mgx_state->t;
+ migraphx::program& prog = mgx_state->prog;
+ std::string& onnx_string = mgx_state->onnx_string;
+ migraphx::onnx_options& cmp_options = mgx_state->options;
+ bool &no_input_shape = mgx_state->no_input_shape;
+
+ // mean no program at all, so need to get the input shape info
+ // from input data
+ bool input_shape_match = true;
+ migraphx::program_parameter_shapes param_shapes;
+ if (no_input_shape)
+ {
+ for (auto& it : map_input_name_index)
+ {
+ auto& name = it.first;
+ auto& index = it.second;
+ const OrtValue* input_tensor = ort.KernelContext_GetInput(context, index);
+ auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
+ const auto& tensor_shape = ort.GetTensorShape(tensor_info);
+ std::vector ort_lens(tensor_shape.begin(), tensor_shape.end());
+ cmp_options.set_input_parameter_shape(name, ort_lens);
+ input_shape_match = false;
+ }
+ }
+ else
+ {
+ param_shapes = prog.get_parameter_shapes();
+ auto prog_output_shapes = prog.get_output_shapes();
+
+ // check whether input shapes match with shapes of program inputs
+ // migraphx::onnx_options cmp_options;
+ if (param_shapes.size() > 0)
+ {
+ for (auto&& name : param_shapes.names())
+ {
+ if (map_input_name_index.count(name) > 0)
+ {
+ const OrtValue* input_tensor = ort.KernelContext_GetInput(context, map_input_name_index[name]);
+ auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
+ const auto& tensor_shape = ort.GetTensorShape(tensor_info);
+ std::vector ort_lens(tensor_shape.begin(), tensor_shape.end());
+
+ auto mgx_s = param_shapes[name];
+ auto mgx_lens = mgx_s.lengths();
+ auto mgx_strides = mgx_s.strides();
+ if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and
+ mgx_strides.size() == 1 and mgx_strides[0] == 0)
+ {
+ mgx_lens.clear();
+ }
+
+ if (mgx_lens != ort_lens)
+ {
+ cmp_options.set_input_parameter_shape(name, ort_lens);
+ input_shape_match = false;
+ }
+ }
+ }
+ }
+ }
+
+ // input shapes are different, needs to re-parse onnx and
+ // re-compile the program
+ if (!input_shape_match)
+ {
+ prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options);
+ prog.compile(t);
+ mgx_state->prog = prog;
+ param_shapes = prog.get_parameter_shapes();
+ no_input_shape = false;
+ }
+
+ migraphx::program_parameters m;
+ auto prog_output_shapes = prog.get_output_shapes();
+ std::vector prog_output_indices;
+ if (param_shapes.size() > 0)
+ {
+ for (auto&& name : param_shapes.names())
+ {
+ if (map_input_name_index.count(name) > 0)
+ {
+ const OrtValue* input_tensor = ort.KernelContext_GetInput(context, map_input_name_index[name]);
+ auto tensor_info = ort.GetTensorTypeAndShape(input_tensor);
+ const auto& tensor_shape = ort.GetTensorShape(tensor_info);
+ auto tensor_type = ort.GetTensorElementType(tensor_info);
+ ort.ReleaseTensorTypeAndShapeInfo(tensor_info);
+
+ migraphx_shape_datatype_t mgx_type;
+ get_migraphx_type(tensor_type, mgx_type);
+ auto mgx_s = param_shapes[name];
+
+ if (mgx_type != mgx_s.type())
+ {
+ LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch";
+ }
+
+ m.add(name, migraphx::argument(param_shapes[name], const_cast(ort.GetTensorData(input_tensor))));
+ }
+ // It is a output argument
+ else
+ {
+ auto compute_output_index = [] (const std::string& name) -> int {
+ std::string out_name_prefix = "#output_";
+ auto pos = name.find(out_name_prefix);
+ if (pos == std::string::npos)
+ {
+ return -1;
+ }
+
+ std::string index_str = name.substr(pos + out_name_prefix.length());
+ return std::stoi(index_str);
+ };
+
+ int output_index = compute_output_index(name);
+ if (output_index != -1)
+ {
+ prog_output_indices.push_back(output_index);
+ auto mgx_output_shape = prog_output_shapes[output_index];
+ auto lens = mgx_output_shape.lengths();
+ std::vector ort_output_shape(lens.begin(), lens.end());
+ OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index, ort_output_shape.data(), ort_output_shape.size());
+ void* output_data = ort.GetTensorMutableData(output_tensor);
+
+ // argument shape
+ auto mgx_arg_shape = param_shapes[name];
+ m.add(name, migraphx::argument(mgx_arg_shape, output_data));
+ }
+ }
+ }
+ }
+
+ {
+ // lock to avoid race condition
+ std::lock_guard lock(*(mgx_state->mgx_mu_ptr));
+ auto prog_outputs = prog.eval(m);
+ hipDeviceSynchronize();
+
+ // In case of input parameters are reused as output parameter call hipMemcpy
+ auto output_num = prog_outputs.size();
+ if (prog_output_indices.size() < output_num)
+ {
+ for (std::size_t i = 0; i < output_num; ++i)
+ {
+ if (std::find(prog_output_indices.begin(), prog_output_indices.end(), i) != prog_output_indices.end())
+ continue;
+ auto gpu_res = prog_outputs[i];
+ migraphx::shape res_shape = gpu_res.get_shape();
+ auto res_lens = res_shape.lengths();
+ std::vector ort_shape{res_lens.begin(), res_lens.end()};
+ OrtValue* output_tensor = ort.KernelContext_GetOutput(context, i, ort_shape.data(), ort_shape.size());
+ void* output_data = ort.GetTensorMutableData(output_tensor);
+ hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice);
+ }
+ }
+ }
+
+ return Status::OK();
+ };
+ node_compute_funcs.push_back(compute_info);
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
new file mode 100644
index 0000000000..ee9dd3324c
--- /dev/null
+++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h
@@ -0,0 +1,63 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License
+
+#pragma once
+
+#include "core/framework/execution_provider.h"
+#include "core/platform/ort_mutex.h"
+#include