diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index a4ea17d7b4..95dcb0873d 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -918,11 +918,11 @@ if (onnxruntime_USE_ROCM)
add_definitions(-DUSE_ROCM=1)
# Add search paths for default hip installation
- list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen)
+ list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen ${onnxruntime_ROCM_HOME}/hiprand ${onnxruntime_ROCM_HOME}/rocrand)
set(CMAKE_MODULE_PATH "${onnxruntime_ROCM_HOME}/hip/cmake" ${CMAKE_MODULE_PATH})
find_package(HIP)
-
+ find_package(hiprand REQUIRED)
find_library(HIP_LIB amdhip64 REQUIRED)
find_library(ROC_BLAS rocblas REQUIRED)
find_library(MIOPEN_LIB MIOpen REQUIRED)
@@ -1027,7 +1027,7 @@ if (onnxruntime_USE_ROCM)
target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template)
endif()
# During transition to separate hipFFT repo, put hipfft/include early
- target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
+ target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen)
if (onnxruntime_ENABLE_TRAINING)
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 9fdea780df..0df4c77404 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -184,6 +184,7 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr AddInitializer;
public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools;
public IntPtr SessionOptionsAppendExecutionProvider_CUDA;
+ public IntPtr SessionOptionsAppendExecutionProvider_ROCM;
public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO;
public IntPtr SetGlobalDenormalAsZero;
public IntPtr CreateArenaCfg;
@@ -560,6 +561,9 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
+ [DllImport(nativeLib, CharSet = charSet)]
+ public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_ROCM(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
+
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index fb9c840bd5..6bc48a0d70 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -94,6 +94,30 @@ namespace Microsoft.ML.OnnxRuntime
return options;
}
+ ///
+ /// A helper method to construct a SessionOptions object for ROCM execution.
+ /// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// A SessionsOptions() object configured for execution on deviceId=0
+ public static SessionOptions MakeSessionOptionWithRocmProvider()
+ {
+ return MakeSessionOptionWithRocmProvider(0);
+ }
+
+ ///
+ /// A helper method to construct a SessionOptions object for ROCM execution.
+ /// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
+ ///
+ ///
+ /// A SessionsOptions() object configured for execution on deviceId
+ public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0)
+ {
+ SessionOptions options = new SessionOptions();
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(options.Handle, deviceId));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
+ return options;
+ }
+
#endregion
#region ExecutionProviderAppends
@@ -156,6 +180,15 @@ namespace Microsoft.ML.OnnxRuntime
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId));
}
+ ///
+ /// Use only if you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// integer device ID
+ public void AppendExecutionProvider_ROCM(int deviceId)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(handle, deviceId));
+ }
+
///
/// Use only if you have the onnxruntime package specific to this Execution Provider.
///
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index c21c4ec607..ad401aba1d 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -96,6 +96,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_CUDA
opt.AppendExecutionProvider_CUDA(0);
#endif
+#if USE_ROCM
+ opt.AppendExecutionProvider_ROCM(0);
+#endif
#if USE_DML
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
@@ -185,6 +188,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
# if USE_CUDA
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););
#endif
+# if USE_ROCM
+ Assert.True(Array.Exists(providers, provider => provider == "ROCMExecutionProvider"););
+#endif
+
}
[Fact]
@@ -2019,6 +2026,34 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
#endif
+#if USE_ROCM
+ void TestROCMAllocatorInternal(InferenceSession session)
+ {
+ int device_id = 0;
+ using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorROCM, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default))
+ {
+ Assert.Equal("Rocm", info_rocm.Name);
+ Assert.Equal(device_id, info_rocm.Id);
+ Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType());
+ Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType());
+
+ using (var allocator = new OrtAllocator(session, info_rocm))
+ {
+ var alloc_info = allocator.Info;
+ Assert.True(info_rocm.Equals(alloc_info));
+
+ uint size = 1024;
+ OrtMemoryAllocation chunk = allocator.Allocate(size);
+ Assert.Equal(chunk.Size, size);
+ Assert.True(chunk.Info.Equals(alloc_info));
+ chunk.Dispose();
+ alloc_info.Dispose();
+ }
+ }
+ }
+#endif
+
+
[Fact]
private void TestAllocator()
{
@@ -2029,12 +2064,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_CUDA
options.AppendExecutionProvider_CUDA(0);
#endif
+
+#if USE_ROCM
+ options.AppendExecutionProvider_ROCM(0);
+#endif
+
using (var session = new InferenceSession(modelPath, options))
{
TestCPUAllocatorInternal(session);
#if USE_CUDA
TestCUDAAllocatorInternal(session);
#endif
+#if USE_ROCM
+ TestROCMAllocatorInternal(session);
+#endif
+
}
}
}
@@ -2294,6 +2338,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_CUDA
,"OrtSessionOptionsAppendExecutionProvider_CUDA"
#endif
+#if USE_ROCM
+ ,"OrtSessionOptionsAppendExecutionProvider_ROCM"
+#endif
#if USE_DML
,"OrtSessionOptionsAppendExecutionProvider_DML"
#endif
@@ -2566,6 +2613,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
option.AppendExecutionProvider_CPU(1);
}
+#elif USE_ROCM
+ using (var option = (deviceId.HasValue) ?
+ SessionOptions.MakeSessionOptionWithRocmProvider(deviceId.Value) :
+ new SessionOptions())
+ {
+ if(!deviceId.HasValue)
+ {
+ option.AppendExecutionProvider_CPU(1);
+ }
#else
using (var option = new SessionOptions())
{
diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm
new file mode 100644
index 0000000000..1fdc467180
--- /dev/null
+++ b/dockerfiles/Dockerfile.rocm
@@ -0,0 +1,49 @@
+# --------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------
+# Dockerfile to run ONNXRuntime with MIGraphX integration
+#--------------------------------------------------------------------------
+
+FROM ubuntu:18.04
+
+ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
+ARG ONNXRUNTIME_BRANCH=master
+
+ENV DEBIAN_FRONTEND noninteractive
+RUN apt-get clean && apt-get update && apt-get install -y locales
+RUN locale-gen en_US.UTF-8
+RUN update-locale LANG=en_US.UTF-8
+ENV LC_ALL C.UTF-8
+ENV LANG C.UTF-8
+
+# Install rocm
+RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \
+ curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \
+ sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.0/ xenial main > /etc/apt/sources.list.d/rocm.list'
+
+RUN apt-get update &&\
+ apt-get install -y --no-install-recommends sudo git bash build-essential cmake libelf1 rocm-dkms libpython3.6-dev python3-pip miopen-hip rocblas\
+ libnuma-dev kmod half hipsparse rocfft hipblas
+
+# Install yapf
+RUN pip3 install yapf==0.28.0
+
+ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH}
+
+# Install dependencies
+COPY ./scripts/install_rocm_deps.sh /
+RUN chmod +x /install_rocm_deps.sh && /install_rocm_deps.sh && rm /install_rocm_deps.sh
+
+WORKDIR /code
+
+# Prepare onnxruntime repository & build onnxruntime
+RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
+ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
+ cd onnxruntime &&\
+ /bin/sh ./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines\
+ ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\
+ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\
+ cd .. &&\
+ rm -rf onnxruntime cmake-3.14.3-Linux-x86_64
+
diff --git a/dockerfiles/scripts/install_rocm_deps.sh b/dockerfiles/scripts/install_rocm_deps.sh
new file mode 100644
index 0000000000..eed8125b74
--- /dev/null
+++ b/dockerfiles/scripts/install_rocm_deps.sh
@@ -0,0 +1,79 @@
+#!/bin/bash
+prefix=/opt/rocm
+DEBIAN_FRONTEND=noninteractive
+apt-get update && apt-get install -y --no-install-recommends \
+ wget \
+ zip \
+ ca-certificates \
+ build-essential \
+ curl \
+ libcurl4-openssl-dev \
+ libssl-dev \
+ python3-dev
+
+# rocm-cmake
+wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/rocm-3.8.0.tar.gz
+tar -xzvf rocm-3.8.0.tar.gz
+rm rocm-3.8.0.tar.gz
+cd rocm-cmake-rocm-3.8.0
+mkdir build
+cd build
+cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
+make -j8
+make install
+cd ../..
+rm -rf rocm-cmake-rocm-3.8.0
+
+# rccl
+wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/rocm-4.0.0.tar.gz
+tar -xzvf rocm-4.0.0.tar.gz
+rm rocm-4.0.0.tar.gz
+cd rccl-rocm-4.0.0
+mkdir build
+cd build
+CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
+make -j8
+make install
+cd ../..
+rm -rf rccl-rocm-4.0.0
+
+#rocrand
+wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/rocm-4.0.0.tar.gz
+tar -xzvf rocm-4.0.0.tar.gz
+rm rocm-4.0.0.tar.gz
+cd rocRAND-rocm-4.0.0
+mkdir build
+cd build
+CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
+make -j8
+make install
+cd ../..
+rm -rf rocRAND-rocm-4.0.0
+
+#hipcub
+wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/rocm-4.0.0.tar.gz
+tar -xzvf rocm-4.0.0.tar.gz
+rm rocm-4.0.0.tar.gz
+cd hipCUB-rocm-4.0.0
+mkdir build
+cd build
+CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
+make -j8
+make package
+make install
+cd ../..
+rm -rf hipCUB-rocm-4.0.0
+
+#rocprim
+wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/rocm-4.0.0.tar.gz
+tar -xzvf rocm-4.0.0.tar.gz
+rm rocm-4.0.0.tar.gz
+cd rocPRIM-rocm-4.0.0
+mkdir build
+cd build
+CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
+make -j8
+make install
+cd ../..
+rm -rf rocPRIM-rocm-4.0.0
+
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 719a881a0e..e9c0fea093 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -275,6 +275,16 @@ typedef struct OrtCUDAProviderOptions {
void* user_compute_stream;
} OrtCUDAProviderOptions;
+///
+/// Options for the ROCM provider that are passed to SessionOptionsAppendExecutionProvider_ROCM
+///
+typedef struct OrtROCMProviderOptions {
+ int device_id; // hip device with id=0 as default device.
+ int miopen_conv_exhaustive_search; // miopen conv algo exhaustive search option
+ size_t hip_mem_limit; // default hip memory limitation to maximum finite value of size_t.
+ int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo.
+} OrtROCMProviderOptions;
+
///
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT
///
@@ -1150,6 +1160,13 @@ struct OrtApi {
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options);
+ /**
+ * Append ROCM execution provider to the session options
+ * If ROCM is not available (due to a non rocm enabled build), this function will return failure.
+ */
+ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM,
+ _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options);
+
/**
* Append OpenVINO execution provider to the session options
* If OpenVINO is not available (due to the OpenVINO provider shared library or its dependencies not being installed), this function will fail.
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 55ae18e270..d85ecd776d 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -325,6 +325,7 @@ struct SessionOptions : Base {
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
+ SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
};
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index d27d9055f6..64199ac6c3 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -490,6 +490,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUD
return *this;
}
+inline SessionOptions& SessionOptions::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(p_, &provider_options));
+ return *this;
+}
+
inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options));
return *this;
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
index 65521b7203..78654953d8 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc
@@ -12,6 +12,7 @@ namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kMemLimit = "hip_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
+constexpr const char* kConvExhaustiveSearch = "conv_exhaustive_search";
} // namespace provider_option_names
} // namespace rocm
@@ -30,6 +31,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
// TODO validate info.device_id
.AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id)
.AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit)
+ .AddAssignmentToReference(rocm::provider_option_names::kConvExhaustiveSearch, info.miopen_conv_exhaustive_search)
.AddAssignmentToEnumReference(
rocm::provider_option_names::kArenaExtendStrategy,
arena_extend_strategy_mapping, info.arena_extend_strategy)
@@ -42,6 +44,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
const ProviderOptions options{
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.hip_mem_limit)},
+ {rocm::provider_option_names::kConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
};
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h
index 3c2383a467..39e92a78ad 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h
@@ -15,6 +15,7 @@ struct ROCMExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t hip_mem_limit{std::numeric_limits::max()};
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
+ bool miopen_conv_exhaustive_search{false};
bool do_copy_in_default_stream{true};
bool has_user_compute_stream{false};
void* user_compute_stream{nullptr};
diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc
index 1e5ba9f46c..e8c439ce0d 100644
--- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc
+++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc
@@ -12,6 +12,7 @@
#include "core/providers/rocm/rocm_execution_provider.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/session/abi_session_options_impl.h"
+#include "core/session/ort_apis.h"
using namespace onnxruntime;
@@ -46,3 +47,17 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessi
return nullptr;
}
+
+ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
+ _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) {
+ ROCMExecutionProviderInfo info{};
+ info.device_id = gsl::narrow(rocm_options->device_id);
+ info.hip_mem_limit = rocm_options->hip_mem_limit;
+ info.arena_extend_strategy = static_cast(rocm_options->arena_extend_strategy);
+ info.miopen_conv_exhaustive_search = rocm_options->miopen_conv_exhaustive_search;
+
+ options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_ROCM(info));
+
+ return nullptr;
+}
+
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index a98af18f18..12174163e5 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -1838,6 +1838,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) {
}
#endif
+#ifndef USE_ROCM
+ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
+ _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) {
+ ORT_UNUSED_PARAMETER(options);
+ ORT_UNUSED_PARAMETER(rocm_options);
+ return CreateStatus(ORT_FAIL, "ROCM execution provider is not enabled.");
+}
+#endif
+
#if defined(ORT_MINIMAL_BUILD)
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
_In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) {
@@ -2098,6 +2107,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
&OrtApis::AddInitializer,
&OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools,
&OrtApis::SessionOptionsAppendExecutionProvider_CUDA,
+ &OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
&OrtApis::SetGlobalDenormalAsZero,
&OrtApis::CreateArenaCfg,
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index fe6e358771..f19b1d7293 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -248,6 +248,8 @@ ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ c
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options);
+ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_ROCM,
+ _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options);
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO,
_In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options);
ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* options);
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index 474c5a1a27..ac7b3fc734 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -36,7 +36,7 @@ void usage() {
"\t-v: verbose\n"
"\t-n [test_case_name]: Specifies a single test case to run.\n"
"\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', "
- "'openvino', 'nuphar', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
+ "'openvino', 'nuphar', 'rocm', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
"Default: 'cpu'.\n"
"\t-p: Pause after launch, can attach debugger and continue\n"
"\t-x: Use parallel executor, default (without -x): sequential executor.\n"
@@ -102,6 +102,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
bool enable_dml = false;
bool enable_acl = false;
bool enable_armnn = false;
+ bool enable_rocm = false;
bool enable_migraphx = false;
int device_id = 0;
GraphOptimizationLevel graph_optimization_level = ORT_ENABLE_ALL;
@@ -174,6 +175,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
enable_acl = true;
} else if (!CompareCString(optarg, ORT_TSTR("armnn"))) {
enable_armnn = true;
+ } else if (!CompareCString(optarg, ORT_TSTR("rocm"))) {
+ enable_rocm = true;
} else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) {
enable_migraphx = true;
} else {
@@ -412,6 +415,19 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
#else
fprintf(stderr, "ArmNN is not supported in this build\n");
return -1;
+#endif
+ }
+ if (enable_rocm) {
+#ifdef USE_ROCM
+ OrtROCMProviderOptions rocm_options{
+ 0,
+ 0,
+ std::numeric_limits::max(),
+ 0};
+ sf.AppendExecutionProvider_ROCM(rocm_options);
+#else
+ fprintf(stderr, "ROCM is not supported in this build");
+ return -1;
#endif
}
if (enable_migraphx) {
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index 148409e660..fcee88cbdf 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -162,6 +162,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0));
#else
ORT_THROW("ArmNN is not supported in this build\n");
+#endif
+ } else if (provider_name == onnxruntime::kRocmExecutionProvider) {
+#ifdef USE_ROCM
+ OrtROCMProviderOptions rocm_options{
+ 0,
+ 0,
+ std::numeric_limits::max(),
+ 0};
+ session_options.AppendExecutionProvider_ROCM(rocm_options);
+#else
+ ORT_THROW("ROCM is not supported in this build\n");
#endif
} else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc
index 7a278a4b55..b9e1770718 100644
--- a/onnxruntime/test/providers/cpu/model_tests.cc
+++ b/onnxruntime/test/providers/cpu/model_tests.cc
@@ -512,6 +512,8 @@ TEST_P(ModelTest, Run) {
InferenceSession session_object(so, (**ort_env).GetEnvironment());
if (provider_name == "cuda") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
+ } else if (provider_name == "rocm") {
+ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider()));
} else if (provider_name == "dnnl") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultDnnlExecutionProvider()));
} else if (provider_name == "nuphar") {
@@ -631,6 +633,9 @@ TEST_P(ModelTest, Run) {
#ifdef USE_CUDA
provider_names.push_back(ORT_TSTR("cuda"));
#endif
+#ifdef USE_ROCM
+ provider_names.push_back(ORT_TSTR("rocm"));
+#endif
#ifdef USE_DNNL
provider_names.push_back(ORT_TSTR("dnnl"));
#endif