mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add rocm execution provider to prover_list (#6306)
* code changes to add rocm ep to ep_list
This commit is contained in:
parent
031587814b
commit
c588d5d13a
17 changed files with 311 additions and 4 deletions
|
|
@ -918,11 +918,11 @@ if (onnxruntime_USE_ROCM)
|
|||
add_definitions(-DUSE_ROCM=1)
|
||||
|
||||
# Add search paths for default hip installation
|
||||
list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen)
|
||||
list(APPEND CMAKE_PREFIX_PATH ${onnxruntime_ROCM_HOME} ${onnxruntime_ROCM_HOME}/hip ${onnxruntime_ROCM_HOME}/hcc ${onnxruntime_ROCM_HOME}/miopen ${onnxruntime_ROCM_HOME}/hiprand ${onnxruntime_ROCM_HOME}/rocrand)
|
||||
|
||||
set(CMAKE_MODULE_PATH "${onnxruntime_ROCM_HOME}/hip/cmake" ${CMAKE_MODULE_PATH})
|
||||
find_package(HIP)
|
||||
|
||||
find_package(hiprand REQUIRED)
|
||||
find_library(HIP_LIB amdhip64 REQUIRED)
|
||||
find_library(ROC_BLAS rocblas REQUIRED)
|
||||
find_library(MIOPEN_LIB MIOpen REQUIRED)
|
||||
|
|
@ -1027,7 +1027,7 @@ if (onnxruntime_USE_ROCM)
|
|||
target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-undefined-var-template)
|
||||
endif()
|
||||
# During transition to separate hipFFT repo, put hipfft/include early
|
||||
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/include/hiprand ${onnxruntime_ROCM_HOME}/include/rocrand)
|
||||
target_include_directories(onnxruntime_providers_rocm PRIVATE ${onnxruntime_ROCM_HOME}/hipfft/include ${onnxruntime_ROCM_HOME}/include ${onnxruntime_ROCM_HOME}/include/hipcub ${onnxruntime_ROCM_HOME}/hiprand/include ${onnxruntime_ROCM_HOME}/rocrand/include)
|
||||
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${MPI_INCLUDE_DIRS} ${SAFEINT_INCLUDE_DIR} ${ONNXRUNTIME_ROOT}/../cmake/external/eigen)
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
|
|
|
|||
|
|
@ -184,6 +184,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr AddInitializer;
|
||||
public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools;
|
||||
public IntPtr SessionOptionsAppendExecutionProvider_CUDA;
|
||||
public IntPtr SessionOptionsAppendExecutionProvider_ROCM;
|
||||
public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO;
|
||||
public IntPtr SetGlobalDenormalAsZero;
|
||||
public IntPtr CreateArenaCfg;
|
||||
|
|
@ -560,6 +561,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_ROCM(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id);
|
||||
|
||||
|
|
|
|||
|
|
@ -94,6 +94,30 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
return options;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A helper method to construct a SessionOptions object for ROCM execution.
|
||||
/// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
|
||||
/// </summary>
|
||||
/// <returns>A SessionsOptions() object configured for execution on deviceId=0</returns>
|
||||
public static SessionOptions MakeSessionOptionWithRocmProvider()
|
||||
{
|
||||
return MakeSessionOptionWithRocmProvider(0);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// A helper method to construct a SessionOptions object for ROCM execution.
|
||||
/// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
|
||||
/// </summary>
|
||||
/// <param name="deviceId"></param>
|
||||
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
|
||||
public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0)
|
||||
{
|
||||
SessionOptions options = new SessionOptions();
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(options.Handle, deviceId));
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
|
||||
return options;
|
||||
}
|
||||
|
||||
#endregion
|
||||
|
||||
#region ExecutionProviderAppends
|
||||
|
|
@ -156,6 +180,15 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Use only if you have the onnxruntime package specific to this Execution Provider.
|
||||
/// </summary>
|
||||
/// <param name="deviceId">integer device ID</param>
|
||||
public void AppendExecutionProvider_ROCM(int deviceId)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(handle, deviceId));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Use only if you have the onnxruntime package specific to this Execution Provider.
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -96,6 +96,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
#if USE_CUDA
|
||||
opt.AppendExecutionProvider_CUDA(0);
|
||||
#endif
|
||||
#if USE_ROCM
|
||||
opt.AppendExecutionProvider_ROCM(0);
|
||||
#endif
|
||||
#if USE_DML
|
||||
// Explicitly set dll probe path so that the (potentially) stale system DirectML.dll
|
||||
// doesn't get loaded by the test process when it is eventually delay loaded by onnruntime.dll
|
||||
|
|
@ -185,6 +188,10 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
# if USE_CUDA
|
||||
Assert.True(Array.Exists(providers, provider => provider == "CUDAExecutionProvider"););
|
||||
#endif
|
||||
# if USE_ROCM
|
||||
Assert.True(Array.Exists(providers, provider => provider == "ROCMExecutionProvider"););
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
[Fact]
|
||||
|
|
@ -2019,6 +2026,34 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
}
|
||||
#endif
|
||||
|
||||
#if USE_ROCM
|
||||
void TestROCMAllocatorInternal(InferenceSession session)
|
||||
{
|
||||
int device_id = 0;
|
||||
using (var info_rocm = new OrtMemoryInfo(OrtMemoryInfo.allocatorROCM, OrtAllocatorType.ArenaAllocator, device_id, OrtMemType.Default))
|
||||
{
|
||||
Assert.Equal("Rocm", info_rocm.Name);
|
||||
Assert.Equal(device_id, info_rocm.Id);
|
||||
Assert.Equal(OrtAllocatorType.ArenaAllocator, info_rocm.GetAllocatorType());
|
||||
Assert.Equal(OrtMemType.Default, info_rocm.GetMemoryType());
|
||||
|
||||
using (var allocator = new OrtAllocator(session, info_rocm))
|
||||
{
|
||||
var alloc_info = allocator.Info;
|
||||
Assert.True(info_rocm.Equals(alloc_info));
|
||||
|
||||
uint size = 1024;
|
||||
OrtMemoryAllocation chunk = allocator.Allocate(size);
|
||||
Assert.Equal(chunk.Size, size);
|
||||
Assert.True(chunk.Info.Equals(alloc_info));
|
||||
chunk.Dispose();
|
||||
alloc_info.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
[Fact]
|
||||
private void TestAllocator()
|
||||
{
|
||||
|
|
@ -2029,12 +2064,21 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
#if USE_CUDA
|
||||
options.AppendExecutionProvider_CUDA(0);
|
||||
#endif
|
||||
|
||||
#if USE_ROCM
|
||||
options.AppendExecutionProvider_ROCM(0);
|
||||
#endif
|
||||
|
||||
using (var session = new InferenceSession(modelPath, options))
|
||||
{
|
||||
TestCPUAllocatorInternal(session);
|
||||
#if USE_CUDA
|
||||
TestCUDAAllocatorInternal(session);
|
||||
#endif
|
||||
#if USE_ROCM
|
||||
TestROCMAllocatorInternal(session);
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2294,6 +2338,9 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
#if USE_CUDA
|
||||
,"OrtSessionOptionsAppendExecutionProvider_CUDA"
|
||||
#endif
|
||||
#if USE_ROCM
|
||||
,"OrtSessionOptionsAppendExecutionProvider_ROCM"
|
||||
#endif
|
||||
#if USE_DML
|
||||
,"OrtSessionOptionsAppendExecutionProvider_DML"
|
||||
#endif
|
||||
|
|
@ -2566,6 +2613,15 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
{
|
||||
option.AppendExecutionProvider_CPU(1);
|
||||
}
|
||||
#elif USE_ROCM
|
||||
using (var option = (deviceId.HasValue) ?
|
||||
SessionOptions.MakeSessionOptionWithRocmProvider(deviceId.Value) :
|
||||
new SessionOptions())
|
||||
{
|
||||
if(!deviceId.HasValue)
|
||||
{
|
||||
option.AppendExecutionProvider_CPU(1);
|
||||
}
|
||||
#else
|
||||
using (var option = new SessionOptions())
|
||||
{
|
||||
|
|
|
|||
49
dockerfiles/Dockerfile.rocm
Normal file
49
dockerfiles/Dockerfile.rocm
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
# --------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------
|
||||
# Dockerfile to run ONNXRuntime with MIGraphX integration
|
||||
#--------------------------------------------------------------------------
|
||||
|
||||
FROM ubuntu:18.04
|
||||
|
||||
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
|
||||
ARG ONNXRUNTIME_BRANCH=master
|
||||
|
||||
ENV DEBIAN_FRONTEND noninteractive
|
||||
RUN apt-get clean && apt-get update && apt-get install -y locales
|
||||
RUN locale-gen en_US.UTF-8
|
||||
RUN update-locale LANG=en_US.UTF-8
|
||||
ENV LC_ALL C.UTF-8
|
||||
ENV LANG C.UTF-8
|
||||
|
||||
# Install rocm
|
||||
RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \
|
||||
curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \
|
||||
sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.0/ xenial main > /etc/apt/sources.list.d/rocm.list'
|
||||
|
||||
RUN apt-get update &&\
|
||||
apt-get install -y --no-install-recommends sudo git bash build-essential cmake libelf1 rocm-dkms libpython3.6-dev python3-pip miopen-hip rocblas\
|
||||
libnuma-dev kmod half hipsparse rocfft hipblas
|
||||
|
||||
# Install yapf
|
||||
RUN pip3 install yapf==0.28.0
|
||||
|
||||
ENV PATH /opt/miniconda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:${PATH}
|
||||
|
||||
# Install dependencies
|
||||
COPY ./scripts/install_rocm_deps.sh /
|
||||
RUN chmod +x /install_rocm_deps.sh && /install_rocm_deps.sh && rm /install_rocm_deps.sh
|
||||
|
||||
WORKDIR /code
|
||||
|
||||
# Prepare onnxruntime repository & build onnxruntime
|
||||
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
|
||||
/bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
|
||||
cd onnxruntime &&\
|
||||
/bin/sh ./build.sh --config Release --build_wheel --update --build --parallel --cmake_extra_defines\
|
||||
ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) --use_rocm --rocm_home=/opt/rocm &&\
|
||||
pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\
|
||||
cd .. &&\
|
||||
rm -rf onnxruntime cmake-3.14.3-Linux-x86_64
|
||||
|
||||
79
dockerfiles/scripts/install_rocm_deps.sh
Normal file
79
dockerfiles/scripts/install_rocm_deps.sh
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
#!/bin/bash
|
||||
prefix=/opt/rocm
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
wget \
|
||||
zip \
|
||||
ca-certificates \
|
||||
build-essential \
|
||||
curl \
|
||||
libcurl4-openssl-dev \
|
||||
libssl-dev \
|
||||
python3-dev
|
||||
|
||||
# rocm-cmake
|
||||
wget --quiet https://github.com/RadeonOpenCompute/rocm-cmake/archive/rocm-3.8.0.tar.gz
|
||||
tar -xzvf rocm-3.8.0.tar.gz
|
||||
rm rocm-3.8.0.tar.gz
|
||||
cd rocm-cmake-rocm-3.8.0
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
|
||||
make -j8
|
||||
make install
|
||||
cd ../..
|
||||
rm -rf rocm-cmake-rocm-3.8.0
|
||||
|
||||
# rccl
|
||||
wget --quiet https://github.com/ROCmSoftwarePlatform/rccl/archive/rocm-4.0.0.tar.gz
|
||||
tar -xzvf rocm-4.0.0.tar.gz
|
||||
rm rocm-4.0.0.tar.gz
|
||||
cd rccl-rocm-4.0.0
|
||||
mkdir build
|
||||
cd build
|
||||
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
|
||||
make -j8
|
||||
make install
|
||||
cd ../..
|
||||
rm -rf rccl-rocm-4.0.0
|
||||
|
||||
#rocrand
|
||||
wget --quiet https://github.com/ROCmSoftwarePlatform/rocRAND/archive/rocm-4.0.0.tar.gz
|
||||
tar -xzvf rocm-4.0.0.tar.gz
|
||||
rm rocm-4.0.0.tar.gz
|
||||
cd rocRAND-rocm-4.0.0
|
||||
mkdir build
|
||||
cd build
|
||||
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
|
||||
make -j8
|
||||
make install
|
||||
cd ../..
|
||||
rm -rf rocRAND-rocm-4.0.0
|
||||
|
||||
#hipcub
|
||||
wget --quiet https://github.com/ROCmSoftwarePlatform/hipCUB/archive/rocm-4.0.0.tar.gz
|
||||
tar -xzvf rocm-4.0.0.tar.gz
|
||||
rm rocm-4.0.0.tar.gz
|
||||
cd hipCUB-rocm-4.0.0
|
||||
mkdir build
|
||||
cd build
|
||||
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
|
||||
make -j8
|
||||
make package
|
||||
make install
|
||||
cd ../..
|
||||
rm -rf hipCUB-rocm-4.0.0
|
||||
|
||||
#rocprim
|
||||
wget --quiet https://github.com/ROCmSoftwarePlatform/rocPRIM/archive/rocm-4.0.0.tar.gz
|
||||
tar -xzvf rocm-4.0.0.tar.gz
|
||||
rm rocm-4.0.0.tar.gz
|
||||
cd rocPRIM-rocm-4.0.0
|
||||
mkdir build
|
||||
cd build
|
||||
CXX=/opt/rocm/bin/hipcc cmake -DCMAKE_INSTALL_PREFIX=$prefix ..
|
||||
make -j8
|
||||
make install
|
||||
cd ../..
|
||||
rm -rf rocPRIM-rocm-4.0.0
|
||||
|
||||
|
|
@ -275,6 +275,16 @@ typedef struct OrtCUDAProviderOptions {
|
|||
void* user_compute_stream;
|
||||
} OrtCUDAProviderOptions;
|
||||
|
||||
/// <summary>
|
||||
/// Options for the ROCM provider that are passed to SessionOptionsAppendExecutionProvider_ROCM
|
||||
/// </summary>
|
||||
typedef struct OrtROCMProviderOptions {
|
||||
int device_id; // hip device with id=0 as default device.
|
||||
int miopen_conv_exhaustive_search; // miopen conv algo exhaustive search option
|
||||
size_t hip_mem_limit; // default hip memory limitation to maximum finite value of size_t.
|
||||
int arena_extend_strategy; // default area extend strategy to KNextPowerOfTwo.
|
||||
} OrtROCMProviderOptions;
|
||||
|
||||
/// <summary>
|
||||
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT
|
||||
/// </summary>
|
||||
|
|
@ -1150,6 +1160,13 @@ struct OrtApi {
|
|||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options);
|
||||
|
||||
/**
|
||||
* Append ROCM execution provider to the session options
|
||||
* If ROCM is not available (due to a non rocm enabled build), this function will return failure.
|
||||
*/
|
||||
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options);
|
||||
|
||||
/**
|
||||
* Append OpenVINO execution provider to the session options
|
||||
* If OpenVINO is not available (due to the OpenVINO provider shared library or its dependencies not being installed), this function will fail.
|
||||
|
|
|
|||
|
|
@ -325,6 +325,7 @@ struct SessionOptions : Base<OrtSessionOptions> {
|
|||
SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
|
||||
|
||||
SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
|
||||
SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -490,6 +490,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_CUDA(const OrtCUD
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(p_, &provider_options));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
|
||||
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(p_, &provider_options));
|
||||
return *this;
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ namespace provider_option_names {
|
|||
constexpr const char* kDeviceId = "device_id";
|
||||
constexpr const char* kMemLimit = "hip_mem_limit";
|
||||
constexpr const char* kArenaExtendStrategy = "arena_extend_strategy";
|
||||
constexpr const char* kConvExhaustiveSearch = "conv_exhaustive_search";
|
||||
} // namespace provider_option_names
|
||||
} // namespace rocm
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
|
|||
// TODO validate info.device_id
|
||||
.AddAssignmentToReference(rocm::provider_option_names::kDeviceId, info.device_id)
|
||||
.AddAssignmentToReference(rocm::provider_option_names::kMemLimit, info.hip_mem_limit)
|
||||
.AddAssignmentToReference(rocm::provider_option_names::kConvExhaustiveSearch, info.miopen_conv_exhaustive_search)
|
||||
.AddAssignmentToEnumReference(
|
||||
rocm::provider_option_names::kArenaExtendStrategy,
|
||||
arena_extend_strategy_mapping, info.arena_extend_strategy)
|
||||
|
|
@ -42,6 +44,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
|
|||
const ProviderOptions options{
|
||||
{rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)},
|
||||
{rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.hip_mem_limit)},
|
||||
{rocm::provider_option_names::kConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
|
||||
{rocm::provider_option_names::kArenaExtendStrategy,
|
||||
EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ struct ROCMExecutionProviderInfo {
|
|||
OrtDevice::DeviceId device_id{0};
|
||||
size_t hip_mem_limit{std::numeric_limits<size_t>::max()};
|
||||
ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo};
|
||||
bool miopen_conv_exhaustive_search{false};
|
||||
bool do_copy_in_default_stream{true};
|
||||
bool has_user_compute_stream{false};
|
||||
void* user_compute_stream{nullptr};
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/providers/rocm/rocm_execution_provider.h"
|
||||
#include "core/providers/rocm/rocm_execution_provider_info.h"
|
||||
#include "core/session/abi_session_options_impl.h"
|
||||
#include "core/session/ort_apis.h"
|
||||
|
||||
using namespace onnxruntime;
|
||||
|
||||
|
|
@ -46,3 +47,17 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessi
|
|||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) {
|
||||
ROCMExecutionProviderInfo info{};
|
||||
info.device_id = gsl::narrow<OrtDevice::DeviceId>(rocm_options->device_id);
|
||||
info.hip_mem_limit = rocm_options->hip_mem_limit;
|
||||
info.arena_extend_strategy = static_cast<onnxruntime::ArenaExtendStrategy>(rocm_options->arena_extend_strategy);
|
||||
info.miopen_conv_exhaustive_search = rocm_options->miopen_conv_exhaustive_search;
|
||||
|
||||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_ROCM(info));
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1838,6 +1838,15 @@ ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options) {
|
||||
ORT_UNUSED_PARAMETER(options);
|
||||
ORT_UNUSED_PARAMETER(rocm_options);
|
||||
return CreateStatus(ORT_FAIL, "ROCM execution provider is not enabled.");
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(ORT_MINIMAL_BUILD)
|
||||
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) {
|
||||
|
|
@ -2098,6 +2107,7 @@ static constexpr OrtApi ort_api_1_to_8 = {
|
|||
&OrtApis::AddInitializer,
|
||||
&OrtApis::CreateEnvWithCustomLoggerAndGlobalThreadPools,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_CUDA,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_ROCM,
|
||||
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO,
|
||||
&OrtApis::SetGlobalDenormalAsZero,
|
||||
&OrtApis::CreateArenaCfg,
|
||||
|
|
|
|||
|
|
@ -248,6 +248,8 @@ ORT_API_STATUS_IMPL(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ c
|
|||
|
||||
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_CUDA,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_ROCM,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options);
|
||||
ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO,
|
||||
_In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options);
|
||||
ORT_API_STATUS_IMPL(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* options);
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ void usage() {
|
|||
"\t-v: verbose\n"
|
||||
"\t-n [test_case_name]: Specifies a single test case to run.\n"
|
||||
"\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', "
|
||||
"'openvino', 'nuphar', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
|
||||
"'openvino', 'nuphar', 'rocm', 'migraphx', 'acl', 'armnn', 'nnapi' or 'coreml'. "
|
||||
"Default: 'cpu'.\n"
|
||||
"\t-p: Pause after launch, can attach debugger and continue\n"
|
||||
"\t-x: Use parallel executor, default (without -x): sequential executor.\n"
|
||||
|
|
@ -102,6 +102,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
bool enable_dml = false;
|
||||
bool enable_acl = false;
|
||||
bool enable_armnn = false;
|
||||
bool enable_rocm = false;
|
||||
bool enable_migraphx = false;
|
||||
int device_id = 0;
|
||||
GraphOptimizationLevel graph_optimization_level = ORT_ENABLE_ALL;
|
||||
|
|
@ -174,6 +175,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
enable_acl = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("armnn"))) {
|
||||
enable_armnn = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("rocm"))) {
|
||||
enable_rocm = true;
|
||||
} else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) {
|
||||
enable_migraphx = true;
|
||||
} else {
|
||||
|
|
@ -412,6 +415,19 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
|
|||
#else
|
||||
fprintf(stderr, "ArmNN is not supported in this build\n");
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
if (enable_rocm) {
|
||||
#ifdef USE_ROCM
|
||||
OrtROCMProviderOptions rocm_options{
|
||||
0,
|
||||
0,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
0};
|
||||
sf.AppendExecutionProvider_ROCM(rocm_options);
|
||||
#else
|
||||
fprintf(stderr, "ROCM is not supported in this build");
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
if (enable_migraphx) {
|
||||
|
|
|
|||
|
|
@ -162,6 +162,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
|
|||
performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0));
|
||||
#else
|
||||
ORT_THROW("ArmNN is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kRocmExecutionProvider) {
|
||||
#ifdef USE_ROCM
|
||||
OrtROCMProviderOptions rocm_options{
|
||||
0,
|
||||
0,
|
||||
std::numeric_limits<size_t>::max(),
|
||||
0};
|
||||
session_options.AppendExecutionProvider_ROCM(rocm_options);
|
||||
#else
|
||||
ORT_THROW("ROCM is not supported in this build\n");
|
||||
#endif
|
||||
} else if (provider_name == onnxruntime::kMIGraphXExecutionProvider) {
|
||||
#ifdef USE_MIGRAPHX
|
||||
|
|
|
|||
|
|
@ -512,6 +512,8 @@ TEST_P(ModelTest, Run) {
|
|||
InferenceSession session_object(so, (**ort_env).GetEnvironment());
|
||||
if (provider_name == "cuda") {
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCudaExecutionProvider()));
|
||||
} else if (provider_name == "rocm") {
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultRocmExecutionProvider()));
|
||||
} else if (provider_name == "dnnl") {
|
||||
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultDnnlExecutionProvider()));
|
||||
} else if (provider_name == "nuphar") {
|
||||
|
|
@ -631,6 +633,9 @@ TEST_P(ModelTest, Run) {
|
|||
#ifdef USE_CUDA
|
||||
provider_names.push_back(ORT_TSTR("cuda"));
|
||||
#endif
|
||||
#ifdef USE_ROCM
|
||||
provider_names.push_back(ORT_TSTR("rocm"));
|
||||
#endif
|
||||
#ifdef USE_DNNL
|
||||
provider_names.push_back(ORT_TSTR("dnnl"));
|
||||
#endif
|
||||
|
|
|
|||
Loading…
Reference in a new issue