diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index a7ee28c96c..eb0aee5154 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -116,49 +116,49 @@ namespace Microsoft.ML.OnnxRuntime #region SessionOptions API [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*OrtSessionOptions* */ OrtCreateSessionOptions(); + public static extern IntPtr /*(OrtStatus*)*/ OrtCreateSessionOptions(out IntPtr /*(OrtSessionOptions**)*/ sessionOptions); [DllImport(nativeLib, CharSet = charSet)] public static extern void OrtReleaseSessionOptions(IntPtr /*(OrtSessionOptions*)*/session); [DllImport(nativeLib, CharSet = charSet)] - public static extern IntPtr /*(OrtSessionOptions*)*/OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions); + public static extern IntPtr /*(OrtStatus*)*/ OrtCloneSessionOptions(IntPtr /*(OrtSessionOptions*)*/ sessionOptions, out IntPtr /*(OrtSessionOptions**)*/ output); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtEnableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtEnableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtDisableSequentialExecution(IntPtr /*(OrtSessionOptions*)*/ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix); + public static extern IntPtr /*(OrtStatus*)*/ OrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtEnableMemPattern(IntPtr /* OrtSessionOptions* */ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtEnableMemPattern(IntPtr /* OrtSessionOptions* */ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtDisableMemPattern(IntPtr /* OrtSessionOptions* */ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtDisableMemPattern(IntPtr /* OrtSessionOptions* */ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtEnableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtEnableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); + public static extern IntPtr /*(OrtStatus*)*/ OrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, string logId); + public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, string logId); [DllImport(nativeLib, CharSet = charSet)] - public static extern void OrtSetSessionLogVerbosityLevel(IntPtr /* OrtSessionOptions* */ options, LogLevel sessionLogVerbosityLevel); + public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionLogVerbosityLevel(IntPtr /* OrtSessionOptions* */ options, LogLevel sessionLogVerbosityLevel); [DllImport(nativeLib, CharSet = charSet)] - public static extern int OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize); + public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionThreadPoolSize(IntPtr /* OrtSessionOptions* */ options, int sessionThreadPoolSize); [DllImport(nativeLib, CharSet = charSet)] - public static extern int OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel); + public static extern IntPtr /*(OrtStatus*)*/ OrtSetSessionGraphOptimizationLevel(IntPtr /* OrtSessionOptions* */ options, uint graphOptimizationLevel); ///** diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index c5d36c6a05..b4157a616a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -21,7 +21,7 @@ namespace Microsoft.ML.OnnxRuntime /// public SessionOptions() { - _nativePtr = NativeMethods.OrtCreateSessionOptions(); + NativeMethods.OrtCreateSessionOptions(out _nativePtr); } /// @@ -32,11 +32,9 @@ namespace Microsoft.ML.OnnxRuntime /// 0 -> Disable all optimizations /// 1 -> Enable basic optimizations /// 2 -> Enable all optimizations - /// True on success and false otherwise - public bool SetSessionGraphOptimizationLevel(uint optimization_level) + public void SetSessionGraphOptimizationLevel(uint optimization_level) { - var result = NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level); - return result == 0; + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(_nativePtr, optimization_level)); } /// @@ -45,7 +43,7 @@ namespace Microsoft.ML.OnnxRuntime /// public void EnableSequentialExecution() { - NativeMethods.OrtEnableSequentialExecution(_nativePtr); + NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableSequentialExecution(_nativePtr)); } /// @@ -54,7 +52,7 @@ namespace Microsoft.ML.OnnxRuntime /// public void DisableSequentialExecution() { - NativeMethods.OrtDisableSequentialExecution(_nativePtr); + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableSequentialExecution(_nativePtr)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index f0d4362b0a..1be67b5fa1 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -61,7 +61,7 @@ namespace Microsoft.ML.OnnxRuntime.Tests // Set the graph optimization level for this session. SessionOptions options = new SessionOptions(); - Assert.True(options.SetSessionGraphOptimizationLevel(graphOptimizationLevel)); + options.SetSessionGraphOptimizationLevel(graphOptimizationLevel); if(disableSequentialExecution) options.DisableSequentialExecution(); using (var session = new InferenceSession(modelPath, options)) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6f1b723882..bfdc6b47f1 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -199,48 +199,46 @@ ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, /** * \return A pointer of the newly created object. The pointer should be freed by OrtReleaseSessionOptions after use */ -ORT_API(OrtSessionOptions*, OrtCreateSessionOptions); +ORT_API_STATUS(OrtCreateSessionOptions, _Out_ OrtSessionOptions** output); // create a copy of an existing OrtSessionOptions -ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions*); -ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options); -ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtCloneSessionOptions, _In_ OrtSessionOptions* in, _Out_ OrtSessionOptions** output); +ORT_API_STATUS(OrtEnableSequentialExecution, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtDisableSequentialExecution, _In_ OrtSessionOptions* options); // Enable profiling for this session. -ORT_API(void, OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); -ORT_API(void, OrtDisableProfiling, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); +ORT_API_STATUS(OrtDisableProfiling, _In_ OrtSessionOptions* options); // Enable the memory pattern optimization. // The idea is if the input shapes are the same, we could trace the internal memory allocation // and generate a memory pattern for future request. So next time we could just do one allocation // with a big chunk for all the internal memory allocation. // Note: memory pattern optimization is only available when SequentialExecution enabled. -ORT_API(void, OrtEnableMemPattern, _In_ OrtSessionOptions* options); -ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtEnableMemPattern, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtDisableMemPattern, _In_ OrtSessionOptions* options); // Enable the memory arena on CPU // Arena may pre-allocate memory for future usage. // set this option to false if you don't want it. -ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options); -ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtEnableCpuMemArena, _In_ OrtSessionOptions* options); +ORT_API_STATUS(OrtDisableCpuMemArena, _In_ OrtSessionOptions* options); // < logger id to use for session output -ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid); +ORT_API_STATUS(OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid); // < applies to session load, initialization, etc -ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level); +ORT_API_STATUS(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level); // Set Graph optimization level. -// Return 0 on success and -1 otherwise // Available options are : 0, 1, 2. // 0 -> Disable all optimizations // 1 -> Enable basic optimizations // 2 -> Enable all optimizations -ORT_API(int, OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level); +ORT_API_STATUS(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level); // How many threads in the session thread pool. -// Returns 0 on success, and -1 otherwise -ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size); +ORT_API_STATUS(OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size); /** * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 6a6dc758c1..63f91f0c78 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -107,67 +107,68 @@ inline RunOptions& RunOptions::SetTerminate(bool flag) { return *this; } -inline SessionOptions::SessionOptions() : Base{OrtCreateSessionOptions()} { +inline SessionOptions::SessionOptions() { + ORT_THROW_ON_ERROR(OrtCreateSessionOptions(&p_)); } inline SessionOptions SessionOptions::Clone() const { - return SessionOptions{OrtCloneSessionOptions(p_)}; + OrtSessionOptions* out; + ORT_THROW_ON_ERROR(OrtCloneSessionOptions(p_, &out)); + return SessionOptions{out}; } inline SessionOptions& SessionOptions::SetThreadPoolSize(int session_thread_pool_size) { - if (OrtSetSessionThreadPoolSize(p_, session_thread_pool_size) == -1) - throw Exception("Error calling SessionOptions::SetThreadPoolSize", ORT_FAIL); + ORT_THROW_ON_ERROR(OrtSetSessionThreadPoolSize(p_, session_thread_pool_size)); return *this; } inline SessionOptions& SessionOptions::SetGraphOptimizationLevel(uint32_t graph_optimization_level) { - if (OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level) == -1) - throw Exception("Error calling SessionOptions::SetGraphOptimizationLevel", ORT_FAIL); + ORT_THROW_ON_ERROR(OrtSetSessionGraphOptimizationLevel(p_, graph_optimization_level)); return *this; } inline SessionOptions& SessionOptions::EnableProfiling(const ORTCHAR_T* profile_file_prefix) { - OrtEnableProfiling(p_, profile_file_prefix); + ORT_THROW_ON_ERROR(OrtEnableProfiling(p_, profile_file_prefix)); return *this; } inline SessionOptions& SessionOptions::DisableProfiling() { - OrtDisableProfiling(p_); + ORT_THROW_ON_ERROR(OrtDisableProfiling(p_)); return *this; } inline SessionOptions& SessionOptions::EnableMemPattern() { - OrtEnableMemPattern(p_); + ORT_THROW_ON_ERROR(OrtEnableMemPattern(p_)); return *this; } inline SessionOptions& SessionOptions::DisableMemPattern() { - OrtDisableMemPattern(p_); + ORT_THROW_ON_ERROR(OrtDisableMemPattern(p_)); return *this; } inline SessionOptions& SessionOptions::EnableCpuMemArena() { - OrtEnableCpuMemArena(p_); + ORT_THROW_ON_ERROR(OrtEnableCpuMemArena(p_)); return *this; } inline SessionOptions& SessionOptions::DisableCpuMemArena() { - OrtDisableCpuMemArena(p_); + ORT_THROW_ON_ERROR(OrtDisableCpuMemArena(p_)); return *this; } inline SessionOptions& SessionOptions::EnableSequentialExecution() { - OrtEnableSequentialExecution(p_); + ORT_THROW_ON_ERROR(OrtEnableSequentialExecution(p_)); return *this; } inline SessionOptions& SessionOptions::DisableSequentialExecution() { - OrtDisableSequentialExecution(p_); + ORT_THROW_ON_ERROR(OrtDisableSequentialExecution(p_)); return *this; } inline SessionOptions& SessionOptions::SetLogId(const char* logid) { - OrtSetSessionLogId(p_, logid); + ORT_THROW_ON_ERROR(OrtSetSessionLogId(p_, logid)); return *this; } inline SessionOptions& SessionOptions::Add(OrtCustomOpDomain* custom_op_domain) { diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index 90c02f5f7b..73c6706a4f 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -8,3 +8,13 @@ namespace onnxruntime { OrtStatus* ToOrtStatus(const onnxruntime::common::Status& st); }; + +#define API_IMPL_BEGIN try { +#define API_IMPL_END \ + } \ + catch (const onnxruntime::NotImplementedException& ex) { \ + return OrtCreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ + } \ + catch (const std::exception& ex) { \ + return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ + } diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index a283fad921..aeaab0b248 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/session/onnxruntime_c_api.h" +#include "core/framework/error_code_helper.h" #include #include #include "core/session/inference_session.h" @@ -16,86 +17,97 @@ OrtSessionOptions::OrtSessionOptions(const OrtSessionOptions& other) : value(other.value), provider_factories(other.provider_factories) { } -ORT_API(OrtSessionOptions*, OrtCreateSessionOptions) { - std::unique_ptr options = std::make_unique(); - return options.release(); +ORT_API_STATUS_IMPL(OrtCreateSessionOptions, OrtSessionOptions** out) { + API_IMPL_BEGIN + *out = new OrtSessionOptions(); + return nullptr; + API_IMPL_END } ORT_API(void, OrtReleaseSessionOptions, OrtSessionOptions* ptr) { delete ptr; } -ORT_API(OrtSessionOptions*, OrtCloneSessionOptions, OrtSessionOptions* input) { - try { - return new OrtSessionOptions(*input); - } catch (std::exception&) { - return nullptr; - } +ORT_API_STATUS_IMPL(OrtCloneSessionOptions, OrtSessionOptions* input, OrtSessionOptions** out) { + API_IMPL_BEGIN + *out = new OrtSessionOptions(*input); + return nullptr; + API_IMPL_END } -ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtEnableSequentialExecution, _In_ OrtSessionOptions* options) { options->value.enable_sequential_execution = true; + return nullptr; } -ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtDisableSequentialExecution, _In_ OrtSessionOptions* options) { options->value.enable_sequential_execution = false; + return nullptr; } // enable profiling for this session. -ORT_API(void, OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix) { +ORT_API_STATUS_IMPL(OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix) { options->value.enable_profiling = true; options->value.profile_file_prefix = profile_file_prefix; + return nullptr; } -ORT_API(void, OrtDisableProfiling, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtDisableProfiling, _In_ OrtSessionOptions* options) { options->value.enable_profiling = false; options->value.profile_file_prefix.clear(); + return nullptr; } // enable the memory pattern optimization. // The idea is if the input shapes are the same, we could trace the internal memory allocation // and generate a memory pattern for future request. So next time we could just do one allocation // with a big chunk for all the internal memory allocation. -ORT_API(void, OrtEnableMemPattern, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtEnableMemPattern, _In_ OrtSessionOptions* options) { options->value.enable_mem_pattern = true; + return nullptr; } -ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtDisableMemPattern, _In_ OrtSessionOptions* options) { options->value.enable_mem_pattern = false; + return nullptr; } // enable the memory arena on CPU // Arena may pre-allocate memory for future usage. // set this option to false if you don't want it. -ORT_API(void, OrtEnableCpuMemArena, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtEnableCpuMemArena, _In_ OrtSessionOptions* options) { options->value.enable_cpu_mem_arena = true; + return nullptr; } -ORT_API(void, OrtDisableCpuMemArena, _In_ OrtSessionOptions* options) { +ORT_API_STATUS_IMPL(OrtDisableCpuMemArena, _In_ OrtSessionOptions* options) { options->value.enable_cpu_mem_arena = false; + return nullptr; } ///< logger id to use for session output -ORT_API(void, OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid) { +ORT_API_STATUS_IMPL(OrtSetSessionLogId, _In_ OrtSessionOptions* options, const char* logid) { options->value.session_logid = logid; + return nullptr; } ///< applies to session load, initialization, etc -ORT_API(void, OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level) { +ORT_API_STATUS_IMPL(OrtSetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, uint32_t session_log_verbosity_level) { options->value.session_log_verbosity_level = session_log_verbosity_level; + return nullptr; } // Set Graph optimization level. // Returns 0 on success and -1 otherwise // Available options are : 0, 1, 2. -ORT_API(int, OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level) { - if (graph_optimization_level >= static_cast(onnxruntime::TransformerLevel::MaxTransformerLevel)) { - return -1; - } +ORT_API_STATUS_IMPL(OrtSetSessionGraphOptimizationLevel, _In_ OrtSessionOptions* options, uint32_t graph_optimization_level) { + if (graph_optimization_level >= static_cast(onnxruntime::TransformerLevel::MaxTransformerLevel)) + return OrtCreateStatus(ORT_INVALID_ARGUMENT, "graph_optimization_level is not valid"); options->value.graph_optimization_level = static_cast(graph_optimization_level); - return 0; + return nullptr; } ///How many threads in the session thread pool. -ORT_API(int, OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size) { - if (session_thread_pool_size <= 0) return -1; +ORT_API_STATUS_IMPL(OrtSetSessionThreadPoolSize, _In_ OrtSessionOptions* options, int session_thread_pool_size) { + if (session_thread_pool_size <= 0) + return OrtCreateStatus(ORT_INVALID_ARGUMENT, "session_thread_pool_size is not valid"); options->value.session_thread_pool_size = session_thread_pool_size; - return 0; + return nullptr; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a6eeb5f369..b9596180bd 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -62,13 +62,6 @@ struct OrtEnv { ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEnv); }; -#define API_IMPL_BEGIN try { -#define API_IMPL_END \ - } \ - catch (std::exception & ex) { \ - return OrtCreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ - } - #define TENSOR_READ_API_BEGIN \ API_IMPL_BEGIN \ auto v = reinterpret_cast(value); \ diff --git a/onnxruntime/test/shared_lib/fns_candy_style_transfer.c b/onnxruntime/test/shared_lib/fns_candy_style_transfer.c index 5aa2f74c3f..a22a1a449f 100644 --- a/onnxruntime/test/shared_lib/fns_candy_style_transfer.c +++ b/onnxruntime/test/shared_lib/fns_candy_style_transfer.c @@ -181,8 +181,8 @@ void verify_input_output_count(OrtSession* session) { } #ifdef USE_CUDA -void enable_cuda(OrtSessionOptions* session_option) { - ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_option, 0)); +void enable_cuda(OrtSessionOptions* session_options) { + ORT_ABORT_ON_ERROR(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); } #endif @@ -196,15 +196,16 @@ int main(int argc, char* argv[]) { char* output_file = argv[3]; OrtEnv* env; ORT_ABORT_ON_ERROR(OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &env)); - OrtSessionOptions* session_option = OrtCreateSessionOptions(); + OrtSessionOptions* session_options; + ORT_ABORT_ON_ERROR(OrtCreateSessionOptions(&session_options)); #ifdef USE_CUDA - enable_cuda(session_option); + enable_cuda(session_options); #endif OrtSession* session; - ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_option, &session)); + ORT_ABORT_ON_ERROR(OrtCreateSession(env, model_path, session_options, &session)); verify_input_output_count(session); int ret = run_inference(session, input_file, output_file); - OrtReleaseSessionOptions(session_option); + OrtReleaseSessionOptions(session_options); OrtReleaseSession(session); OrtReleaseEnv(env); if (ret != 0) { diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index 597a96aaa8..60fbc6b819 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -14,6 +14,10 @@ TEST_F(CApiTest, session_options_graph_optimization_level) { options.SetGraphOptimizationLevel(valid_optimization_level); // Test set optimization level fails when invalid level is provided. - uint32_t invalid_level = static_cast(TransformerLevel::MaxTransformerLevel); - ASSERT_EQ(OrtSetSessionGraphOptimizationLevel(options, invalid_level), -1); + try { + uint32_t invalid_level = static_cast(TransformerLevel::MaxTransformerLevel); + options.SetGraphOptimizationLevel(invalid_level); + } catch (const Ort::Exception& e) { + ASSERT_EQ(e.GetOrtErrorCode(), ORT_INVALID_ARGUMENT); + } }