From c588d5d13afcc3a4e1846d5c41ef59c214e6fa09 Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Fri, 12 Mar 2021 09:51:08 -0600 Subject: [PATCH] Add rocm execution provider to prover_list (#6306) * code changes to add rocm ep to ep_list --- cmake/onnxruntime_providers.cmake | 6 +- .../Microsoft.ML.OnnxRuntime/NativeMethods.cs | 4 + .../SessionOptions.cs | 33 ++++++++ .../InferenceTest.cs | 56 +++++++++++++ dockerfiles/Dockerfile.rocm | 49 ++++++++++++ dockerfiles/scripts/install_rocm_deps.sh | 79 +++++++++++++++++++ .../core/session/onnxruntime_c_api.h | 17 ++++ .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 5 ++ .../rocm/rocm_execution_provider_info.cc | 3 + .../rocm/rocm_execution_provider_info.h | 1 + .../providers/rocm/rocm_provider_factory.cc | 15 ++++ onnxruntime/core/session/onnxruntime_c_api.cc | 10 +++ onnxruntime/core/session/ort_apis.h | 2 + onnxruntime/test/onnx/main.cc | 18 ++++- onnxruntime/test/perftest/ort_test_session.cc | 11 +++ onnxruntime/test/providers/cpu/model_tests.cc | 5 ++ 17 files changed, 311 insertions(+), 4 deletions(-) create mode 100644 dockerfiles/Dockerfile.rocm create mode 100644 dockerfiles/scripts/install_rocm_deps.sh diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index a4ea17d7b4..95dcb0873d 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -918,11 +918,11 @@ if (onnxruntime_USE_ROCM) add_definitions(-DUSE_ROCM=1) # Add search paths for default hip installation - list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen) + list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen ${onnxruntime_ROCM_HOME}/hiprand ${onnxruntime_ROCM_HOME}/rocrand) set(CMAKE_MODULE_PATH "${onnxruntime_ROCM_HOME}/hip/cmake" ${CMAKE_MODULE_PATH}) find_package(HIP) - + find_package(hiprand REQUIRED) find_library(HIP_LIB amdhip64 REQUIRED) find_library(ROC_BLAS rocblas REQUIRED) find_library(MIOPEN_LIB MIOpen REQUIRED) @@ -1027,7 +1027,7 @@ if (onnxruntime_USE_ROCM) target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template) endif() # During transition to separate hipFFT repo, put hipfft/include early - target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand) + target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include) target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen) if (onnxruntime_ENABLE_TRAINING) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index 9fdea780df..0df4c77404 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -184,6 +184,7 @@ namespace Microsoft.ML.OnnxRuntime public IntPtr AddInitializer; public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools; public IntPtr SessionOptionsAppendExecutionProvider_CUDA; + public IntPtr SessionOptionsAppendExecutionProvider_ROCM; public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO; public IntPtr SetGlobalDenormalAsZero; public IntPtr CreateArenaCfg; @@ -560,6 +561,9 @@ namespace Microsoft.ML.OnnxRuntime [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id); + [DllImport(nativeLib, CharSet = charSet)] + public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_ROCM(IntPtr /*(OrtSessionOptions*) */ options, int device_id); + [DllImport(nativeLib, CharSet = charSet)] public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id); diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index fb9c840bd5..6bc48a0d70 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -94,6 +94,30 @@ namespace Microsoft.ML.OnnxRuntime return options; } + /// + /// A helper method to construct a SessionOptions object for ROCM execution. + /// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider. + /// + /// A SessionsOptions() object configured for execution on deviceId=0 + public static SessionOptions MakeSessionOptionWithRocmProvider() + { + return MakeSessionOptionWithRocmProvider(0); + } + + /// + /// A helper method to construct a SessionOptions object for ROCM execution. + /// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider. + /// + /// + /// A SessionsOptions() object configured for execution on deviceId + public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0) + { + SessionOptions options = new SessionOptions(); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(options.Handle, deviceId)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1)); + return options; + } + #endregion #region ExecutionProviderAppends @@ -156,6 +180,15 @@ namespace Microsoft.ML.OnnxRuntime NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId)); } + /// + /// Use only if you have the onnxruntime package specific to this Execution Provider. + /// + /// integer device ID + public void AppendExecutionProvider_ROCM(int deviceId) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(handle, deviceId)); + } + /// /// Use only if you have the onnxruntime package specific to this Execution Provider. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index c21c4ec607..ad401aba1d 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -96,6 +96,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests #if USE_CUDA opt.AppendExecutionProvider_CUDA(0); #endif +#if USE_ROCM + opt.AppendExecutionProvider_ROCM(0); +#endif #if USE_DML // Explicitly set dll probe path so that the (potentially) stale system DirectML.dll // doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll @@ -185,6 +188,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests # if USE_CUDA Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider");); #endif +# if USE_ROCM + Assert.True(Array.Exists(providers, provider => provider == "ROCMExecutionProvider");); +#endif + } [Fact] @@ -2019,6 +2026,34 @@ namespace Microsoft.ML.OnnxRuntime.Tests } #endif +#if USE_ROCM + void TestROCMAllocatorInternal(InferenceSession session) + { + int device_id = 0; + using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorROCM, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default)) + { + Assert.Equal("Rocm", info_rocm.Name); + Assert.Equal(device_id, info_rocm.Id); + Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType()); + Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType()); + + using (var allocator = new OrtAllocator(session, info_rocm)) + { + var alloc_info = allocator.Info; + Assert.True(info_rocm.Equals(alloc_info)); + + uint size = 1024; + OrtMemoryAllocation chunk = allocator.Allocate(size); + Assert.Equal(chunk.Size, size); + Assert.True(chunk.Info.Equals(alloc_info)); + chunk.Dispose(); + alloc_info.Dispose(); + } + } + } +#endif + + [Fact] private void TestAllocator() { @@ -2029,12 +2064,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests #if USE_CUDA options.AppendExecutionProvider_CUDA(0); #endif + +#if USE_ROCM + options.AppendExecutionProvider_ROCM(0); +#endif + using (var session = new InferenceSession(modelPath, options)) { TestCPUAllocatorInternal(session); #if USE_CUDA TestCUDAAllocatorInternal(session); #endif +#if USE_ROCM + TestROCMAllocatorInternal(session); +#endif + } } } @@ -2294,6 +2338,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests #if USE_CUDA ,"OrtSessionOptionsAppendExecutionProvider_CUDA" #endif +#if USE_ROCM + ,"OrtSessionOptionsAppendExecutionProvider_ROCM" +#endif #if USE_DML ,"OrtSessionOptionsAppendExecutionProvider_DML" #endif @@ -2566,6 +2613,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests { option.AppendExecutionProvider_CPU(1); } +#elif USE_ROCM + using (var option = (deviceId.HasValue) ? + SessionOptions.MakeSessionOptionWithRocmProvider(deviceId.Value) : + new SessionOptions()) + { + if(!deviceId.HasValue) + { + option.AppendExecutionProvider_CPU(1); + } #else using (var option = new SessionOptions()) { diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm new file mode 100644 index 0000000000..1fdc467180 --- /dev/null +++ b/dockerfiles/Dockerfile.rocm @@ -0,0 +1,49 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with MIGraphX integration +#-------------------------------------------------------------------------- + +FROM ubuntu:18.04 + +ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime +ARG ONNXRUNTIME_BRANCH=master + +ENV DEBIAN_FRONTEND noninteractive +RUN apt-get clean && apt-get update && apt-get install -y locales +RUN locale-gen en_US.UTF-8 +RUN update-locale LANG=en_US.UTF-8 +ENV LC_ALL C.UTF-8 +ENV LANG C.UTF-8 + +# Install rocm +RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ + curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \ + sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.0/ xenial main > /etc/apt/sources.list.d/rocm.list' + +RUN apt-get update &&\ + apt-get install -y --no-install-recommends sudo git bash build-essential cmake libelf1 rocm-dkms libpython3.6-dev python3-pip miopen-hip rocblas\ + libnuma-dev kmod half hipsparse rocfft hipblas + +# Install yapf +RUN pip3 install yapf==0.28.0 + +ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH} + +# Install dependencies +COPY ./scripts/install_rocm_deps.sh / +RUN chmod +x /install_rocm_deps.sh && /install_rocm_deps.sh && rm /install_rocm_deps.sh + +WORKDIR /code + +# Prepare onnxruntime repository & build onnxruntime +RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ + /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ + cd onnxruntime &&\ + /bin/sh ./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines\ + ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\ + pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\ + cd .. &&\ + rm -rf onnxruntime cmake-3.14.3-Linux-x86_64 + diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh new file mode 100644 index 0000000000..eed8125b74 --- /dev/null +++ b/dockerfiles/scripts/install_rocm_deps.sh @@ -0,0 +1,79 @@ +#!/bin/bash +prefix=/opt/rocm +DEBIAN_FRONTEND=noninteractive +apt-get update && apt-get install -y --no-install-recommends \ + wget \ + zip \ + ca-certificates \ + build-essential \ + curl \ + libcurl4-openssl-dev \ + libssl-dev \ + python3-dev + +# rocm-cmake +wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/rocm-3.8.0.tar.gz +tar -xzvf rocm-3.8.0.tar.gz +rm rocm-3.8.0.tar.gz +cd rocm-cmake-rocm-3.8.0 +mkdir build +cd build +cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocm-cmake-rocm-3.8.0 + +# rccl +wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/rocm-4.0.0.tar.gz +tar -xzvf rocm-4.0.0.tar.gz +rm rocm-4.0.0.tar.gz +cd rccl-rocm-4.0.0 +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rccl-rocm-4.0.0 + +#rocrand +wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/rocm-4.0.0.tar.gz +tar -xzvf rocm-4.0.0.tar.gz +rm rocm-4.0.0.tar.gz +cd rocRAND-rocm-4.0.0 +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocRAND-rocm-4.0.0 + +#hipcub +wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/rocm-4.0.0.tar.gz +tar -xzvf rocm-4.0.0.tar.gz +rm rocm-4.0.0.tar.gz +cd hipCUB-rocm-4.0.0 +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make package +make install +cd ../.. +rm -rf hipCUB-rocm-4.0.0 + +#rocprim +wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/rocm-4.0.0.tar.gz +tar -xzvf rocm-4.0.0.tar.gz +rm rocm-4.0.0.tar.gz +cd rocPRIM-rocm-4.0.0 +mkdir build +cd build +CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix .. +make -j8 +make install +cd ../.. +rm -rf rocPRIM-rocm-4.0.0 + diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 719a881a0e..e9c0fea093 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -275,6 +275,16 @@ typedef struct OrtCUDAProviderOptions { void* user_compute_stream; } OrtCUDAProviderOptions; +/// +/// Options for the ROCM provider that are passed to SessionOptionsAppendExecutionProvider_ROCM +/// +typedef struct OrtROCMProviderOptions { + int device_id; // hip device with id=0 as default device. + int miopen_conv_exhaustive_search; // miopen conv algo exhaustive search option + size_t hip_mem_limit; // default hip memory limitation to maximum finite value of size_t. + int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo. +} OrtROCMProviderOptions; + /// /// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT /// @@ -1150,6 +1160,13 @@ struct OrtApi { ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); + /** + * Append ROCM execution provider to the session options + * If ROCM is not available (due to a non rocm enabled build), this function will return failure. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); + /** * Append OpenVINO execution provider to the session options * If OpenVINO is not available (due to the OpenVINO provider shared library or its dependencies not being installed), this function will fail. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 55ae18e270..d85ecd776d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -325,6 +325,7 @@ struct SessionOptions : Base { SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); + SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d27d9055f6..64199ac6c3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -490,6 +490,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUD return *this; } +inline SessionOptions& SessionOptions::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(p_, &provider_options)); + return *this; +} + inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options)); return *this; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 65521b7203..78654953d8 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -12,6 +12,7 @@ namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "hip_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; +constexpr const char* kConvExhaustiveSearch = "conv_exhaustive_search"; } // namespace provider_option_names } // namespace rocm @@ -30,6 +31,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P // TODO validate info.device_id .AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id) .AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit) + .AddAssignmentToReference(rocm::provider_option_names::kConvExhaustiveSearch, info.miopen_conv_exhaustive_search) .AddAssignmentToEnumReference( rocm::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) @@ -42,6 +44,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.hip_mem_limit)}, + {rocm::provider_option_names::kConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, }; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index 3c2383a467..39e92a78ad 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -15,6 +15,7 @@ struct ROCMExecutionProviderInfo { OrtDevice::DeviceId device_id{0}; size_t hip_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; + bool miopen_conv_exhaustive_search{false}; bool do_copy_in_default_stream{true}; bool has_user_compute_stream{false}; void* user_compute_stream{nullptr}; diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 1e5ba9f46c..e8c439ce0d 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -12,6 +12,7 @@ #include "core/providers/rocm/rocm_execution_provider.h" #include "core/providers/rocm/rocm_execution_provider_info.h" #include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" using namespace onnxruntime; @@ -46,3 +47,17 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessi return nullptr; } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) { + ROCMExecutionProviderInfo info{}; + info.device_id = gsl::narrow(rocm_options->device_id); + info.hip_mem_limit = rocm_options->hip_mem_limit; + info.arena_extend_strategy = static_cast(rocm_options->arena_extend_strategy); + info.miopen_conv_exhaustive_search = rocm_options->miopen_conv_exhaustive_search; + + options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_ROCM(info)); + + return nullptr; +} + diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a98af18f18..12174163e5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1838,6 +1838,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) { } #endif +#ifndef USE_ROCM +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(rocm_options); + return CreateStatus(ORT_FAIL, "ROCM execution provider is not enabled."); +} +#endif + #if defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) { @@ -2098,6 +2107,7 @@ static constexpr OrtApi ort_api_1_to_8 = { &OrtApis::AddInitializer, &OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools, &OrtApis::SessionOptionsAppendExecutionProvider_CUDA, + &OrtApis::SessionOptionsAppendExecutionProvider_ROCM, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, &OrtApis::SetGlobalDenormalAsZero, &OrtApis::CreateArenaCfg, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fe6e358771..f19b1d7293 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -248,6 +248,8 @@ ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ c ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options); ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* options); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 474c5a1a27..ac7b3fc734 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -36,7 +36,7 @@ void usage() { "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', " - "'openvino', 'nuphar', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. " + "'openvino', 'nuphar', 'rocm', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. " "Default: 'cpu'.\n" "\t-p: Pause after launch, can attach debugger and continue\n" "\t-x: Use parallel executor, default (without -x): sequential executor.\n" @@ -102,6 +102,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_dml = false; bool enable_acl = false; bool enable_armnn = false; + bool enable_rocm = false; bool enable_migraphx = false; int device_id = 0; GraphOptimizationLevel graph_optimization_level = ORT_ENABLE_ALL; @@ -174,6 +175,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_acl = true; } else if (!CompareCString(optarg, ORT_TSTR("armnn"))) { enable_armnn = true; + } else if (!CompareCString(optarg, ORT_TSTR("rocm"))) { + enable_rocm = true; } else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) { enable_migraphx = true; } else { @@ -412,6 +415,19 @@ int real_main(int argc, char* argv[], Ort::Env& env) { #else fprintf(stderr, "ArmNN is not supported in this build\n"); return -1; +#endif + } + if (enable_rocm) { +#ifdef USE_ROCM + OrtROCMProviderOptions rocm_options{ + 0, + 0, + std::numeric_limits::max(), + 0}; + sf.AppendExecutionProvider_ROCM(rocm_options); +#else + fprintf(stderr, "ROCM is not supported in this build"); + return -1; #endif } if (enable_migraphx) { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 148409e660..fcee88cbdf 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -162,6 +162,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); #else ORT_THROW("ArmNN is not supported in this build\n"); +#endif + } else if (provider_name == onnxruntime::kRocmExecutionProvider) { +#ifdef USE_ROCM + OrtROCMProviderOptions rocm_options{ + 0, + 0, + std::numeric_limits::max(), + 0}; + session_options.AppendExecutionProvider_ROCM(rocm_options); +#else + ORT_THROW("ROCM is not supported in this build\n"); #endif } else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) { #ifdef USE_MIGRAPHX diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 7a278a4b55..b9e1770718 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -512,6 +512,8 @@ TEST_P(ModelTest, Run) { InferenceSession session_object(so, (**ort_env).GetEnvironment()); if (provider_name == "cuda") { ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider())); + } else if (provider_name == "rocm") { + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider())); } else if (provider_name == "dnnl") { ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultDnnlExecutionProvider())); } else if (provider_name == "nuphar") { @@ -631,6 +633,9 @@ TEST_P(ModelTest, Run) { #ifdef USE_CUDA provider_names.push_back(ORT_TSTR("cuda")); #endif +#ifdef USE_ROCM + provider_names.push_back(ORT_TSTR("rocm")); +#endif #ifdef USE_DNNL provider_names.push_back(ORT_TSTR("dnnl")); #endif