diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b9a774995d..a57463df78 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -84,6 +84,7 @@ option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir" option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump node input shapes and output data to standard output when executing the model." OFF) option(onnxruntime_USE_DML "Build with DirectML support" OFF) +option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) option(onnxruntime_USE_ACL "Build with ACL support" OFF) option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) @@ -854,6 +855,14 @@ if (onnxruntime_USE_TENSORRT) endif() endif() +if (onnxruntime_USE_MIGRAPHX) + if (WIN32) + message(FATAL_ERROR "MIGraphX does not support build in Windows!") + endif() + set(AMD_MIGRAPHX_HOME ${onnxruntime_MIGRAPHX_HOME}) + add_definitions(-DUSE_MIGRAPHX=1) +endif() + if (onnxruntime_USE_TVM) if (WIN32 AND MSVC) # wd4100: identifier' : unreferenced formal parameter diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 6815227589..a234638b63 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -82,6 +82,7 @@ target_link_libraries(onnxruntime PRIVATE ${PROVIDERS_NNAPI} ${PROVIDERS_RKNPU} ${PROVIDERS_TENSORRT} + ${PROVIDERS_MIGRAPHX} ${PROVIDERS_OPENVINO} ${PROVIDERS_NUPHAR} ${PROVIDERS_VITISAI} diff --git a/cmake/onnxruntime_csharp.cmake b/cmake/onnxruntime_csharp.cmake index a6856af498..5fd082f53c 100644 --- a/cmake/onnxruntime_csharp.cmake +++ b/cmake/onnxruntime_csharp.cmake @@ -22,6 +22,10 @@ if (onnxruntime_USE_TENSORRT) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_TENSORRT,") endif() +if (onnxruntime_USE_MIGRAPHX) + STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_MIGRAPHX,") +endif() + if (onnxruntime_USE_OPENVINO) STRING(APPEND CSHARP_PREPROCESSOR_DEFINES "USE_OPENVINO,") endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 9998321a45..2f33a36532 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -67,6 +67,10 @@ if(onnxruntime_USE_DML) set(PROVIDERS_DML onnxruntime_providers_dml) list(APPEND ONNXRUNTIME_PROVIDER_NAMES dml) endif() +if(onnxruntime_USE_MIGRAPHX) + set(PROVIDERS_MIGRAPHX onnxruntime_providers_migraphx) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES migraphx) +endif() if(onnxruntime_USE_OPENVINO) set(PROVIDERS_OPENVINO onnxruntime_providers_openvino) list(APPEND ONNXRUNTIME_PROVIDER_NAMES openvino) @@ -607,6 +611,32 @@ if (onnxruntime_USE_DML) set_target_properties(onnxruntime_providers_dml PROPERTIES FOLDER "ONNXRuntime") endif() +if (onnxruntime_USE_MIGRAPHX) + # Add search paths for default rocm installation + list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm) + + find_package(hip) + find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME}) + + set(migraphx_libs migraphx::c hip::host) + + file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" + ) + + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) + add_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs}) + set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") + target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT}) + onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnxruntime_framework onnx) + add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/migraphx DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) + set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) +endif() + if (onnxruntime_USE_ACL) add_definitions(-DUSE_ACL=1) file(GLOB_RECURSE onnxruntime_providers_acl_cc_srcs diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 5edaa75bc7..81625925b8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -85,6 +85,7 @@ set(onnxruntime_pybind11_state_libs ${PROVIDERS_CUDA} ${PROVIDERS_DNNL} ${PROVIDERS_TENSORRT} + ${PROVIDERS_MIGRAPHX} ${PROVIDERS_NGRAPH} ${PROVIDERS_OPENVINO} ${PROVIDERS_NUPHAR} diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d949310ed7..572485d74a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -314,6 +314,11 @@ if(onnxruntime_USE_DML) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_dml) endif() +if(onnxruntime_USE_MIGRAPHX) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx) +endif() + + file(GLOB_RECURSE onnxruntime_test_tvm_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/test/tvm/*.h" "${ONNXRUNTIME_ROOT}/test/tvm/*.cc" @@ -341,6 +346,7 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_CUDA} ${PROVIDERS_DNNL} ${PROVIDERS_TENSORRT} + ${PROVIDERS_MIGRAPHX} ${PROVIDERS_NGRAPH} ${PROVIDERS_OPENVINO} ${PROVIDERS_NUPHAR} diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 98e2cd013d..c099f8ce6c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -444,6 +444,9 @@ namespace Microsoft.ML.OnnxRuntime [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tensorrt(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Nnapi(IntPtr /*(OrtSessionOptions*)*/ options); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index 65d349aa76..0fe79c0475 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -135,6 +135,14 @@ namespace Microsoft.ML.OnnxRuntime NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(_nativePtr, deviceId)); } + /// + /// Use only if you have the onnxruntime package specific to this Execution Provider. + /// + public void AppendExecutionProvider_MIGraphX(int deviceId) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(_nativePtr, deviceId)); + } + /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index b17baab6c4..b0e30d12e9 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -97,6 +97,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests #if USE_TENSORRT opt.AppendExecutionProvider_Tensorrt(0); #endif +#if USE_MIGRAPHX + opt.AppendExecutionProvider_MIGraphX(0); +#endif #if USE_NNAPI opt.AppendExecutionProvider_Nnapi(); #endif @@ -1614,6 +1617,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests #if USE_TENSORRT ,"OrtSessionOptionsAppendExecutionProvider_Tensorrt" #endif +#if USE_MIGRAPHX + ,"OrtSessionOptionsAppendExecutionProvider_MIGraphX" +#endif #if USE_NNAPI ,"OrtSessionOptionsAppendExecutionProvider_Nnapi" #endif diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx new file mode 100644 index 0000000000..fcf33bee30 --- /dev/null +++ b/dockerfiles/Dockerfile.migraphx @@ -0,0 +1,49 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with MIGraphX integration +#-------------------------------------------------------------------------- + +FROM ubuntu:16.04 + +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime +ARG ONNXRUNTIME_BRANCH=master +ENV DEBIAN_FRONTEND noninteractive +ENV LC_ALL C.UTF-8 +ENV LANG C.UTF-8 +ENV MIGRAPHX_DISABLE_FAST_GELU=1 + +# Install rocm +RUN apt-get update && apt-get install -y --no-install-recommends curl && \ + curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \ + sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/debian/ xenial main > /etc/apt/sources.list.d/rocm.list' + +RUN apt-get update &&\ + apt-get install -y sudo git bash build-essential cmake libpython3.5-dev python3-pip miopen-hip rocblas half + +# Install rbuild +RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz + +# Install MIGraphX from source +RUN mkdir -p /migraphx +RUN cd /migraphx && git clone --depth=1 --branch migraphx_for_ort https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src +RUN cd /migraphx && rbuild package --cxx /opt/rocm/bin/hcc -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 +RUN dpkg -i /migraphx/build/*.deb +RUN rm -rf /migraphx + +WORKDIR /code +ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH} + +# Workaround for broken cmake in hip's binary package +RUN sed -i -e 's/hcc::hccrt;hcc::hc_am//g' /opt/rocm/hip/lib/cmake/hip/hip-targets-release.cmake +ENV CXXFLAGS "-D__HIP_PLATFORM_HCC__=1" + +# Prepare onnxruntime repository & build onnxruntime +RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ + /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ + cd onnxruntime &&\ + /bin/sh ./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_migraphx &&\ + pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ + cd .. &&\ + rm -rf onnxruntime cmake-3.14.3-Linux-x86_64 diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index c348ab4ff2..d79b9bc8d1 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -23,13 +23,14 @@ struct OrtDevice { // Pre-defined device types. static const DeviceType CPU = 0; - static const DeviceType GPU = 1; //CUDA + static const DeviceType GPU = 1; //CUDA or HIP static const DeviceType FPGA = 2; struct MemType { // Pre-defined memory types. static const MemoryType DEFAULT = 0; static const MemoryType CUDA_PINNED = 1; + static const MemoryType HIP_PINNED = 2; }; constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) @@ -141,6 +142,8 @@ namespace onnxruntime { constexpr const char* CPU = "Cpu"; constexpr const char* CUDA = "Cuda"; constexpr const char* CUDA_PINNED = "CudaPinned"; +constexpr const char* MIGRAPHX = "MIGraphX"; +constexpr const char* MIGRAPHX_PINNED = "MIGraphXPinned"; constexpr const char* TRT = "Tensorrt"; constexpr const char* TRT_PINNED = "TensorrtPinned"; diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index b491b5f124..fa69c90113 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -22,6 +22,7 @@ constexpr const char* kMSNchwcDomain = "com.microsoft.nchwc"; constexpr const char* kMSFeaturizersDomain = "com.microsoft.mlfeaturizers"; constexpr const char* kMSDmlDomain = "com.microsoft.dml"; constexpr const char* kNGraphDomain = "com.intel.ai"; +constexpr const char* kMIGraphXDomain = ""; constexpr const char* kVitisAIDomain = "com.xilinx"; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; @@ -34,5 +35,6 @@ constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; constexpr const char* kNnapiExecutionProvider = "NnapiExecutionProvider"; constexpr const char* kRknpuExecutionProvider = "RknpuExecutionProvider"; constexpr const char* kDmlExecutionProvider = "DmlExecutionProvider"; +constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; constexpr const char* kAclExecutionProvider = "ACLExecutionProvider"; } // namespace onnxruntime diff --git a/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h new file mode 100644 index 0000000000..8f8219fde0 --- /dev/null +++ b/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -0,0 +1,15 @@ +// Copyright 2019 AMD AMDMIGraphX + +#include "onnxruntime_c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); + +#ifdef __cplusplus +} +#endif + + diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 47db160487..4c89debf15 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -22,6 +22,7 @@ #include "onnxruntime/core/providers/nuphar/nuphar_provider_factory.h" #include "onnxruntime/core/providers/openvino/openvino_provider_factory.h" #include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h" +#include "onnxruntime/core/providers/migraphx/migraphx_provider_factory.h" #include "onnxruntime/core/providers/acl/acl_provider_factory.h" #ifdef USE_DIRECTML #include "onnxruntime/core/providers/dml/dml_provider_factory.h" @@ -423,6 +424,22 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addNup #endif } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addMIGraphX + * Signature: (JJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIGraphX + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceNum) { + (void)jobj; + #ifdef USE_MIGRAPHX + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_MIGraphX((OrtSessionOptions*) handle, deviceNum)); + #else + (void)apiHandle;(void)handle;(void)deviceNum; // Parameters used when MIGraphX is defined. + throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with MIGraphX support."); + #endif +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addDirectML diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc new file mode 100644 index 0000000000..2d443d43b7 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "migraphx_inc.h" +#include "gpu_data_transfer.h" + +namespace onnxruntime { +GPUDataTransfer::GPUDataTransfer() { + // create streams, default is nullptr + streams_[kHipStreamDefault] = nullptr; + hipStreamCreateWithFlags(&streams_[kHipStreamCopyIn], hipStreamNonBlocking); + hipStreamCreateWithFlags(&streams_[kHipStreamCopyOut], hipStreamNonBlocking); +} + +GPUDataTransfer::~GPUDataTransfer() { + hipStreamDestroy(streams_[kHipStreamCopyIn]); + hipStreamDestroy(streams_[kHipStreamCopyOut]); +} + +bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED + || dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED; +} + +common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const { + size_t bytes = src.SizeInBytes(); + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + // copy from pinned memory to GPU, this is non-blocking + hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, streams_[exec_queue_id]); + } else if (src_device.Type() == OrtDevice::GPU) { + // copying between GPU, this is non-blocking + hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, streams_[kHipStreamDefault]); + } else { + // copy from other CPU memory to GPU, this is blocking + hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice); + } + } else if (src_device.Type() == OrtDevice::GPU) { + if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { + // copying from GPU to pinned memory, this is non-blocking + hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, streams_[exec_queue_id]); + } else { + // copying from GPU to CPU memory, this is blocking + hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost); + } + } else { + // copying between cpu memory + memcpy(dst_data, src_data, bytes); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h new file mode 100644 index 0000000000..9b966236cd --- /dev/null +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "migraphx_inc.h" +#include "core/framework/data_transfer.h" + +namespace onnxruntime { + +enum HIPStreamType : int { + kHipStreamDefault = 0, + kHipStreamCopyIn, + kHipStreamCopyOut, + kTotalHipStreams, +}; + +class GPUDataTransfer : public IDataTransfer { + public: + GPUDataTransfer(); + ~GPUDataTransfer(); + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst, int exec_queue_id) const override; + + hipStream_t GetStream(int queue_id) const { + ORT_ENFORCE(queue_id >= 0 && queue_id < kTotalHipStreams); + return streams_[queue_id]; + } + + private: + hipStream_t streams_[kTotalHipStreams]; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.cc b/onnxruntime/core/providers/migraphx/hip_allocator.cc new file mode 100644 index 0000000000..bd2b4c785d --- /dev/null +++ b/onnxruntime/core/providers/migraphx/hip_allocator.cc @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "migraphx_inc.h" +#include "hip_allocator.h" +#include "core/framework/allocatormgr.h" +#include "core/framework/session_state.h" +#include "hip_fence.h" +#include "gpu_data_transfer.h" + +namespace onnxruntime { + +static const GPUDataTransfer* GetGPUDataTransfer(const SessionState* session_state) { + OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0); + OrtDevice cpu_device; + return dynamic_cast(session_state->GetDataTransferMgr().GetDataTransfer(gpu_device, cpu_device)); +} + +void HIPAllocator::CheckDevice() const { +#ifndef NDEBUG + // check device to match at debug build + // if it's expected to change, call hipSetDevice instead of the check + int current_device; + hipGetDevice(¤t_device); + ORT_ENFORCE(current_device == info_.id); +#endif +} + +void* HIPAllocator::Alloc(size_t size) { + CheckDevice(); + void* p = nullptr; + if (size > 0) { + hipMalloc((void**)&p, size); + } + return p; +} + +void HIPAllocator::Free(void* p) { + CheckDevice(); + hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown +} + +const OrtMemoryInfo& HIPAllocator::Info() const { + return info_; +} + +FencePtr HIPAllocator::CreateFence(const SessionState* session_state) { + return std::make_shared(GetGPUDataTransfer(session_state)); +} + +void* HIPPinnedAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + hipHostMalloc((void**)&p, size); + } + return p; +} + +void HIPPinnedAllocator::Free(void* p) { + hipHostFree(p); +} + +const OrtMemoryInfo& HIPPinnedAllocator::Info() const { + return info_; +} + +FencePtr HIPPinnedAllocator::CreateFence(const SessionState* session_state) { + return std::make_shared(GetGPUDataTransfer(session_state)); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.h b/onnxruntime/core/providers/migraphx/hip_allocator.h new file mode 100644 index 0000000000..bfd651ae08 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/hip_allocator.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace onnxruntime { + +class HIPAllocator : public IDeviceAllocator { + public: + HIPAllocator(int device_id, const char* name) : info_(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeDefault) {} + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + virtual const OrtMemoryInfo& Info() const override; + virtual FencePtr CreateFence(const SessionState* session_state) override; + + private: + void CheckDevice() const; + + private: + const OrtMemoryInfo info_; +}; + +//TODO: add a default constructor +class HIPPinnedAllocator : public IDeviceAllocator { + public: + HIPPinnedAllocator(int device_id, const char* name) : info_(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, device_id), device_id, OrtMemTypeCPUOutput) {} + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + virtual const OrtMemoryInfo& Info() const override; + virtual FencePtr CreateFence(const SessionState* session_state) override; + + private: + const OrtMemoryInfo info_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/hip_fence.cc b/onnxruntime/core/providers/migraphx/hip_fence.cc new file mode 100644 index 0000000000..44313c756a --- /dev/null +++ b/onnxruntime/core/providers/migraphx/hip_fence.cc @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "migraphx_inc.h" +#include "hip_fence.h" +#include "gpu_data_transfer.h" + +namespace onnxruntime { + +HIPFence::HIPFence(const GPUDataTransfer* data_transfer) : data_transfer_(data_transfer) { + hipEventCreate(&read_event_); + hipEventCreate(&write_event_); +} + +HIPFence::~HIPFence() { + hipEventDestroy(read_event_); + hipEventDestroy(write_event_); +} + +void HIPFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) { + (void)provider_type; + (void)async_queue_id; + // sync on CPU for all other providers, this is blocking + hipEventSynchronize(write_event_); +} + +void HIPFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) { + (void)provider_type; + (void)queue_id; + + // sync on CPU for all other providers, this is blocking + hipEventSynchronize(read_event_); + hipEventSynchronize(write_event_); +} + +bool HIPFence::CanRelease() { + return hipEventQuery(read_event_) == hipSuccess && + hipEventQuery(write_event_) == hipSuccess; +} + +void HIPFence::AfterUsedAsInput(int queue_id) { + // update read fence + hipStream_t stream = data_transfer_->GetStream(queue_id); + hipEventRecord(read_event_, stream); +} + +void HIPFence::AfterUsedAsOutput(int queue_id) { + // update write fence + hipStream_t stream = data_transfer_->GetStream(queue_id); + hipEventRecord(write_event_, stream); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/hip_fence.h b/onnxruntime/core/providers/migraphx/hip_fence.h new file mode 100644 index 0000000000..ba9803ee37 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/hip_fence.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/tensor.h" +#include "core/graph/basic_types.h" + +namespace onnxruntime { +class GPUDataTransfer; + +class HIPFence : public IFence { + public: + HIPFence(const GPUDataTransfer* data_transfer); + virtual ~HIPFence(); + virtual void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) override; + virtual void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) override; + virtual void AfterUsedAsInput(int queue_id) override; + virtual void AfterUsedAsOutput(int queue_id) override; + virtual bool CanRelease() override; + + private: + hipEvent_t read_event_; + hipEvent_t write_event_; + const GPUDataTransfer* data_transfer_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc new file mode 100644 index 0000000000..4fbac1d6fa --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -0,0 +1,1297 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#include "core/common/common.h" +#include "core/common/logging/logging.h" +#include "core/framework/compute_capability.h" +#include "core/framework/allocatormgr.h" +#include "core/framework/kernel_registry.h" +#include "core/framework/memcpy.h" +#include "core/graph/graph_viewer.h" +#include "core/graph/model.h" +#include "core/graph/graph_utils.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/optimizer/reshape_fusion.h" +#include "migraphx_inc.h" +#include "migraphx_execution_provider.h" +#include "hip_allocator.h" +#include "gpu_data_transfer.h" +#include + +#if defined(_MSC_VER) +#pragma warning(disable : 4244 4245) +#elif __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +#if defined(_MSC_VER) +#pragma warning(default : 4244 4245) +#elif __GNUC__ +#pragma GCC diagnostic pop +#endif + +#define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz)) + +namespace onnxruntime { + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kMIGraphXExecutionProvider, + KernelDefBuilder() + .InputMemoryType(0) + .ExecQueueId(kHipStreamCopyIn) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kMIGraphXExecutionProvider, + KernelDefBuilder() + .OutputMemoryType(0) + .ExecQueueId(kHipStreamCopyOut) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + +static void RegisterMIGraphXKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + kernel_registry.Register(function_table_entry()); + } +} + +std::shared_ptr GetMIGraphXKernelRegistry() { + std::shared_ptr kernel_registry = std::make_shared(); + RegisterMIGraphXKernels(*kernel_registry); + + return kernel_registry; +} + +std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr kernel_registry = onnxruntime::GetMIGraphXKernelRegistry(); + return kernel_registry; +} + +MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider} { + + // Set GPU device to be used + hipSetDevice(info.device_id); + DeviceAllocatorRegistrationInfo default_memory_info( + {OrtMemTypeDefault, [](int id) { return onnxruntime::make_unique(id, MIGRAPHX); }, std::numeric_limits::max()}); + allocator_ = CreateAllocator(default_memory_info, device_id_); + InsertAllocator(allocator_); + + + DeviceAllocatorRegistrationInfo pinned_memory_info( + {OrtMemTypeCPUOutput, [](int) { return onnxruntime::make_unique(0, MIGRAPHX_PINNED); }, std::numeric_limits::max()}); + InsertAllocator(CreateAllocator(pinned_memory_info, device_id_)); + + + // create the target based on the device_id + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, device_id_); + std::set valid_targets = {"gpu", "cpu"}; + if (valid_targets.count(info.target_device) == 0) + { + LOGS_DEFAULT(FATAL) << "Device " << info.target_device << " are not supported"; + } + + t_ = migraphx::target(info.target_device.c_str()); +} + +AllocatorPtr MIGraphXExecutionProvider::GetAllocator(int id, OrtMemType mem_type) const { + if (mem_type == OrtMemTypeDefault) { + return allocator_; + } else { + return IExecutionProvider::GetAllocator(id, mem_type); + } +} + +std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { + return onnxruntime::make_unique(); +} + +static bool IsTypeSupported(const NodeArg* node_arg) { + const auto* type_proto = node_arg->TypeAsProto(); + if (!type_proto) { + return false; + } + + switch (type_proto->tensor_type().elem_type()) { + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT64: + return true; + default: + return false; + } +} + +static bool get_migraphx_type(ONNXTensorElementDataType type, + migraphx_shape_datatype_t &mgx_type) +{ + mgx_type = migraphx_shape_float_type; + switch(type) { + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + mgx_type = migraphx_shape_half_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + mgx_type = migraphx_shape_float_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + mgx_type = migraphx_shape_double_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: + mgx_type = migraphx_shape_int8_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: + mgx_type = migraphx_shape_int16_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: + mgx_type = migraphx_shape_int32_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + mgx_type = migraphx_shape_int64_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: + mgx_type = migraphx_shape_uint8_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: + mgx_type = migraphx_shape_uint16_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: + mgx_type = migraphx_shape_uint32_type; + break; + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT64: + mgx_type = migraphx_shape_uint64_type; + break; + default: + LOGS_DEFAULT(WARNING) << "MiGraphx: unsupported data type " << type << ", fallback to CPU"; + LOGS_DEFAULT(WARNING) << "implementation" << std::endl; + return false; + } + + return true; +} + +static bool can_eval_concat(const Node* concat, const InitializedTensorSet& initializers, const logging::Logger& logger) +{ + if (concat == nullptr) return true; + const auto concat_args = concat->InputDefs(); + if (concat_args.size() != 3) + { + return false; + } + + auto arg_0 = concat_args[0]; + bool b_found = (initializers.find(arg_0->Name()) != initializers.end()); + auto arg_2 = concat_args[2]; + b_found &= (initializers.find(arg_2->Name()) != initializers.end()); + if (b_found) + { + std::vector parent_path{ + {0, 1, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + std::vector edges; + b_found = graph_utils::FindPath(*concat, true, parent_path, edges, logger); + if (b_found) + { + const Node& gather = edges[1]->GetNode(); + const auto* arg_index = gather.InputDefs()[1]; + if (initializers.find(arg_index->Name()) != initializers.end()) + { + return true; + } + } + } + + return false; +} + +static bool can_eval_cast(const Node* cast, const InitializedTensorSet& initializers, const logging::Logger& logger) +{ + std::vector parent_path = { + {0, 0, "Concat", {1, 4, 11}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1}, kOnnxDomain}, + {0, 0, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, + {0, 0, "Slice", {1, 10, 11}, kOnnxDomain}, + {0, 0, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + std::vector edges; + if (graph_utils::FindPath(*cast, true, parent_path, edges, logger)) { + const Node& concat = edges[1]->GetNode(); + const Node& slice = edges[5]->GetNode(); + const auto& concat_args = concat.InputDefs(); + bool const_flag = true; + for (std::size_t i = 1; i < concat_args.size(); i++) + { + const_flag &= (initializers.find(concat_args[i]->Name()) != initializers.end()); + } + if (const_flag) + { + const auto& slice_args = slice.InputDefs(); + for (std::size_t i = 1; i < slice_args.size(); ++i) + { + const_flag &= (initializers.find(slice_args[i]->Name()) != initializers.end()); + } + } + if (const_flag) + { + return true; + } + } + + return false; +} + +static bool can_eval_input_shape(const Node* node, const InitializedTensorSet& initializers, const logging::Logger& logger) +{ + // scenario 1: [Root] --> Shape --> Cast --> Cast + std::vector parent_path{ + {0, 1, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + + std::vector edges; + if (graph_utils::FindPath(*node, true, parent_path, edges, logger)) + { + return true; + } + + // scenario 2: + const Node* concat = graph_utils::GetInputNode(*node, 1); + if (concat and concat->OpType() == "Concat") + { + if (can_eval_concat(concat, initializers, logger)) + { + return true; + } + } + + // scenario 3: + const Node* cast = graph_utils::GetInputNode(*node, 1); + if (cast and cast->OpType() == "Cast") + { + if (can_eval_cast(cast, initializers, logger)) + { + return true; + } + } + + // scenario 4: + std::vector parent_path_4 = { + {0, 1, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Concat", {1, 4, 11}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1}, kOnnxDomain}, + {0, 0, "Mul", {1, 6, 7}, kOnnxDomain}, + {0, 0, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, + {0, 0, "Slice", {1, 10, 11}, kOnnxDomain}, + {0, 0, "Cast", {1, 6, 9}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + if (graph_utils::FindPath(*node, true, parent_path_4, edges, logger)) { + const Node& concat = edges[1]->GetNode(); + const Node& mul = edges[3]->GetNode(); + const Node& slice = edges[6]->GetNode(); + const auto& concat_args = concat.InputDefs(); + bool const_flag = true; + for (std::size_t i = 1; i < concat_args.size(); i++) + { + const_flag &= (initializers.find(concat_args[i]->Name()) != initializers.end()); + } + if (const_flag) + { + const auto& mul_args = mul.InputDefs(); + const_flag &= (initializers.find(mul_args[1]->Name()) != initializers.end()); + } + if (const_flag) + { + const auto& slice_args = slice.InputDefs(); + for (std::size_t i = 1; i < slice_args.size(); ++i) + { + const_flag &= (initializers.find(slice_args[i]->Name()) != initializers.end()); + } + } + if (const_flag) + { + return true; + } + } + + return false; +} + +static bool IsUnsupportedOpMode(const Node* node, const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) { + const auto& optype = node->OpType(); + const auto& initializers = graph_viewer.GetAllInitializedTensors(); + if (optype == "AveragePool") { + // ceil_mode attribute is not supported in MIGraphX + const auto& attributes = node->GetAttributes(); + const auto ceil_attr = attributes.find("ceil_mode"); + // default value of ceil_mode (0) is supported. + if (ceil_attr != attributes.end() && ceil_attr->second.i() != 0) { + return true; + } + + // input can only have 4 dims + const auto input_shape = node->InputDefs()[0]->Shape(); + if (input_shape != nullptr and input_shape->dim_size() != 4) + { + return true; + } + + // migraphx does not support count_include_pad to be 1 + const auto cip_attr = attributes.find("count_include_pad"); + if (cip_attr != attributes.end() && cip_attr->second.i() != 0) + { + return true; + } + + const auto ap_attr = attributes.find("auto_pad"); + if (ap_attr != attributes.end()) + { + // explicit pad should be symmetric in migraphx + auto s_pad = ap_attr->second.s(); + auto pads_attr = attributes.find("pads"); + if (s_pad == "NOTSET") + { + if (pads_attr != attributes.end()) + { + auto pads = pads_attr->second.ints(); + if (pads.size() != 4) + { + return true; + } + + if ((pads[0] != pads[2]) || (pads[1] != pads[3])) + { + return true; + } + } + } + // either SAME_UPPER or SAME_LOWER + else if (s_pad.find("SAME") != std::string::npos) + { + // pads cannot exist when auto_pad is same_upper or same_lower + if (pads_attr != attributes.end()) + { + return true; + } + + // compute the padding size to see whether they are symmetric + std::vector strides = {1, 1}; + auto stride_attr = attributes.find("strides"); + if (stride_attr != attributes.end()) + { + auto attr_strides = stride_attr->second.ints(); + strides.clear(); + std::copy(attr_strides.begin(), attr_strides.end(), std::back_inserter(strides)); + } + + std::vector kernel_lens = {1, 1}; + auto kernel_attr = attributes.find("kernel_shape"); + if (kernel_attr != attributes.end()) + { + auto attr_k = kernel_attr->second.ints(); + std::copy(attr_k.begin(), attr_k.end(), kernel_lens.begin()); + } + + auto tensor_dims = input_shape->dim(); + std::vector in_lens; + std::transform(tensor_dims.begin(), + tensor_dims.end(), + std::back_inserter(in_lens), + [&](auto&& d) -> std::size_t { + if(d.has_dim_value()) + { + return d.dim_value(); + } + return 1; + }); + + std::vector out_lens(2); + out_lens[0] = (in_lens[2] + strides[0] - 1) / strides[0]; + out_lens[1] = (in_lens[3] + strides[1] - 1) / strides[1]; + std::vector explicit_pads(2); + explicit_pads[0] = (out_lens[0] - 1) * strides[0] + kernel_lens[0] - in_lens[2]; + explicit_pads[1] = (out_lens[1] - 1) * strides[1] + kernel_lens[1] - in_lens[3]; + if ((explicit_pads[0] & 1) != 0 or (explicit_pads[1] & 1) != 0) + { + return true; + } + } + } + } else if (optype == "BatchNormalization") { + // input can only have 4 dims + const auto input_shape = node->InputDefs()[0]->Shape(); + if (input_shape != nullptr and input_shape->dim_size() != 4) + { + return true; + } + } else if (optype == "Clip") { + auto args = node->InputDefs(); + if (args.size() >= 3) + { + if (initializers.find(args[2]->Name()) == initializers.end()) + return true; + } + if (args.size() >= 2) + { + if (initializers.find(args[1]->Name()) == initializers.end()) + return true; + } + } else if (optype == "Conv") { + // input can only have 4 dims + const auto input_shape = node->InputDefs()[0]->Shape(); + if (input_shape != nullptr and input_shape->dim_size() != 4) + { + return true; + } + } else if (optype == "ConstantOfShape") { + const auto shape_arg = node->InputDefs()[0]; + if (initializers.find(shape_arg->Name()) != initializers.end()) + { + return false; + } + const Node* shape_node = graph_utils::GetInputNode(*node, 0); + if (shape_node and shape_node->OpType() == "Concat") + { + if (can_eval_concat(shape_node, initializers, logger)) + { + return false; + } + } + else if (shape_node and shape_node->OpType() == "Cast") + { + if (can_eval_cast(shape_node, initializers, logger)) + { + return false; + } + } + return true; + } else if (optype == "ConvInteger") { + if (node->InputDefs()[0]->Shape()->dim_size() != 4) + { + return true; + } + + // migraphx can handle only two inputs + if (node->InputDefs().size() != 2) + { + return true; + } + + // only support int8 type + const auto& input_type = node->InputDefs()[0]->TypeAsProto(); + if (input_type == nullptr) + { + return true; + } + + if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) + { + return true; + } + } else if (optype == "Expand") { + // MIGraphX only supports constant shape input values + const auto& shape_input = node->InputDefs()[1]; + return !graph_viewer.IsConstantInitializer(shape_input->Name(), true); + } else if (optype == "MaxPool") { + //MaxPool "indices" output is not currently supported. + if (node->OutputDefs().size() > 1) { + return true; + } + + // ceil_mode and dilations attrs are not supported in MIGraphX + const auto& attributes = node->GetAttributes(); + const auto ceil_attr = attributes.find("ceil_mode"); + // default value of ceil_mode (0) is supported. + if (ceil_attr != attributes.end() and ceil_attr->second.i() != 0) { + return true; + } + + auto dila_attr = attributes.find("dilations"); + if (dila_attr != attributes.end()) { + auto dilas = dila_attr->second.ints(); + bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1;}); + if (ret == false) + { + return true; + } + } + + // storage order 1 (column major format) is not supported + const auto storage_order_attr = attributes.find("storage_order"); + if (storage_order_attr != attributes.end() and storage_order_attr->second.i() != 0) + { + return true; + } + + // input can only have 4 dims + const auto input_shape = node->InputDefs()[0]->Shape(); + if (input_shape != nullptr and input_shape->dim_size() != 4) + { + return true; + } + } else if (optype == "MatMulInteger") { + // migraphx can handle only two inputs + if (node->InputDefs().size() != 2) + { + return true; + } + + // only support int8 type + const auto& input_type = node->InputDefs()[0]->TypeAsProto(); + if (input_type == nullptr) + { + return true; + } + + if (input_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8) + { + return true; + } + } else if (optype == "OneHot") { + const auto& arg_depth = node->InputDefs()[1]; + return (initializers.find(arg_depth->Name()) == initializers.end()); + } else if (optype == "Pad") { + const auto& args = node->InputDefs(); + // if pad size is not constant, migraphx cannot support + if (args.size() >= 2) + { + const auto& shape_arg = node->InputDefs()[1]; + if (initializers.find(shape_arg->Name()) == initializers.end()) { + return true; + } + } + + const auto& attributes = node->GetAttributes(); + // Pad only support constant mode + const auto mode_attr = attributes.find("mode"); + std::string mode = "constant"; + if(mode_attr != attributes.end()) + { + mode = mode_attr->second.s(); + } + static const std::set allowed_modes = {"constant", "reflect"}; + if (allowed_modes.count(mode) == 0) + { + return true; + } + + // input value only applied to constant mode + if (mode == "constant") + { + if (args.size() == 3) + { + const auto& val_arg = node->InputDefs()[2]; + if (initializers.find(val_arg->Name()) == initializers.end()) { + return true; + } + } + } + } else if (optype == "Range") { + const auto& args = node->InputDefs(); + if (!std::all_of(args.begin(), args.end(), [&](auto arg) { + return (initializers.find(arg->Name()) != initializers.end()); + })) { + return true; + } + } else if (optype == "Reshape") { + const auto& args = node->InputDefs(); + if (args.size() == 2) + { + const auto& shape_arg = args[1]; + if (initializers.find(shape_arg->Name()) != initializers.end()) { + return false; + } + + if (can_eval_input_shape(node, initializers, logger)) + { + return false; + } + } + return true; + } else if (optype == "Slice") { + // MIGraphX does not properly handle the situation where any + // value of the "starts" attribute is higher than a corresponding + // value in the "ends" + const auto& args = node->InputDefs(); + if (args.size() == 5) { + return true; + } + + const auto& attributes = node->GetAttributes(); + if (args.size() >= 4) { + if (initializers.find(args[3]->Name()) == initializers.end()) + return true; + } + + if (args.size() >= 3) { + if (initializers.find(args[2]->Name()) == initializers.end()) + return true; + } + + if (args.size() >= 2) { + if (initializers.find(args[1]->Name()) == initializers.end()) + return true; + } + + if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { + const auto& starts = attributes.find("starts")->second.ints(); + const auto& ends = attributes.find("ends")->second.ints(); + for (int i = 0; i < starts.size(); ++i) { + if (starts.Get(i) > ends.Get(i)) { + return true; + } + } + } + } + else if (optype == "Tile") + { + const auto& args = node->InputDefs(); + return (initializers.find(args[1]->Name()) == initializers.end()); + } + + //Op doesn't fall into known any of unsupported modes. + return false; +} + +static bool IsNodeSupported(const std::set& op_set, + const onnxruntime::GraphViewer& graph_viewer, + const NodeIndex node_idx, + const logging::Logger& logger) { + const auto& node = graph_viewer.GetNode(node_idx); + const auto& optype = node->OpType(); + const auto& domain = node->Domain(); + + // Three types of checking: + // 1. Check input and output data types are supported. + // 2. Check op_type is implemented in migraphx + // 3. Check the mode is implemented in migraphx + // if 3. failed, call the constant folding capability in migraphx + // to see whether some input parameters can be calculated statically + // check data type + bool are_types_supported = true; + + node->ForEachDef([&are_types_supported](const onnxruntime::NodeArg& node_arg, bool /*is_input*/) { + are_types_supported &= IsTypeSupported(&node_arg); + }); + + if (!are_types_supported) { + return false; + } + + // whether an operator implemented in migraphx + if (op_set.count(optype) == 0) { + return false; + } + + // check that some modes might not be supported in migraphx for some operators + if (domain == kOnnxDomain && IsUnsupportedOpMode(node, graph_viewer, logger)) { + // not supported, then check the constant folding capability of migraphx + // to see whether it is supported + return false; + } + + return true; +} + +static void AppendNodesToSubGraph(const std::vector& nodes, + const std::vector& inputs, + const std::vector& outputs, + std::vector>& result) { + static size_t op_counter = 0; + + auto meta_def = onnxruntime::make_unique(); + meta_def->name = "MIGraphX_" + std::to_string(++op_counter); + meta_def->domain = kMIGraphXDomain; + meta_def->since_version = 1; + meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + meta_def->inputs = inputs; + meta_def->outputs = outputs; + + std::unique_ptr sub_graph = onnxruntime::make_unique(); + sub_graph->nodes = nodes; + sub_graph->SetMetaDef(meta_def); + result.push_back(onnxruntime::make_unique(std::move(sub_graph))); +} + +static std::vector +GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, + /*out*/ std::unordered_set& mgx_required_initializers, + const logging::Logger& logger) { + static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "ArgMax", "ArgMin", + "Asin", "Asinh", "Atan", "Atanh", "AveragePool", "BatchNormalization", "Cast", "Ceil", "Clip", + "Concat", "Constant", "ConstantFill", "ConstantOfShape", "Conv", "Cos", "Cosh", "Div", "Dropout", + "Elu", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", "Gemm", "GlobalAveragePool", + "GlobalMaxPool", "Identity", "ImageScaler", "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", + "Log", "LogSoftmax", "MatMul", "Max", "MaxPool", "Min", "Mul", "OneHot", "Pad", "Pow", "PRelu", + "RNN","Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", + "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", + "Round", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Split", "Sqrt", "Squeeze", + "Sub", "Sum", "Tan", "Tanh", "Tile", "Transpose", "Unsqueeze"}; + std::vector unsupported_nodes_idx; + for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { + if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { + // Collect inputs that are initializers + graph_viewer.GetNode(node_idx)->ForEachDef([&mgx_required_initializers, &graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) { + if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) { + mgx_required_initializers.insert(node_arg.Name()); + } }, true); + } else { + unsupported_nodes_idx.push_back(node_idx); + } + } + + return unsupported_nodes_idx; +} + +// Returns a vector clusters(or node_idx). For each unsupported node, the graph +// is split into 3 parts. supported_cluster + (UNsupported_node + rest_of_the_graph). +// This functions returns vector of all supported_subgraphx by amdmigraphx +static std::vector> +GetPartitionedSubgraphs(const std::vector& topological_order, const std::vector& unsupported_nodes) { + std::vector> mgx_subgraphx; + + auto prev = topological_order.begin(); + + for (const auto& unsup_node : unsupported_nodes) { + auto it = std::find(prev, topological_order.end(), unsup_node); + // Create a cluster vector[supported_node_idx, unsupported_node_idx) + // and append it to return list. + std::vector this_subgraph{prev, it}; + if (!this_subgraph.empty()) { + mgx_subgraphx.push_back(std::move(this_subgraph)); + } + // Point prev to node idx past this unsuported node. + prev = ++it; + } + + // Tail + std::vector this_subgraph{prev, topological_order.end()}; + if (!this_subgraph.empty()) { + mgx_subgraphx.push_back(std::move(this_subgraph)); + } + + return mgx_subgraphx; +} + +static void GetInputsOutputsOfSubgraph(const GraphViewer& graph_viewer, + const std::vector& nodes, + const std::unordered_set& mgx_required_initializers, + std::vector& nodes_inputs, + std::vector& nodes_outputs) { + std::unordered_set input_args; + std::vector ordered_input_args; + std::unordered_set output_args; + std::unordered_set external_output_args; + + for (const auto& node_idx : nodes) { + const auto& node = graph_viewer.GetNode(node_idx); + + // Collect all inputs and outputs + node->ForEachDef( + [&input_args, &ordered_input_args, &output_args](const NodeArg& node_arg, bool is_input) { + if (is_input) { + if (!input_args.count(node_arg.Name())) { + ordered_input_args.push_back(node_arg.Name()); + } + input_args.insert(node_arg.Name()); + } else { + output_args.insert(node_arg.Name()); + } + }, + true); + + // Check if output of this node is used by nodes outside + // subgraph. If yes add this to cluster outputs + for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { + const auto& ext_node = graph_viewer.GetNode((*it).Index()); + + if (std::find(nodes.begin(), nodes.end(), ext_node->Index()) == nodes.end()) { + // Node is external to subgraph. Search through its + // inputs to find the output that is generated by subgraph. + std::set ext_node_inputs; + ext_node->ForEachDef( + [&ext_node_inputs](const onnxruntime::NodeArg& arg, bool is_input) { + if (is_input) { + ext_node_inputs.insert(arg.Name()); + } + }, + true); + + for (const auto& out_def : node->OutputDefs()) { + if (ext_node_inputs.find(out_def->Name()) != ext_node_inputs.end()) { + external_output_args.insert(out_def->Name()); + } + } + } + } + } + + //Extract initializers used by subgraph. + std::unordered_set original_graph_inputs; + for (const auto& node_arg : graph_viewer.GetInputsIncludingInitializers()) { + original_graph_inputs.insert(node_arg->Name()); + } + + const auto& initializers = graph_viewer.GetAllInitializedTensors(); + std::vector const_inputs; + for (const auto& in_arg : ordered_input_args) { + if ((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || + mgx_required_initializers.count(in_arg)) { + const_inputs.push_back(in_arg); + } + } + + for (const auto& in_arg : ordered_input_args) { + if (!output_args.count(in_arg) && + !((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || + mgx_required_initializers.count(in_arg))) { + nodes_inputs.push_back(in_arg); + } + } + + for (const auto& in_arg : const_inputs) { + nodes_inputs.push_back(in_arg); + } + + std::copy(external_output_args.begin(), external_output_args.end(), std::back_inserter(nodes_outputs)); + for (const auto& node_arg : graph_viewer.GetOutputs()) { + const auto& name = node_arg->Name(); + if (output_args.count(name) && !external_output_args.count(name)) { + nodes_outputs.push_back(name); + } + } +} + +std::vector> +MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const std::vector& /*kernel_registries*/) const { + + std::vector> result; + if (graph_viewer.IsSubgraph()) { + return result; + } + + for (const auto& tensor : graph_viewer.GetAllInitializedTensors()) { + if (tensor.second->has_data_location() && tensor.second->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + LOGS_DEFAULT(WARNING) << "MIGraphX: Initializers with external data lepcation are not currently supported"; + return result; + } + } + + // Construct modelproto from graph + onnxruntime::Model model(graph_viewer.Name(), true, ModelMetaData(), PathString{}, + IOnnxRuntimeOpSchemaRegistryList(), graph_viewer.DomainToVersionMap(), + std::vector(), *GetLogger()); + + std::unordered_map map_dim_param_values; + onnxruntime::Graph& graph_build = model.MainGraph(); + for (const auto& node : graph_viewer.Nodes()) { + std::vector inputs, outputs; + for (auto input : node.InputDefs()) { + auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); + inputs.push_back(&n_input); + } + for (auto output : node.OutputDefs()) { + auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); + outputs.push_back(&n_output); + } + graph_build.AddNode(node.Name(), node.OpType(), node.Description(), inputs, outputs, &node.GetAttributes(), node.Domain()); + } + + //Add initializer to graph + std::size_t init_tensor_num = 0; + const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); + for (const auto& tensor : init_tensors) { + init_tensor_num++; + graph_build.AddInitializedTensor(*(tensor.second)); + } + + ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + auto status = graph_build.Resolve(); + std::string onnx_string_buffer; + model_proto.SerializeToString(&onnx_string_buffer); + + // This is a list of initializers that migraphx considers as constants. + // Example weights, reshape shape etc. + std::unordered_set mgx_required_initializers; + const auto unsupported_nodes = GetUnsupportedNodeIndices(graph_viewer, mgx_required_initializers, *GetLogger()); + + // Too many unsupported operators, fallback to run on CPU + if (unsupported_nodes.size() >= 6) + { + return result; + } + + //If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. + if (unsupported_nodes.empty()) { + std::vector inputs; + std::vector outputs; + + //Fill inputs with names + std::for_each(graph_viewer.GetInputs().begin(), graph_viewer.GetInputs().end(), + [&inputs](const NodeArg* node_arg) { inputs.push_back(node_arg->Name()); }); + + // In scenarios, when there are no inputs or all inputs being initializers, + // ConstantFolding optimization in onnxruntime pre-computes the value. + if (inputs.empty()) { + return result; + } + + // Initializers need to be part of meta_def->inputs + std::for_each(mgx_required_initializers.begin(), mgx_required_initializers.end(), + [&inputs](const std::string& initializer) { inputs.push_back(initializer); }); + + // Fill outputs with names + std::for_each(graph_viewer.GetOutputs().begin(), graph_viewer.GetOutputs().end(), + [&outputs](const NodeArg* node_arg) { outputs.push_back(node_arg->Name()); }); + + // Create and add this graph to result. + AppendNodesToSubGraph(graph_viewer.GetNodesInTopologicalOrder(), inputs, outputs, result); + + } else { // unsupported_nodes_idx.empty() + const auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); + + for (const auto& this_cluster : mgx_clusters) { + std::vector cluster_inputs, cluster_outputs; + GetInputsOutputsOfSubgraph(graph_viewer, this_cluster, mgx_required_initializers, cluster_inputs, cluster_outputs); + + if (!cluster_inputs.empty()) { + AppendNodesToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); + } + } + } + + return result; +} + +static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, + const logging::Logger& logger) { + const auto* node_function = fused_node->GetFunctionBody(); + + ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name()); + + const Graph& node_subgraph = node_function->Body(); + onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, PathString{}, + IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(), + std::vector(), logger}; + + ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); + //model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + *(model_proto.mutable_graph()) = node_subgraph.ToGraphProto(); + + auto opset = model_proto.add_opset_import(); + opset->set_domain(kOnnxDomain); + opset->set_version(node_subgraph.DomainToVersionMap().at(kOnnxDomain)); + + return model_proto; +} + +bool get_input_output_names(std::string& onnx_buffer, + std::vector& input_names, + std::vector& output_names) +{ + bool no_input_shape = false; + + input_names.clear(); + output_names.clear(); + onnx::ModelProto model; + if (model.ParseFromArray(onnx_buffer.data(), onnx_buffer.size())) + { + if (model.has_graph()) + { + // compute output names + auto& graph = model.graph(); + + // compute input names + std::unordered_set ini_names; + for(auto&& f : graph.initializer()) + ini_names.insert(f.name()); + + for(auto&& input : graph.input()) + { + const std::string& name = input.name(); + if (ini_names.count(name) == 0) + { + input_names.push_back(name); + auto dim_size = input.type().tensor_type().shape().dim_size(); + if (dim_size == 0) + { + no_input_shape = true; + } + } + } + + auto prog_output = graph.output(); + std::vector all_output_names; + std::vector prog_output_names; + std::transform(prog_output.begin(), + prog_output.end(), + std::back_inserter(all_output_names), + [](auto& node) { return node.name(); }); + std::copy_if( + all_output_names.begin(), + all_output_names.end(), + std::back_inserter(output_names), + [&](const auto& name) { return !name.empty(); }); + } + } + + return no_input_shape; +} + +Status MIGraphXExecutionProvider::Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) { + migraphx::onnx_options options; + bool no_input_shape = false; + // std::size_t fused_node_idx = 0; + for (const auto& fused_node : fused_nodes) { + // map parameter input name to index + std::unordered_map input_name_index; + const auto& input_defs = fused_node->InputDefs(); + input_name_index.reserve(input_defs.size()); + for (std::size_t i = 0; i < input_defs.size(); ++i) { + input_name_index[input_defs[i]->Name()] = i; + } + + // reconstruct the subgraph proto from fused nodes + onnx::ModelProto model_proto = GetModelProtoFromFusedNode(fused_node, *GetLogger()); + std::string onnx_string_buffer; + model_proto.SerializeToString(&onnx_string_buffer); + std::vector input_names, output_names; + no_input_shape |= get_input_output_names(onnx_string_buffer, input_names, output_names); + + // by parsing the model_proto, create a program corresponding to + // the input fused_node + migraphx::program prog; + + if (!no_input_shape) + { + prog = migraphx::parse_onnx_buffer(onnx_string_buffer, options); + prog.compile(t_); + + auto prog_output_shapes = prog.get_output_shapes(); + for (std::size_t i = 0; i < output_names.size(); ++i) + { + auto out_len = prog_output_shapes[i].lengths(); + options.set_input_parameter_shape(output_names[i], out_len); + } + } + + // compile the program + map_progs_[fused_node->Name()] = prog; + + map_onnx_string_[fused_node->Name()] = onnx_string_buffer; + map_input_index_[fused_node->Name()] = input_name_index; + map_no_input_shape_[fused_node->Name()] = no_input_shape; + NodeComputeInfo compute_info; + compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { + std::unique_ptr p = onnxruntime::make_unique(); + *p = {context->allocate_func, context->release_func, context->allocator_handle, map_progs_[context->node_name], + map_onnx_string_[context->node_name], options, t_, map_input_index_[context->node_name], &mgx_mu_, + map_no_input_shape_[context->node_name]}; + *state = p.release(); + return 0; + }; + + compute_info.release_state_func = [](FunctionState state) { + if (state) + delete static_cast(state); + }; + + compute_info.compute_func = [](FunctionState state, const OrtCustomOpApi* api, OrtKernelContext* context) { + Ort::CustomOpApi ort{*api}; + MIGraphXFuncState* mgx_state = reinterpret_cast(state); + std::unordered_map& map_input_name_index = mgx_state->input_name_indexes; + migraphx::target t = mgx_state->t; + migraphx::program& prog = mgx_state->prog; + std::string& onnx_string = mgx_state->onnx_string; + migraphx::onnx_options& cmp_options = mgx_state->options; + bool &no_input_shape = mgx_state->no_input_shape; + + // mean no program at all, so need to get the input shape info + // from input data + bool input_shape_match = true; + migraphx::program_parameter_shapes param_shapes; + if (no_input_shape) + { + for (auto& it : map_input_name_index) + { + auto& name = it.first; + auto& index = it.second; + const OrtValue* input_tensor = ort.KernelContext_GetInput(context, index); + auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); + const auto& tensor_shape = ort.GetTensorShape(tensor_info); + std::vector ort_lens(tensor_shape.begin(), tensor_shape.end()); + cmp_options.set_input_parameter_shape(name, ort_lens); + input_shape_match = false; + } + } + else + { + param_shapes = prog.get_parameter_shapes(); + auto prog_output_shapes = prog.get_output_shapes(); + + // check whether input shapes match with shapes of program inputs + // migraphx::onnx_options cmp_options; + if (param_shapes.size() > 0) + { + for (auto&& name : param_shapes.names()) + { + if (map_input_name_index.count(name) > 0) + { + const OrtValue* input_tensor = ort.KernelContext_GetInput(context, map_input_name_index[name]); + auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); + const auto& tensor_shape = ort.GetTensorShape(tensor_info); + std::vector ort_lens(tensor_shape.begin(), tensor_shape.end()); + + auto mgx_s = param_shapes[name]; + auto mgx_lens = mgx_s.lengths(); + auto mgx_strides = mgx_s.strides(); + if (mgx_lens.size() == 1 and mgx_lens[0] == 1 and + mgx_strides.size() == 1 and mgx_strides[0] == 0) + { + mgx_lens.clear(); + } + + if (mgx_lens != ort_lens) + { + cmp_options.set_input_parameter_shape(name, ort_lens); + input_shape_match = false; + } + } + } + } + } + + // input shapes are different, needs to re-parse onnx and + // re-compile the program + if (!input_shape_match) + { + prog = migraphx::parse_onnx_buffer(onnx_string, cmp_options); + prog.compile(t); + mgx_state->prog = prog; + param_shapes = prog.get_parameter_shapes(); + no_input_shape = false; + } + + migraphx::program_parameters m; + auto prog_output_shapes = prog.get_output_shapes(); + std::vector prog_output_indices; + if (param_shapes.size() > 0) + { + for (auto&& name : param_shapes.names()) + { + if (map_input_name_index.count(name) > 0) + { + const OrtValue* input_tensor = ort.KernelContext_GetInput(context, map_input_name_index[name]); + auto tensor_info = ort.GetTensorTypeAndShape(input_tensor); + const auto& tensor_shape = ort.GetTensorShape(tensor_info); + auto tensor_type = ort.GetTensorElementType(tensor_info); + ort.ReleaseTensorTypeAndShapeInfo(tensor_info); + + migraphx_shape_datatype_t mgx_type; + get_migraphx_type(tensor_type, mgx_type); + auto mgx_s = param_shapes[name]; + + if (mgx_type != mgx_s.type()) + { + LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; + } + + m.add(name, migraphx::argument(param_shapes[name], const_cast(ort.GetTensorData(input_tensor)))); + } + // It is a output argument + else + { + auto compute_output_index = [] (const std::string& name) -> int { + std::string out_name_prefix = "#output_"; + auto pos = name.find(out_name_prefix); + if (pos == std::string::npos) + { + return -1; + } + + std::string index_str = name.substr(pos + out_name_prefix.length()); + return std::stoi(index_str); + }; + + int output_index = compute_output_index(name); + if (output_index != -1) + { + prog_output_indices.push_back(output_index); + auto mgx_output_shape = prog_output_shapes[output_index]; + auto lens = mgx_output_shape.lengths(); + std::vector ort_output_shape(lens.begin(), lens.end()); + OrtValue* output_tensor = ort.KernelContext_GetOutput(context, output_index, ort_output_shape.data(), ort_output_shape.size()); + void* output_data = ort.GetTensorMutableData(output_tensor); + + // argument shape + auto mgx_arg_shape = param_shapes[name]; + m.add(name, migraphx::argument(mgx_arg_shape, output_data)); + } + } + } + } + + { + // lock to avoid race condition + std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); + auto prog_outputs = prog.eval(m); + hipDeviceSynchronize(); + + // In case of input parameters are reused as output parameter call hipMemcpy + auto output_num = prog_outputs.size(); + if (prog_output_indices.size() < output_num) + { + for (std::size_t i = 0; i < output_num; ++i) + { + if (std::find(prog_output_indices.begin(), prog_output_indices.end(), i) != prog_output_indices.end()) + continue; + auto gpu_res = prog_outputs[i]; + migraphx::shape res_shape = gpu_res.get_shape(); + auto res_lens = res_shape.lengths(); + std::vector ort_shape{res_lens.begin(), res_lens.end()}; + OrtValue* output_tensor = ort.KernelContext_GetOutput(context, i, ort_shape.data(), ort_shape.size()); + void* output_data = ort.GetTensorMutableData(output_tensor); + hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice); + } + } + } + + return Status::OK(); + }; + node_compute_funcs.push_back(compute_info); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h new file mode 100644 index 0000000000..ee9dd3324c --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#pragma once + +#include "core/framework/execution_provider.h" +#include "core/platform/ort_mutex.h" +#include +#include "migraphx_inc.h" + +namespace onnxruntime { + +// Information needed to construct amdmigraphx execution providers. +struct MIGraphXExecutionProviderInfo { + std::string target_device; + int device_id {0}; +}; + +// Information to construct kernel function state. +struct MIGraphXFuncState { + AllocateFunc allocate_func = nullptr; + DestroyFunc release_func = nullptr; + AllocatorHandle allocate_handle = nullptr; + migraphx::program prog{}; + std::string onnx_string; + migraphx::onnx_options options; + migraphx::target t{}; + std::unordered_map input_name_indexes; + OrtMutex* mgx_mu_ptr = nullptr; + bool no_input_shape = false; +}; + +// Logical device representation. +class MIGraphXExecutionProvider : public IExecutionProvider { + public: + explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); + ~MIGraphXExecutionProvider() = default; + + std::vector> + GetCapability(const onnxruntime::GraphViewer& graph_viewer, + const std::vector& kernel_registries) const override; + + Status Compile(const std::vector& fused_nodes, + std::vector& node_compute_funcs) override; + + virtual std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; + AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; + +private: + int device_id_; + migraphx::target t_; + OrtMutex mgx_mu_; + + std::unordered_map map_progs_; + std::unordered_map map_onnx_string_; + std::unordered_map> map_input_index_; + std::unordered_map map_no_input_shape_; + + AllocatorPtr allocator_; +}; + +} diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h new file mode 100644 index 0000000000..96b24051ac --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#pragma once + +#include +#include +#include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc new file mode 100644 index 0000000000..7b616b8fc7 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#include "core/providers/migraphx/migraphx_provider_factory.h" +#include +#include "migraphx_execution_provider.h" +#include "core/session/abi_session_options_impl.h" + +using namespace onnxruntime; + +namespace onnxruntime { +struct MIGraphXProviderFactory : IExecutionProviderFactory { + MIGraphXProviderFactory(int device_id) : device_id_(device_id) {} + ~MIGraphXProviderFactory() = default; + + std::unique_ptr CreateProvider() override { + MIGraphXExecutionProviderInfo info; + info.device_id = device_id_; + info.target_device = "gpu"; + return std::make_unique(info); + } + +private: + int device_id_; +}; + +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id) { + return std::make_shared(device_id); +} + +} // namespace onnxruntime + +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id) { + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_MIGraphX(device_id)); + return nullptr; +} diff --git a/onnxruntime/core/providers/migraphx/symbols.txt b/onnxruntime/core/providers/migraphx/symbols.txt new file mode 100644 index 0000000000..be8e8e150c --- /dev/null +++ b/onnxruntime/core/providers/migraphx/symbols.txt @@ -0,0 +1 @@ +OrtSessionOptionsAppendExecutionProvider_MIGraphX diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index d6d64d56a3..c11c53eb4f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -52,6 +52,12 @@ #define BACKEND_NGRAPH "" #endif +#if USE_MIGRAPHX +#define BACKEND_MIGRAPHX "-MIGRAPHX" +#else +#define BACKEND_MIGRAPHX "" +#endif + #ifdef USE_OPENVINO #if OPENVINO_CONFIG_CPU_FP32 #define BACKEND_OPENVINO "-OPENVINO_CPU_FP32" @@ -94,7 +100,7 @@ #define BACKEND_OPENBLAS "" #endif -#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_OPENVINO +#define BACKEND_DEVICE BACKEND_PROC BACKEND_DNNL BACKEND_MKLML BACKEND_NGRAPH BACKEND_OPENVINO BACKEND_NUPHAR BACKEND_OPENBLAS BACKEND_MIGRAPHX #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/providers.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -109,6 +115,9 @@ onnxruntime::ArenaExtendStrategy arena_extend_strategy = onnxruntime::ArenaExten #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_factory.h" #endif +#ifdef USE_MIGRAPHX +#include "core/providers/migraphx/migraphx_provider_factory.h" +#endif #ifdef USE_NGRAPH #include "core/providers/ngraph/ngraph_provider_factory.h" #endif @@ -130,6 +139,7 @@ std::shared_ptr CreateExecutionProviderFactory_CUDA(O size_t cuda_mem_limit, onnxruntime::ArenaExtendStrategy arena_extend_strategy); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id); +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); std::shared_ptr CreateExecutionProviderFactory_Dnnl(int use_arena); std::shared_ptr CreateExecutionProviderFactory_NGraph(const char* ng_backend_type); std::shared_ptr CreateExecutionProviderFactory_OpenVINO(const char* device); @@ -277,7 +287,7 @@ inline void RegisterExecutionProvider(InferenceSession* sess, onnxruntime::IExec const std::vector& GetAllProviders() { static std::vector all_providers = {kTensorrtExecutionProvider, kCudaExecutionProvider, kDnnlExecutionProvider, kNGraphExecutionProvider, kOpenVINOExecutionProvider, kNupharExecutionProvider, - kVitisAIExecutionProvider, kCpuExecutionProvider}; + kVitisAIExecutionProvider, kCpuExecutionProvider, kMIGraphXExecutionProvider}; return all_providers; } @@ -287,6 +297,9 @@ const std::vector& GetAvailableProviders() { #ifdef USE_TENSORRT available_providers.push_back(kTensorrtExecutionProvider); #endif +#ifdef USE_MIGRAPHX + available_providers.push_back(kMIGraphXExecutionProvider); +#endif #ifdef USE_CUDA available_providers.push_back(kCudaExecutionProvider); #endif @@ -318,6 +331,10 @@ void RegisterExecutionProviders(InferenceSession* sess, const std::vector values("streamLength", 2.0f); Record record(names, values); - const std::string* name; + const std::string* name = nullptr; auto status = record.GetName(2, &name); EXPECT_FALSE(status.IsOK()); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 4aa4c0e251..2b16a2e601 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -37,7 +37,7 @@ void usage() { "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', 'ngraph', " - "'openvino', 'nuphar', or 'acl'. " + "'openvino', 'nuphar', 'migraphx' or 'acl'. " "Default: 'cpu'.\n" "\t-x: Use parallel executor, default (without -x): sequential executor.\n" "\t-d [device_id]: Specifies the device id for multi-device (e.g. GPU). The value should > 0\n" @@ -101,6 +101,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_nnapi = false; bool enable_dml = false; bool enable_acl = false; + bool enable_migraphx = false; int device_id = 0; GraphOptimizationLevel graph_optimization_level = ORT_ENABLE_ALL; bool user_graph_optimization_level_set = false; @@ -167,7 +168,9 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_dml = true; } else if (!CompareCString(optarg, ORT_TSTR("acl"))) { enable_acl = true; - } else { + } else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) { + enable_migraphx = true; + }else { usage(); return -1; } @@ -363,6 +366,14 @@ int real_main(int argc, char* argv[], Ort::Env& env) { return -1; #endif } + if (enable_migraphx) { +#ifdef USE_MIGRAPHX + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(sf, device_id)); +#else + fprintf(stderr, "MIGRAPHX is not supported in this build"); + return -1; +#endif + } if (user_graph_optimization_level_set) { sf.SetGraphOptimizationLevel(graph_optimization_level); diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index b5ecc8a415..42284995b3 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -85,6 +85,12 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); #else ORT_THROW("Acl is not supported in this build\n"); +#endif + } else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) { +#ifdef USE_MIGRAPHX + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_MIGraphX(session_options, 0)); +#else + ORT_THROW("MIGraphX is not supported in this build\n"); #endif } else if (!provider_name.empty() && provider_name != onnxruntime::kCpuExecutionProvider) { ORT_THROW("This backend is not included in perf test runner.\n"); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index f968c4bf3b..3e0323e26f 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -17,7 +17,6 @@ pytest_plugins = 'onnx.backend.test.report', class OrtBackendTest(onnx.backend.test.BackendTest): - def __init__(self, backend, parent_module=None): super(OrtBackendTest, self).__init__(backend, parent_module) @@ -42,9 +41,11 @@ def create_backend_test(testname=None): backend_test.include(testname + '.*') else: # read filters data - with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'testdata', 'onnx_backend_test_series_filters.jsonc')) as f: + with open( + os.path.join(os.path.dirname(os.path.realpath(__file__)), 'testdata', + 'onnx_backend_test_series_filters.jsonc')) as f: filters_lines = f.readlines() - filters_lines = [x.split('//')[0] for x in filters_lines] + filters_lines = [x.split('//')[0] for x in filters_lines] filters = json.loads('\n'.join(filters_lines)) current_failing_tests = filters['current_failing_tests'] @@ -70,11 +71,23 @@ def create_backend_test(testname=None): if c2.supports_device('OPENVINO_CPU_FP32'): current_failing_tests += filters['current_failing_tests_OPENVINO_CPU_FP32'] + if c2.supports_device('MIGRAPHX'): + current_failing_tests += [ + '^test_constant_pad_cpu', '^test_softmax_axis_1_cpu', '^test_softmax_axis_0_cpu', + '^test_softmax_default_axis_cpu', '^test_round_cpu', '^test_lrn_default_cpu', '^test_lrn_cpu', + '^test_logsoftmax_axis_0_cpu', '^test_logsoftmax_axis_1_cpu', '^test_logsoftmax_default_axis_cpu', + '^test_dynamicquantizelinear_expanded_cpu', '^test_dynamicquantizelinear_max_adjusted_cpu', + '^test_dynamicquantizelinear_max_adjusted_expanded_cpu', '^test_dynamicquantizelinear_min_adjusted_cpu', + '^test_dynamicquantizelinear_min_adjusted_expanded_cpu', + '^test_range_float_type_positive_delta_expanded_cpu', + '^test_range_int32_type_negative_delta_expanded_cpu', '^test_operator_symbolic_override_nested_cpu' + ] + filters = current_failing_tests + \ - filters['tests_with_pre_opset7_dependencies'] + \ - filters['unsupported_usages'] + \ - filters['failing_permanently'] + \ - filters['test_with_types_disabled_due_to_binary_size_concerns'] + filters['tests_with_pre_opset7_dependencies'] + \ + filters['unsupported_usages'] + \ + filters['failing_permanently'] + \ + filters['test_with_types_disabled_due_to_binary_size_concerns'] backend_test.exclude('(' + '|'.join(filters) + ')') print('excluded tests:', filters) @@ -92,7 +105,8 @@ def parse_args(): # Add an argument to match a single test name, by adding the name to the 'include' filter. # Using -k with python unittest (https://docs.python.org/3/library/unittest.html#command-line-options) - # doesn't work as it filters on the test method name (Runner._add_model_test) rather than inidividual test case names. + # doesn't work as it filters on the test method name (Runner._add_model_test) rather than inidividual + # test case names. parser.add_argument( '-t', '--test-name', diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index d28b9bfd08..0710ad9249 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -20,6 +20,7 @@ std::shared_ptr CreateExecutionProviderFactory_Nuphar std::shared_ptr CreateExecutionProviderFactory_Nnapi(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id); +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); namespace test { @@ -36,6 +37,14 @@ std::unique_ptr DefaultTensorrtExecutionProvider() { #endif } +std::unique_ptr DefaultMIGraphXExecutionProvider() { +#ifdef USE_MIGRAPHX + return CreateExecutionProviderFactory_MIGraphX(0)->CreateProvider(); +#else + return nullptr; +#endif +} + std::unique_ptr DefaultOpenVINOExecutionProvider() { #ifdef USE_OPENVINO return CreateExecutionProviderFactory_OpenVINO("")->CreateProvider(); diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 83a16d3f67..69433675ae 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -13,6 +13,7 @@ std::unique_ptr DefaultDnnlExecutionProvider(bool enable_are std::unique_ptr DefaultNGraphExecutionProvider(); std::unique_ptr DefaultNupharExecutionProvider(bool allow_unaligned_buffers = true); std::unique_ptr DefaultTensorrtExecutionProvider(); +std::unique_ptr DefaultMIGraphXExecutionProvider(); std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); diff --git a/onnxruntime/test/util/include/providers.h b/onnxruntime/test/util/include/providers.h index 9861565bf4..14cc74a60f 100644 --- a/onnxruntime/test/util/include/providers.h +++ b/onnxruntime/test/util/include/providers.h @@ -31,3 +31,7 @@ #ifdef USE_ACL #include "core/providers/acl/acl_provider_factory.h" #endif +#ifdef USE_MIGRAPHX +#include "core/providers/migraphx/migraphx_provider_factory.h" +#endif + diff --git a/samples/c_cxx/include/providers.h b/samples/c_cxx/include/providers.h index 2bba91530b..077cfbd870 100644 --- a/samples/c_cxx/include/providers.h +++ b/samples/c_cxx/include/providers.h @@ -22,3 +22,6 @@ #ifdef USE_DML #include "onnxruntime/core/providers/dml/dml_provider_factory.h" #endif +#ifdef USE_MIGRAPHX +#include "onnxruntime/core/providers/migraphx/migraphx_provider_factory.h" +#endif diff --git a/server/external/spdlog b/server/external/spdlog index 352281313f..23f0cdf901 160000 --- a/server/external/spdlog +++ b/server/external/spdlog @@ -1 +1 @@ -Subproject commit 352281313fe1c4313bc222cb9de222afd50c822f +Subproject commit 23f0cdf9014650c79e214c2d0e935ab0f8821cc5 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 0dd1493be5..c0544aa826 100755 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -284,6 +284,10 @@ def parse_arguments(): "--use_tensorrt", action='store_true', help="Build with TensorRT") parser.add_argument( "--tensorrt_home", help="Path to TensorRT installation dir") + parser.add_argument( + "--use_migraphx", action='store_true', help="Build with MIGraphX") + parser.add_argument( + "--migraphx_home", help="Path to MIGraphX installation dir") parser.add_argument( "--use_full_protobuf", action='store_true', help="Use the full protobuf library") @@ -509,7 +513,7 @@ def setup_test_data(build_dir, configs): def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, - cudnn_home, tensorrt_home, path_to_protoc_exe, configs, + cudnn_home, tensorrt_home, migraphx_home, path_to_protoc_exe, configs, cmake_extra_defines, args, cmake_extra_args): log.info("Generating CMake build tree") cmake_dir = os.path.join(source_dir, "cmake") @@ -582,6 +586,9 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, "-Donnxruntime_USE_TENSORRT=" + ("ON" if args.use_tensorrt else "OFF"), "-Donnxruntime_TENSORRT_HOME=" + ( tensorrt_home if args.use_tensorrt else ""), + # set vars for migraphx + "-Donnxruntime_USE_MIGRAPHX=" + ("ON" if args.use_migraphx else "OFF"), + "-Donnxruntime_MIGRAPHX_HOME=" + (migraphx_home if args.use_migraphx else ""), # By default - we currently support only cross compiling for # ARM/ARM64 (no native compilation supported through this # script). @@ -994,6 +1001,23 @@ def setup_tensorrt_vars(args): return tensorrt_home +def setup_migraphx_vars(args): + + migraphx_home = None + + if (args.use_migraphx): + print("migraphx_home = {}".format(args.migraphx_home)) + migraphx_home = args.migraphx_home or os.getenv("MIGRAPHX_HOME") or None + + migraphx_home_not_valid = (migraphx_home and not os.path.exists(migraphx_home)) + + if (migraphx_home_not_valid): + raise BuildError("migraphx_home paths must be specified and valid.", + "migraphx_home='{}' valid={}." + .format(migraphx_home, migraphx_home_not_valid)) + return migraphx_home or '' + + def setup_dml_build(args, cmake_path, build_dir, configs): if args.use_dml: for config in configs: @@ -1561,6 +1585,9 @@ def main(): # if using tensorrt, setup tensorrt paths tensorrt_home = setup_tensorrt_vars(args) + # if using migraphx, setup migraphx paths + migraphx_home = setup_migraphx_vars(args) + os.makedirs(build_dir, exist_ok=True) log.info("Build started") @@ -1645,7 +1672,7 @@ def main(): setup_test_data(build_dir, configs) generate_build_tree( cmake_path, source_dir, build_dir, cuda_home, cudnn_home, - tensorrt_home, path_to_protoc_exe, configs, cmake_extra_defines, + tensorrt_home, migraphx_home, path_to_protoc_exe, configs, cmake_extra_defines, args, cmake_extra_args) if args.clean: