mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Change SessionOptions APIs to always return a status (#1171)
* Change SessionOptions APIs to always return a status, for consistency and ease of use (a couple returned 0 or -1 for success/failure)
This commit is contained in:
parent
b23ab6a06e
commit
b68bb51dd0
10 changed files with 112 additions and 95 deletions
|
|
@ -116,49 +116,49 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
#region SessionOptions API
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*OrtSessionOptions* */ OrtCreateSessionOptions();
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateSessionOptions(out IntPtr /*(OrtSessionOptions**)*/ sessionOptions);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern IntPtr /*(OrtSessionOptions*)*/OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions, out IntPtr /*(OrtSessionOptions**)*/ output);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtEnableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtEnableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtEnableMemPattern(IntPtr /* OrtSessionOptions* */ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtEnableMemPattern(IntPtr /* OrtSessionOptions* */ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtDisableMemPattern(IntPtr /* OrtSessionOptions* */ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtDisableMemPattern(IntPtr /* OrtSessionOptions* */ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtEnableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtEnableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, string logId);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, string logId);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern void OrtSetSessionLogVerbosityLevel(IntPtr /* OrtSessionOptions* */ options, LogLevel sessionLogVerbosityLevel);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionLogVerbosityLevel(IntPtr /* OrtSessionOptions* */ options, LogLevel sessionLogVerbosityLevel);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern int OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize);
|
||||
|
||||
[DllImport(nativeLib, CharSet = charSet)]
|
||||
public static extern int OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel);
|
||||
public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel);
|
||||
|
||||
|
||||
///**
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </summary>
|
||||
public SessionOptions()
|
||||
{
|
||||
_nativePtr = NativeMethods.OrtCreateSessionOptions();
|
||||
NativeMethods.OrtCreateSessionOptions(out _nativePtr);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -32,11 +32,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// 0 -> Disable all optimizations
|
||||
/// 1 -> Enable basic optimizations
|
||||
/// 2 -> Enable all optimizations
|
||||
/// <returns>True on success and false otherwise</returns>
|
||||
public bool SetSessionGraphOptimizationLevel(uint optimization_level)
|
||||
public void SetSessionGraphOptimizationLevel(uint optimization_level)
|
||||
{
|
||||
var result = NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level);
|
||||
return result == 0;
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -45,7 +43,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </param>
|
||||
public void EnableSequentialExecution()
|
||||
{
|
||||
NativeMethods.OrtEnableSequentialExecution(_nativePtr);
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableSequentialExecution(_nativePtr));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -54,7 +52,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
/// </param>
|
||||
public void DisableSequentialExecution()
|
||||
{
|
||||
NativeMethods.OrtDisableSequentialExecution(_nativePtr);
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableSequentialExecution(_nativePtr));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests
|
|||
|
||||
// Set the graph optimization level for this session.
|
||||
SessionOptions options = new SessionOptions();
|
||||
Assert.True(options.SetSessionGraphOptimizationLevel(graphOptimizationLevel));
|
||||
options.SetSessionGraphOptimizationLevel(graphOptimizationLevel);
|
||||
if(disableSequentialExecution) options.DisableSequentialExecution();
|
||||
|
||||
using (var session = new InferenceSession(modelPath, options))
|
||||
|
|
|
|||
|
|
@ -199,48 +199,46 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess,
|
|||
/**
|
||||
* \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use
|
||||
*/
|
||||
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions);
|
||||
ORT_API_STATUS(OrtCreateSessionOptions, _Out_ OrtSessionOptions** output);
|
||||
|
||||
// create a copy of an existing OrtSessionOptions
|
||||
ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions*);
|
||||
ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options);
|
||||
ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtCloneSessionOptions, _In_ OrtSessionOptions* in, _Out_ OrtSessionOptions** output);
|
||||
ORT_API_STATUS(OrtEnableSequentialExecution, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtDisableSequentialExecution, _In_ OrtSessionOptions* options);
|
||||
|
||||
// Enable profiling for this session.
|
||||
ORT_API(void, OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
|
||||
ORT_API(void, OrtDisableProfiling, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix);
|
||||
ORT_API_STATUS(OrtDisableProfiling, _In_ OrtSessionOptions* options);
|
||||
|
||||
// Enable the memory pattern optimization.
|
||||
// The idea is if the input shapes are the same, we could trace the internal memory allocation
|
||||
// and generate a memory pattern for future request. So next time we could just do one allocation
|
||||
// with a big chunk for all the internal memory allocation.
|
||||
// Note: memory pattern optimization is only available when SequentialExecution enabled.
|
||||
ORT_API(void, OrtEnableMemPattern, _In_ OrtSessionOptions* options);
|
||||
ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtEnableMemPattern, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtDisableMemPattern, _In_ OrtSessionOptions* options);
|
||||
|
||||
// Enable the memory arena on CPU
|
||||
// Arena may pre-allocate memory for future usage.
|
||||
// set this option to false if you don't want it.
|
||||
ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options);
|
||||
ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtEnableCpuMemArena, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS(OrtDisableCpuMemArena, _In_ OrtSessionOptions* options);
|
||||
|
||||
// < logger id to use for session output
|
||||
ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid);
|
||||
ORT_API_STATUS(OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid);
|
||||
|
||||
// < applies to session load, initialization, etc
|
||||
ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level);
|
||||
ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level);
|
||||
|
||||
// Set Graph optimization level.
|
||||
// Return 0 on success and -1 otherwise
|
||||
// Available options are : 0, 1, 2.
|
||||
// 0 -> Disable all optimizations
|
||||
// 1 -> Enable basic optimizations
|
||||
// 2 -> Enable all optimizations
|
||||
ORT_API(int, OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level);
|
||||
ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level);
|
||||
|
||||
// How many threads in the session thread pool.
|
||||
// Returns 0 on success, and -1 otherwise
|
||||
ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
|
||||
ORT_API_STATUS(OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size);
|
||||
|
||||
/**
|
||||
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
|
||||
|
|
|
|||
|
|
@ -107,67 +107,68 @@ inline RunOptions& RunOptions::SetTerminate(bool flag) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions::SessionOptions() : Base<OrtSessionOptions>{OrtCreateSessionOptions()} {
|
||||
inline SessionOptions::SessionOptions() {
|
||||
ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&p_));
|
||||
}
|
||||
|
||||
inline SessionOptions SessionOptions::Clone() const {
|
||||
return SessionOptions{OrtCloneSessionOptions(p_)};
|
||||
OrtSessionOptions* out;
|
||||
ORT_THROW_ON_ERROR(OrtCloneSessionOptions(p_, &out));
|
||||
return SessionOptions{out};
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool_size) {
|
||||
if (OrtSetSessionThreadPoolSize(p_, session_thread_pool_size) == -1)
|
||||
throw Exception("Error calling SessionOptions::SetThreadPoolSize", ORT_FAIL);
|
||||
ORT_THROW_ON_ERROR(OrtSetSessionThreadPoolSize(p_, session_thread_pool_size));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(uint32_t graph_optimization_level) {
|
||||
if (OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level) == -1)
|
||||
throw Exception("Error calling SessionOptions::SetGraphOptimizationLevel", ORT_FAIL);
|
||||
ORT_THROW_ON_ERROR(OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
|
||||
OrtEnableProfiling(p_, profile_file_prefix);
|
||||
ORT_THROW_ON_ERROR(OrtEnableProfiling(p_, profile_file_prefix));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableProfiling() {
|
||||
OrtDisableProfiling(p_);
|
||||
ORT_THROW_ON_ERROR(OrtDisableProfiling(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableMemPattern() {
|
||||
OrtEnableMemPattern(p_);
|
||||
ORT_THROW_ON_ERROR(OrtEnableMemPattern(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableMemPattern() {
|
||||
OrtDisableMemPattern(p_);
|
||||
ORT_THROW_ON_ERROR(OrtDisableMemPattern(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableCpuMemArena() {
|
||||
OrtEnableCpuMemArena(p_);
|
||||
ORT_THROW_ON_ERROR(OrtEnableCpuMemArena(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableCpuMemArena() {
|
||||
OrtDisableCpuMemArena(p_);
|
||||
ORT_THROW_ON_ERROR(OrtDisableCpuMemArena(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::EnableSequentialExecution() {
|
||||
OrtEnableSequentialExecution(p_);
|
||||
ORT_THROW_ON_ERROR(OrtEnableSequentialExecution(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::DisableSequentialExecution() {
|
||||
OrtDisableSequentialExecution(p_);
|
||||
ORT_THROW_ON_ERROR(OrtDisableSequentialExecution(p_));
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline SessionOptions& SessionOptions::SetLogId(const char* logid) {
|
||||
OrtSetSessionLogId(p_, logid);
|
||||
ORT_THROW_ON_ERROR(OrtSetSessionLogId(p_, logid));
|
||||
return *this;
|
||||
}
|
||||
inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) {
|
||||
|
|
|
|||
|
|
@ -8,3 +8,13 @@
|
|||
namespace onnxruntime {
|
||||
OrtStatus* ToOrtStatus(const onnxruntime::common::Status& st);
|
||||
};
|
||||
|
||||
#define API_IMPL_BEGIN try {
|
||||
#define API_IMPL_END \
|
||||
} \
|
||||
catch (const onnxruntime::NotImplementedException& ex) { \
|
||||
return OrtCreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
|
||||
} \
|
||||
catch (const std::exception& ex) { \
|
||||
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/onnxruntime_c_api.h"
|
||||
#include "core/framework/error_code_helper.h"
|
||||
#include <cstring>
|
||||
#include <cassert>
|
||||
#include "core/session/inference_session.h"
|
||||
|
|
@ -16,86 +17,97 @@ OrtSessionOptions::OrtSessionOptions(const OrtSessionOptions& other)
|
|||
: value(other.value), provider_factories(other.provider_factories) {
|
||||
}
|
||||
|
||||
ORT_API(OrtSessionOptions*, OrtCreateSessionOptions) {
|
||||
std::unique_ptr<OrtSessionOptions> options = std::make_unique<OrtSessionOptions>();
|
||||
return options.release();
|
||||
ORT_API_STATUS_IMPL(OrtCreateSessionOptions, OrtSessionOptions** out) {
|
||||
API_IMPL_BEGIN
|
||||
*out = new OrtSessionOptions();
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API(void, OrtReleaseSessionOptions, OrtSessionOptions* ptr) {
|
||||
delete ptr;
|
||||
}
|
||||
|
||||
ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) {
|
||||
try {
|
||||
return new OrtSessionOptions(*input);
|
||||
} catch (std::exception&) {
|
||||
return nullptr;
|
||||
}
|
||||
ORT_API_STATUS_IMPL(OrtCloneSessionOptions, OrtSessionOptions* input, OrtSessionOptions** out) {
|
||||
API_IMPL_BEGIN
|
||||
*out = new OrtSessionOptions(*input);
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_sequential_execution = true;
|
||||
return nullptr;
|
||||
}
|
||||
ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtDisableSequentialExecution, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_sequential_execution = false;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// enable profiling for this session.
|
||||
ORT_API(void, OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix) {
|
||||
ORT_API_STATUS_IMPL(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix) {
|
||||
options->value.enable_profiling = true;
|
||||
options->value.profile_file_prefix = profile_file_prefix;
|
||||
return nullptr;
|
||||
}
|
||||
ORT_API(void, OrtDisableProfiling, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtDisableProfiling, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_profiling = false;
|
||||
options->value.profile_file_prefix.clear();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// enable the memory pattern optimization.
|
||||
// The idea is if the input shapes are the same, we could trace the internal memory allocation
|
||||
// and generate a memory pattern for future request. So next time we could just do one allocation
|
||||
// with a big chunk for all the internal memory allocation.
|
||||
ORT_API(void, OrtEnableMemPattern, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtEnableMemPattern, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_mem_pattern = true;
|
||||
return nullptr;
|
||||
}
|
||||
ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtDisableMemPattern, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_mem_pattern = false;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// enable the memory arena on CPU
|
||||
// Arena may pre-allocate memory for future usage.
|
||||
// set this option to false if you don't want it.
|
||||
ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtEnableCpuMemArena, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_cpu_mem_arena = true;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options) {
|
||||
ORT_API_STATUS_IMPL(OrtDisableCpuMemArena, _In_ OrtSessionOptions* options) {
|
||||
options->value.enable_cpu_mem_arena = false;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
///< logger id to use for session output
|
||||
ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid) {
|
||||
ORT_API_STATUS_IMPL(OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid) {
|
||||
options->value.session_logid = logid;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
///< applies to session load, initialization, etc
|
||||
ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level) {
|
||||
ORT_API_STATUS_IMPL(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level) {
|
||||
options->value.session_log_verbosity_level = session_log_verbosity_level;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Set Graph optimization level.
|
||||
// Returns 0 on success and -1 otherwise
|
||||
// Available options are : 0, 1, 2.
|
||||
ORT_API(int, OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level) {
|
||||
if (graph_optimization_level >= static_cast<uint32_t>(onnxruntime::TransformerLevel::MaxTransformerLevel)) {
|
||||
return -1;
|
||||
}
|
||||
ORT_API_STATUS_IMPL(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level) {
|
||||
if (graph_optimization_level >= static_cast<uint32_t>(onnxruntime::TransformerLevel::MaxTransformerLevel))
|
||||
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid");
|
||||
options->value.graph_optimization_level = static_cast<onnxruntime::TransformerLevel>(graph_optimization_level);
|
||||
return 0;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
///How many threads in the session thread pool.
|
||||
ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size) {
|
||||
if (session_thread_pool_size <= 0) return -1;
|
||||
ORT_API_STATUS_IMPL(OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size) {
|
||||
if (session_thread_pool_size <= 0)
|
||||
return OrtCreateStatus(ORT_INVALID_ARGUMENT, "session_thread_pool_size is not valid");
|
||||
options->value.session_thread_pool_size = session_thread_pool_size;
|
||||
return 0;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,13 +62,6 @@ struct OrtEnv {
|
|||
ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv);
|
||||
};
|
||||
|
||||
#define API_IMPL_BEGIN try {
|
||||
#define API_IMPL_END \
|
||||
} \
|
||||
catch (std::exception & ex) { \
|
||||
return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
|
||||
}
|
||||
|
||||
#define TENSOR_READ_API_BEGIN \
|
||||
API_IMPL_BEGIN \
|
||||
auto v = reinterpret_cast<const ::OrtValue*>(value); \
|
||||
|
|
|
|||
|
|
@ -181,8 +181,8 @@ void verify_input_output_count(OrtSession* session) {
|
|||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
void enable_cuda(OrtSessionOptions* session_option) {
|
||||
ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0));
|
||||
void enable_cuda(OrtSessionOptions* session_options) {
|
||||
ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
@ -196,15 +196,16 @@ int main(int argc, char* argv[]) {
|
|||
char* output_file = argv[3];
|
||||
OrtEnv* env;
|
||||
ORT_ABORT_ON_ERROR(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env));
|
||||
OrtSessionOptions* session_option = OrtCreateSessionOptions();
|
||||
OrtSessionOptions* session_options;
|
||||
ORT_ABORT_ON_ERROR(OrtCreateSessionOptions(&session_options));
|
||||
#ifdef USE_CUDA
|
||||
enable_cuda(session_option);
|
||||
enable_cuda(session_options);
|
||||
#endif
|
||||
OrtSession* session;
|
||||
ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_option, &session));
|
||||
ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_options, &session));
|
||||
verify_input_output_count(session);
|
||||
int ret = run_inference(session, input_file, output_file);
|
||||
OrtReleaseSessionOptions(session_option);
|
||||
OrtReleaseSessionOptions(session_options);
|
||||
OrtReleaseSession(session);
|
||||
OrtReleaseEnv(env);
|
||||
if (ret != 0) {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,10 @@ TEST_F(CApiTest, session_options_graph_optimization_level) {
|
|||
options.SetGraphOptimizationLevel(valid_optimization_level);
|
||||
|
||||
// Test set optimization level fails when invalid level is provided.
|
||||
uint32_t invalid_level = static_cast<uint32_t>(TransformerLevel::MaxTransformerLevel);
|
||||
ASSERT_EQ(OrtSetSessionGraphOptimizationLevel(options, invalid_level), -1);
|
||||
try {
|
||||
uint32_t invalid_level = static_cast<uint32_t>(TransformerLevel::MaxTransformerLevel);
|
||||
options.SetGraphOptimizationLevel(invalid_level);
|
||||
} catch (const Ort::Exception& e) {
|
||||
ASSERT_EQ(e.GetOrtErrorCode(), ORT_INVALID_ARGUMENT);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue