Add rocm execution provider to prover_list (#6306)

* code changes to add rocm ep to ep_list
This commit is contained in:
Shucai Xiao 2021-03-12 09:51:08 -06:00 committed by GitHub
parent 031587814b
commit c588d5d13a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 311 additions and 4 deletions

View file

@ -918,11 +918,11 @@ if (onnxruntime_USE_ROCM)
add_definitions(-DUSE_ROCM=1)
# Add search paths for default hip installation
list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen)
list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen ${onnxruntime_ROCM_HOME}/hiprand ${onnxruntime_ROCM_HOME}/rocrand)
set(CMAKE_MODULE_PATH "${onnxruntime_ROCM_HOME}/hip/cmake" ${CMAKE_MODULE_PATH})
find_package(HIP)
find_package(hiprand REQUIRED)
find_library(HIP_LIB amdhip64 REQUIRED)
find_library(ROC_BLAS rocblas REQUIRED)
find_library(MIOPEN_LIB MIOpen REQUIRED)
@ -1027,7 +1027,7 @@ if (onnxruntime_USE_ROCM)
target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template)
endif()
# During transition to separate hipFFT repo, put hipfft/include early
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen)
if (onnxruntime_ENABLE_TRAINING)

View file

@ -184,6 +184,7 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr AddInitializer;
public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools;
public IntPtr SessionOptionsAppendExecutionProvider_CUDA;
public IntPtr SessionOptionsAppendExecutionProvider_ROCM;
public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO;
public IntPtr SetGlobalDenormalAsZero;
public IntPtr CreateArenaCfg;
@ -560,6 +561,9 @@ namespace Microsoft.ML.OnnxRuntime
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_ROCM(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id);

View file

@ -94,6 +94,30 @@ namespace Microsoft.ML.OnnxRuntime
return options;
}
/// <summary>
/// A helper method to construct a SessionOptions object for ROCM execution.
/// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <returns>A SessionsOptions() object configured for execution on deviceId=0</returns>
public static SessionOptions MakeSessionOptionWithRocmProvider()
{
return MakeSessionOptionWithRocmProvider(0);
}
/// <summary>
/// A helper method to construct a SessionOptions object for ROCM execution.
/// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId"></param>
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0)
{
SessionOptions options = new SessionOptions();
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(options.Handle, deviceId));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
return options;
}
#endregion
#region ExecutionProviderAppends
@ -156,6 +180,15 @@ namespace Microsoft.ML.OnnxRuntime
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId));
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">integer device ID</param>
public void AppendExecutionProvider_ROCM(int deviceId)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(handle, deviceId));
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>

View file

@ -96,6 +96,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_CUDA
opt.AppendExecutionProvider_CUDA(0);
#endif
#if USE_ROCM
opt.AppendExecutionProvider_ROCM(0);
#endif
#if USE_DML
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
@ -185,6 +188,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
# if USE_CUDA
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););
#endif
# if USE_ROCM
Assert.True(Array.Exists(providers, provider => provider == "ROCMExecutionProvider"););
#endif
}
[Fact]
@ -2019,6 +2026,34 @@ namespace Microsoft.ML.OnnxRuntime.Tests
}
#endif
#if USE_ROCM
void TestROCMAllocatorInternal(InferenceSession session)
{
int device_id = 0;
using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorROCM, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default))
{
Assert.Equal("Rocm", info_rocm.Name);
Assert.Equal(device_id, info_rocm.Id);
Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType());
Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType());
using (var allocator = new OrtAllocator(session, info_rocm))
{
var alloc_info = allocator.Info;
Assert.True(info_rocm.Equals(alloc_info));
uint size = 1024;
OrtMemoryAllocation chunk = allocator.Allocate(size);
Assert.Equal(chunk.Size, size);
Assert.True(chunk.Info.Equals(alloc_info));
chunk.Dispose();
alloc_info.Dispose();
}
}
}
#endif
[Fact]
private void TestAllocator()
{
@ -2029,12 +2064,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_CUDA
options.AppendExecutionProvider_CUDA(0);
#endif
#if USE_ROCM
options.AppendExecutionProvider_ROCM(0);
#endif
using (var session = new InferenceSession(modelPath, options))
{
TestCPUAllocatorInternal(session);
#if USE_CUDA
TestCUDAAllocatorInternal(session);
#endif
#if USE_ROCM
TestROCMAllocatorInternal(session);
#endif
}
}
}
@ -2294,6 +2338,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
#if USE_CUDA
,"OrtSessionOptionsAppendExecutionProvider_CUDA"
#endif
#if USE_ROCM
,"OrtSessionOptionsAppendExecutionProvider_ROCM"
#endif
#if USE_DML
,"OrtSessionOptionsAppendExecutionProvider_DML"
#endif
@ -2566,6 +2613,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests
{
option.AppendExecutionProvider_CPU(1);
}
#elif USE_ROCM
using (var option = (deviceId.HasValue) ?
SessionOptions.MakeSessionOptionWithRocmProvider(deviceId.Value) :
new SessionOptions())
{
if(!deviceId.HasValue)
{
option.AppendExecutionProvider_CPU(1);
}
#else
using (var option = new SessionOptions())
{

View file

@ -0,0 +1,49 @@
# --------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------
# Dockerfile to run ONNXRuntime with MIGraphX integration
#--------------------------------------------------------------------------
FROM ubuntu:18.04
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get clean && apt-get update && apt-get install -y locales
RUN locale-gen en_US.UTF-8
RUN update-locale LANG=en_US.UTF-8
ENV LC_ALL C.UTF-8
ENV LANG C.UTF-8
# Install rocm
RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \
curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \
sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.0/ xenial main > /etc/apt/sources.list.d/rocm.list'
RUN apt-get update &&\
apt-get install -y --no-install-recommends sudo git bash build-essential cmake libelf1 rocm-dkms libpython3.6-dev python3-pip miopen-hip rocblas\
libnuma-dev kmod half hipsparse rocfft hipblas
# Install yapf
RUN pip3 install yapf==0.28.0
ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH}
# Install dependencies
COPY ./scripts/install_rocm_deps.sh /
RUN chmod +x /install_rocm_deps.sh && /install_rocm_deps.sh && rm /install_rocm_deps.sh
WORKDIR /code
# Prepare onnxruntime repository & build onnxruntime
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
/bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
cd onnxruntime &&\
/bin/sh ./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines\
ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\
pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\
cd .. &&\
rm -rf onnxruntime cmake-3.14.3-Linux-x86_64

View file

@ -0,0 +1,79 @@
#!/bin/bash
prefix=/opt/rocm
DEBIAN_FRONTEND=noninteractive
apt-get update && apt-get install -y --no-install-recommends \
wget \
zip \
ca-certificates \
build-essential \
curl \
libcurl4-openssl-dev \
libssl-dev \
python3-dev
# rocm-cmake
wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/rocm-3.8.0.tar.gz
tar -xzvf rocm-3.8.0.tar.gz
rm rocm-3.8.0.tar.gz
cd rocm-cmake-rocm-3.8.0
mkdir build
cd build
cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
make -j8
make install
cd ../..
rm -rf rocm-cmake-rocm-3.8.0
# rccl
wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/rocm-4.0.0.tar.gz
tar -xzvf rocm-4.0.0.tar.gz
rm rocm-4.0.0.tar.gz
cd rccl-rocm-4.0.0
mkdir build
cd build
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
make -j8
make install
cd ../..
rm -rf rccl-rocm-4.0.0
#rocrand
wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/rocm-4.0.0.tar.gz
tar -xzvf rocm-4.0.0.tar.gz
rm rocm-4.0.0.tar.gz
cd rocRAND-rocm-4.0.0
mkdir build
cd build
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
make -j8
make install
cd ../..
rm -rf rocRAND-rocm-4.0.0
#hipcub
wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/rocm-4.0.0.tar.gz
tar -xzvf rocm-4.0.0.tar.gz
rm rocm-4.0.0.tar.gz
cd hipCUB-rocm-4.0.0
mkdir build
cd build
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
make -j8
make package
make install
cd ../..
rm -rf hipCUB-rocm-4.0.0
#rocprim
wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/rocm-4.0.0.tar.gz
tar -xzvf rocm-4.0.0.tar.gz
rm rocm-4.0.0.tar.gz
cd rocPRIM-rocm-4.0.0
mkdir build
cd build
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
make -j8
make install
cd ../..
rm -rf rocPRIM-rocm-4.0.0

View file

@ -275,6 +275,16 @@ typedef struct OrtCUDAProviderOptions {
void* user_compute_stream;
} OrtCUDAProviderOptions;
/// <summary>
/// Options for the ROCM provider that are passed to SessionOptionsAppendExecutionProvider_ROCM
/// </summary>
typedef struct OrtROCMProviderOptions {
int device_id; // hip device with id=0 as default device.
int miopen_conv_exhaustive_search; // miopen conv algo exhaustive search option
size_t hip_mem_limit; // default hip memory limitation to maximum finite value of size_t.
int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo.
} OrtROCMProviderOptions;
/// <summary>
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT
/// </summary>
@ -1150,6 +1160,13 @@ struct OrtApi {
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options);
/**
* Append ROCM execution provider to the session options
* If ROCM is not available (due to a non rocm enabled build), this function will return failure.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM,
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options);
/**
* Append OpenVINO execution provider to the session options
* If OpenVINO is not available (due to the OpenVINO provider shared library or its dependencies not being installed), this function will fail.

View file

@ -325,6 +325,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
};

View file

@ -490,6 +490,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUD
return *this;
}
inline SessionOptions& SessionOptions::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(p_, &provider_options));
return *this;
}
inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options));
return *this;

View file

@ -12,6 +12,7 @@ namespace provider_option_names {
constexpr const char* kDeviceId = "device_id";
constexpr const char* kMemLimit = "hip_mem_limit";
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
constexpr const char* kConvExhaustiveSearch = "conv_exhaustive_search";
} // namespace provider_option_names
} // namespace rocm
@ -30,6 +31,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
// TODO validate info.device_id
.AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id)
.AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit)
.AddAssignmentToReference(rocm::provider_option_names::kConvExhaustiveSearch, info.miopen_conv_exhaustive_search)
.AddAssignmentToEnumReference(
rocm::provider_option_names::kArenaExtendStrategy,
arena_extend_strategy_mapping, info.arena_extend_strategy)
@ -42,6 +44,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
const ProviderOptions options{
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.hip_mem_limit)},
{rocm::provider_option_names::kConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kArenaExtendStrategy,
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
};

View file

@ -15,6 +15,7 @@ struct ROCMExecutionProviderInfo {
OrtDevice::DeviceId device_id{0};
size_t hip_mem_limit{std::numeric_limits<size_t>::max()};
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
bool miopen_conv_exhaustive_search{false};
bool do_copy_in_default_stream{true};
bool has_user_compute_stream{false};
void* user_compute_stream{nullptr};

View file

@ -12,6 +12,7 @@
#include "core/providers/rocm/rocm_execution_provider.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/ort_apis.h"
using namespace onnxruntime;
@ -46,3 +47,17 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessi
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) {
ROCMExecutionProviderInfo info{};
info.device_id = gsl::narrow<OrtDevice::DeviceId>(rocm_options->device_id);
info.hip_mem_limit = rocm_options->hip_mem_limit;
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(rocm_options->arena_extend_strategy);
info.miopen_conv_exhaustive_search = rocm_options->miopen_conv_exhaustive_search;
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_ROCM(info));
return nullptr;
}

View file

@ -1838,6 +1838,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) {
}
#endif
#ifndef USE_ROCM
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) {
ORT_UNUSED_PARAMETER(options);
ORT_UNUSED_PARAMETER(rocm_options);
return CreateStatus(ORT_FAIL, "ROCM execution provider is not enabled.");
}
#endif
#if defined(ORT_MINIMAL_BUILD)
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
_In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) {
@ -2098,6 +2107,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
&OrtApis::AddInitializer,
&OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools,
&OrtApis::SessionOptionsAppendExecutionProvider_CUDA,
&OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
&OrtApis::SetGlobalDenormalAsZero,
&OrtApis::CreateArenaCfg,

View file

@ -248,6 +248,8 @@ ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ c
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_CUDA,
_In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options);
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_ROCM,
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options);
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO,
_In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options);
ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* options);

View file

@ -36,7 +36,7 @@ void usage() {
"\t-v: verbose\n"
"\t-n [test_case_name]: Specifies a single test case to run.\n"
"\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', "
"'openvino', 'nuphar', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
"'openvino', 'nuphar', 'rocm', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
"Default: 'cpu'.\n"
"\t-p: Pause after launch, can attach debugger and continue\n"
"\t-x: Use parallel executor, default (without -x): sequential executor.\n"
@ -102,6 +102,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
bool enable_dml = false;
bool enable_acl = false;
bool enable_armnn = false;
bool enable_rocm = false;
bool enable_migraphx = false;
int device_id = 0;
GraphOptimizationLevel graph_optimization_level = ORT_ENABLE_ALL;
@ -174,6 +175,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
enable_acl = true;
} else if (!CompareCString(optarg, ORT_TSTR("armnn"))) {
enable_armnn = true;
} else if (!CompareCString(optarg, ORT_TSTR("rocm"))) {
enable_rocm = true;
} else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) {
enable_migraphx = true;
} else {
@ -412,6 +415,19 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
#else
fprintf(stderr, "ArmNN is not supported in this build\n");
return -1;
#endif
}
if (enable_rocm) {
#ifdef USE_ROCM
OrtROCMProviderOptions rocm_options{
0,
0,
std::numeric_limits<size_t>::max(),
0};
sf.AppendExecutionProvider_ROCM(rocm_options);
#else
fprintf(stderr, "ROCM is not supported in this build");
return -1;
#endif
}
if (enable_migraphx) {

View file

@ -162,6 +162,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0));
#else
ORT_THROW("ArmNN is not supported in this build\n");
#endif
} else if (provider_name == onnxruntime::kRocmExecutionProvider) {
#ifdef USE_ROCM
OrtROCMProviderOptions rocm_options{
0,
0,
std::numeric_limits<size_t>::max(),
0};
session_options.AppendExecutionProvider_ROCM(rocm_options);
#else
ORT_THROW("ROCM is not supported in this build\n");
#endif
} else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX

View file

@ -512,6 +512,8 @@ TEST_P(ModelTest, Run) {
InferenceSession session_object(so, (**ort_env).GetEnvironment());
if (provider_name == "cuda") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
} else if (provider_name == "rocm") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider()));
} else if (provider_name == "dnnl") {
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultDnnlExecutionProvider()));
} else if (provider_name == "nuphar") {
@ -631,6 +633,9 @@ TEST_P(ModelTest, Run) {
#ifdef USE_CUDA
provider_names.push_back(ORT_TSTR("cuda"));
#endif
#ifdef USE_ROCM
provider_names.push_back(ORT_TSTR("rocm"));
#endif
#ifdef USE_DNNL
provider_names.push_back(ORT_TSTR("dnnl"));
#endif