diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 47bf9bc044..38bd56b5d0 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -93,6 +93,20 @@ export declare namespace InferenceSession { */ executionMode?: 'sequential'|'parallel'; + /** + * Wether enable profiling. + * + * This setting is a placeholder for a future use. + */ + enableProfiling?: boolean; + + /** + * File prefix for profiling. + * + * This setting is a placeholder for a future use. + */ + profileFilePrefix?: string; + /** * Log ID. * @@ -107,6 +121,36 @@ export declare namespace InferenceSession { * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ logSeverityLevel?: 0|1|2|3|4; + + /** + * Log verbosity level. + * + * This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later + */ + logVerbosityLevel?: number; + + /** + * Store configurations for a session. See + * https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/ + * onnxruntime_session_options_config_keys.h + * + * In example, + * + * ```js + * extra: { + * session: { + * set_denormal_as_zero: "1", + * disable_prepacking: "1" + * }, + * optimization: { + * enable_gelu_approximation: "1" + * } + * } + * ``` + * + * This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later + */ + extra?: Record; } //#region execution providers @@ -163,12 +207,45 @@ export declare namespace InferenceSession { */ logSeverityLevel?: 0|1|2|3|4; + /** + * Log verbosity level. + * + * This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later + */ + logVerbosityLevel?: number; + + /** + * Terminate all incomplete OrtRun calls as soon as possible if true + * + * This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later + */ + terminate?: boolean; + /** * A tag for the Run() calls using this * * This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend */ tag?: string; + + /** + * Set a single run configuration entry. See + * https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/ + * onnxruntime_run_options_config_keys.h + * + * In example, + * + * ```js + * extra: { + * memory: { + * enable_memory_arena_shrinkage: "1", + * } + * } + * ``` + * + * This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later + */ + extra?: Record; } //#endregion diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index c8173b38e8..9c579530fd 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -33,21 +33,16 @@ export interface OrtWasmModule extends EmscriptenModule { sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number; - _OrtCreateSessionOptions(): number; + _OrtCreateSessionOptions( + graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, + enableProfiling: boolean, profileFilePrefix: number, logId: number, logSeverityLevel: number, + logVerbosityLevel: number): number; + _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; - _OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle: number, level: number): number; - _OrtEnableCpuMemArena(sessionOptionsHandle: number): number; - _OrtDisableCpuMemArena(sessionOptionsHandle: number): number; - _OrtEnableMemPattern(sessionOptionsHandle: number): number; - _OrtDisableMemPattern(sessionOptionsHandle: number): number; - _OrtSetSessionExecutionMode(sessionOptionsHandle: number, mode: number): number; - _OrtSetSessionLogId(sessionOptionsHandle: number, logid: number): number; - _OrtSetSessionLogSeverityLevel(sessionOptionsHandle: number, level: number): number; - _OrtCreateRunOptions(): number; + _OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number; + _OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number; _OrtReleaseRunOptions(runOptionsHandle: number): void; - _OrtRunOptionsSetRunLogSeverityLevel(runOptionsHandle: number, level: number): number; - _OrtRunOptionsSetRunTag(runOptionsHandle: number, tag: number): number; //#endregion //#region config diff --git a/js/web/lib/wasm/options-utils.ts b/js/web/lib/wasm/options-utils.ts new file mode 100644 index 0000000000..2a123921c4 --- /dev/null +++ b/js/web/lib/wasm/options-utils.ts @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {getInstance} from './wasm-factory'; + +interface ExtraOptionsHandler { + (name: string, value: string): void; +} + +export const iterateExtraOptions = + (options: Record, prefix: string, seen: WeakSet>, + handler: ExtraOptionsHandler): void => { + if (typeof options == 'object' && options !== null) { + if (seen.has(options)) { + throw new Error('Circular reference in options'); + } else { + seen.add(options); + } + } + + Object.entries(options).forEach(([key, value]) => { + const name = (prefix) ? prefix + key : key; + if (typeof value === 'object') { + iterateExtraOptions(value as Record, name + '.', seen, handler); + } else if (typeof value === 'string' || typeof value === 'number') { + handler(name, value.toString()); + } else if (typeof value === 'boolean') { + handler(name, (value) ? '1' : '0'); + } else { + throw new Error(`Can't handle extra config type: ${typeof value}`); + } + }); + }; + +export const allocWasmString = (data: string, allocs: number[]): number => { + const wasm = getInstance(); + + const dataLength = wasm.lengthBytesUTF8(data) + 1; + const dataOffset = wasm._malloc(dataLength); + wasm.stringToUTF8(data, dataOffset, dataLength); + allocs.push(dataOffset); + + return dataOffset; +}; diff --git a/js/web/lib/wasm/run-options.ts b/js/web/lib/wasm/run-options.ts new file mode 100644 index 0000000000..365f79b3c4 --- /dev/null +++ b/js/web/lib/wasm/run-options.ts @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from 'onnxruntime-common'; + +import {allocWasmString, iterateExtraOptions} from './options-utils'; +import {getInstance} from './wasm-factory'; + +export const setRunOptions = (options: InferenceSession.RunOptions): [number, number[]] => { + const wasm = getInstance(); + let runOptionsHandle = 0; + const allocs: number[] = []; + + const runOptions: InferenceSession.RunOptions = options || {}; + + try { + if (options?.logSeverityLevel === undefined) { + runOptions.logSeverityLevel = 2; // Default to warning + } else if ( + typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) || + options.logSeverityLevel < 0 || options.logSeverityLevel > 4) { + throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`); + } + + if (options?.logVerbosityLevel === undefined) { + runOptions.logVerbosityLevel = 0; // Default to 0 + } else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) { + throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`); + } + + if (options?.terminate === undefined) { + runOptions.terminate = false; + } + + let tagDataOffset = 0; + if (options?.tag !== undefined) { + tagDataOffset = allocWasmString(options.tag, allocs); + } + + runOptionsHandle = wasm._OrtCreateRunOptions( + runOptions.logSeverityLevel!, runOptions.logVerbosityLevel!, !!runOptions.terminate!, tagDataOffset); + if (runOptionsHandle === 0) { + throw new Error('Can\'t create run options'); + } + + if (options?.extra !== undefined) { + iterateExtraOptions(options.extra, '', new WeakSet>(), (key, value) => { + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + + if (wasm._OrtAddRunConfigEntry(runOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + throw new Error(`Can't set a run config entry: ${key} - ${value}`); + } + }); + } + + return [runOptionsHandle, allocs]; + } catch (e) { + if (runOptionsHandle !== 0) { + wasm._OrtReleaseRunOptions(runOptionsHandle); + } + allocs.forEach(wasm._free); + throw e; + } +}; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index dbed4b5dfc..f8726c3796 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -3,6 +3,9 @@ import {onnx} from 'onnx-proto'; import {env, InferenceSession, SessionHandler, Tensor, TypedTensor} from 'onnxruntime-common'; + +import {setRunOptions} from './run-options'; +import {setSessionOptions} from './session-options'; import {getInstance} from './wasm-factory'; let ortInit: boolean; @@ -119,133 +122,6 @@ const getLogLevel = (logLevel: 'verbose'|'info'|'warning'|'error'|'fatal'): numb } }; -const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { - const wasm = getInstance(); - const sessionOptionsHandle = wasm._OrtCreateSessionOptions(); - const allocs: number[] = []; - - if (sessionOptionsHandle === 0) { - throw new Error('Can\'t create session options'); - } - - if (options === undefined) { - return [sessionOptionsHandle, allocs]; - } - - let errorCode = 0; - - if (options.graphOptimizationLevel !== undefined) { - switch (options.graphOptimizationLevel) { - case 'disabled': - errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 0); - break; - case 'basic': - errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 1); - break; - case 'extended': - errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 2); - break; - case 'all': - errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 99); - break; - default: - throw new Error(`unsupported graph optimization level: ${options.graphOptimizationLevel}`); - } - if (errorCode !== 0) { - throw new Error(`Can't set a graph optimization level as a session option. error code = ${errorCode}`); - } - } - - if (options.enableCpuMemArena !== undefined) { - if (options.enableCpuMemArena) { - errorCode = wasm._OrtEnableCpuMemArena(sessionOptionsHandle); - } else { - errorCode = wasm._OrtDisableCpuMemArena(sessionOptionsHandle); - } - if (errorCode !== 0) { - throw new Error(`Can't set a CPU memory arena as a session option. error code = ${errorCode}`); - } - } - - if (options.enableMemPattern !== undefined) { - if (options.enableMemPattern) { - errorCode = wasm._OrtEnableMemPattern(sessionOptionsHandle); - } else { - errorCode = wasm._OrtDisableMemPattern(sessionOptionsHandle); - } - if (errorCode !== 0) { - throw new Error(`Can't set a memory pattern as a session option. error code = ${errorCode}`); - } - } - - if (options.executionMode !== undefined) { - switch (options.executionMode) { - case 'sequential': - errorCode = wasm._OrtSetSessionExecutionMode(sessionOptionsHandle, 0); - break; - case 'parallel': - errorCode = wasm._OrtSetSessionExecutionMode(sessionOptionsHandle, 1); - break; - default: - throw new Error(`unsupported execution mode: ${options.executionMode}`); - } - if (errorCode !== 0) { - throw new Error(`Can't set an execution mode as a session option. error code = ${errorCode}`); - } - } - - if (options.logId !== undefined) { - const logIdDataLength = wasm.lengthBytesUTF8(options.logId) + 1; - const logIdDataOffset = wasm._malloc(logIdDataLength); - wasm.stringToUTF8(options.logId, logIdDataOffset, logIdDataLength); - errorCode = wasm._OrtSetSessionLogId(sessionOptionsHandle, logIdDataOffset); - allocs.push(logIdDataOffset); - if (errorCode !== 0) { - throw new Error(`Can't set a log id as a session option. error code = ${errorCode}`); - } - } - - if (options.logSeverityLevel !== undefined) { - errorCode = wasm._OrtSetSessionLogSeverityLevel(sessionOptionsHandle, options.logSeverityLevel); - if (errorCode !== 0) { - throw new Error(`Can't set a log severity level as a session option. error code = ${errorCode}`); - } - } - - return [sessionOptionsHandle, allocs]; -}; - -const setRunOptions = (options: InferenceSession.RunOptions): [number, number[]] => { - const wasm = getInstance(); - const runOptionsHandle = wasm._OrtCreateRunOptions(); - if (runOptionsHandle === 0) { - throw new Error('Can\'t create run options'); - } - - const allocs: number[] = []; - let errorCode = 0; - - if (options.logSeverityLevel !== undefined) { - errorCode = wasm._OrtRunOptionsSetRunLogSeverityLevel(runOptionsHandle, options.logSeverityLevel); - if (errorCode !== 0) { - throw new Error(`Can't set a log severity level as a run option. error code = ${errorCode}`); - } - } - - if (options.tag !== undefined) { - const tagDataLength = wasm.lengthBytesUTF8(options.tag) + 1; - const tagDataOffset = wasm._malloc(tagDataLength); - wasm.stringToUTF8(options.tag, tagDataOffset, tagDataLength); - errorCode = wasm._OrtRunOptionsSetRunTag(runOptionsHandle, tagDataOffset); - allocs.push(tagDataOffset); - if (errorCode !== 0) { - throw new Error(`Can't set a tag as a run option. error code = ${errorCode}`); - } - } - - return [runOptionsHandle, allocs]; -}; - export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { private sessionHandle: number; diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts new file mode 100644 index 0000000000..9199a85e7c --- /dev/null +++ b/js/web/lib/wasm/session-options.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from 'onnxruntime-common'; + +import {allocWasmString, iterateExtraOptions} from './options-utils'; +import {getInstance} from './wasm-factory'; + +const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): number => { + switch (graphOptimizationLevel) { + case 'disabled': + return 0; + case 'basic': + return 1; + case 'extended': + return 2; + case 'all': + return 99; + default: + throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`); + } +}; + +const getExecutionMode = (executionMode: 'sequential'|'parallel'): number => { + switch (executionMode) { + case 'sequential': + return 0; + case 'parallel': + return 1; + default: + throw new Error(`unsupported execution mode: ${executionMode}`); + } +}; + +export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => { + const wasm = getInstance(); + let sessionOptionsHandle = 0; + const allocs: number[] = []; + + const sessionOptions: InferenceSession.SessionOptions = options || {}; + + try { + if (options?.graphOptimizationLevel === undefined) { + sessionOptions.graphOptimizationLevel = 'all'; + } + const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel!); + + if (options?.enableCpuMemArena === undefined) { + sessionOptions.enableCpuMemArena = true; + } + + if (options?.enableMemPattern === undefined) { + sessionOptions.enableMemPattern = true; + } + + if (options?.executionMode === undefined) { + sessionOptions.executionMode = 'sequential'; + } + const executionMode = getExecutionMode(sessionOptions.executionMode!); + + let logIdDataOffset = 0; + if (options?.logId !== undefined) { + logIdDataOffset = allocWasmString(options.logId, allocs); + } + + if (options?.logSeverityLevel === undefined) { + sessionOptions.logSeverityLevel = 2; // Default to warning + } else if ( + typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) || + options.logSeverityLevel < 0 || options.logSeverityLevel > 4) { + throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`); + } + + if (options?.logVerbosityLevel === undefined) { + sessionOptions.logVerbosityLevel = 0; // Default to 0 + } else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) { + throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`); + } + + // TODO: Support profiling + sessionOptions.enableProfiling = false; + + sessionOptionsHandle = wasm._OrtCreateSessionOptions( + graphOptimizationLevel, !!sessionOptions.enableCpuMemArena!, !!sessionOptions.enableMemPattern!, executionMode, + sessionOptions.enableProfiling, 0, logIdDataOffset, sessionOptions.logSeverityLevel!, + sessionOptions.logVerbosityLevel!); + if (sessionOptionsHandle === 0) { + throw new Error('Can\'t create session options'); + } + + if (options?.extra !== undefined) { + iterateExtraOptions(options.extra, '', new WeakSet>(), (key, value) => { + const keyDataOffset = allocWasmString(key, allocs); + const valueDataOffset = allocWasmString(value, allocs); + + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + throw new Error(`Can't set a session config entry: ${key} - ${value}`); + } + }); + } + + return [sessionOptionsHandle, allocs]; + } catch (e) { + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(wasm._free); + throw e; + } +}; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 93d7d09f17..c9c490906b 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -60,49 +60,62 @@ int OrtInit(int num_threads, int logging_level) { #endif } -OrtSessionOptions* OrtCreateSessionOptions() { +OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, + bool enable_cpu_mem_arena, + bool enable_mem_pattern, + size_t execution_mode, + bool /* enable_profiling */, + const char* /* profile_file_prefix */, + const char* log_id, + size_t log_severity_level, + size_t log_verbosity_level) { OrtSessionOptions* session_options = nullptr; - return (CHECK_STATUS(CreateSessionOptions, &session_options) == ORT_OK) ? session_options : nullptr; + RETURN_NULLPTR_IF_ERROR(CreateSessionOptions, &session_options); + + // assume that a graph optimization level is checked and properly set at JavaScript + RETURN_NULLPTR_IF_ERROR(SetSessionGraphOptimizationLevel, + session_options, + static_cast(graph_optimization_level)); + + if (enable_cpu_mem_arena) { + RETURN_NULLPTR_IF_ERROR(EnableCpuMemArena, session_options); + } else { + RETURN_NULLPTR_IF_ERROR(DisableCpuMemArena, session_options); + } + + if (enable_mem_pattern) { + RETURN_NULLPTR_IF_ERROR(EnableCpuMemArena, session_options); + } else { + RETURN_NULLPTR_IF_ERROR(DisableCpuMemArena, session_options); + } + + // assume that an execution mode is checked and properly set at JavaScript + RETURN_NULLPTR_IF_ERROR(SetSessionExecutionMode, session_options, static_cast(execution_mode)); + + // TODO: support profling + + if (log_id != nullptr) { + RETURN_NULLPTR_IF_ERROR(SetSessionLogId, session_options, log_id); + } + + // assume that a log severity level is checked and properly set at JavaScript + RETURN_NULLPTR_IF_ERROR(SetSessionLogSeverityLevel, session_options, log_severity_level); + + RETURN_NULLPTR_IF_ERROR(SetSessionLogVerbosityLevel, session_options, log_verbosity_level); + + return session_options; +} + +int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, + const char* config_key, + const char* config_value) { + return CHECK_STATUS(AddSessionConfigEntry, session_options, config_key, config_value); } void OrtReleaseSessionOptions(OrtSessionOptions* session_options) { Ort::GetApi().ReleaseSessionOptions(session_options); } -int OrtSetSessionGraphOptimizationLevel(OrtSessionOptions* session_options, size_t level) { - // Assume that a graph optimization level is check and properly set at JavaScript - return CHECK_STATUS(SetSessionGraphOptimizationLevel, session_options, static_cast(level)); -} - -int OrtEnableCpuMemArena(OrtSessionOptions* session_options) { - return CHECK_STATUS(EnableCpuMemArena, session_options); -} - -int OrtDisableCpuMemArena(OrtSessionOptions* session_options) { - return CHECK_STATUS(DisableCpuMemArena, session_options); -} - -int OrtEnableMemPattern(OrtSessionOptions* session_options) { - return CHECK_STATUS(EnableMemPattern, session_options); -} - -int OrtDisableMemPattern(OrtSessionOptions* session_options) { - return CHECK_STATUS(DisableMemPattern, session_options); -} - -int OrtSetSessionExecutionMode(OrtSessionOptions* session_options, size_t mode) { - // Assume that an execution mode is check and properly set at JavaScript - return CHECK_STATUS(SetSessionExecutionMode, session_options, static_cast(mode)); -} - -int OrtSetSessionLogId(OrtSessionOptions* session_options, const char* logid) { - return CHECK_STATUS(SetSessionLogId, session_options, logid); -} - -int OrtSetSessionLogSeverityLevel(OrtSessionOptions* session_options, size_t level) { - return CHECK_STATUS(SetSessionLogSeverityLevel, session_options, level); -} - OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) { // OrtSessionOptions must not be nullptr. if (session_options == nullptr) { @@ -228,23 +241,41 @@ void OrtReleaseTensor(OrtValue* tensor) { Ort::GetApi().ReleaseValue(tensor); } -OrtRunOptions* OrtCreateRunOptions() { +OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level, + size_t log_verbosity_level, + bool terminate, + const char* tag) { OrtRunOptions* run_options = nullptr; - return (CHECK_STATUS(CreateRunOptions, &run_options) == ORT_OK) ? run_options : nullptr; + RETURN_NULLPTR_IF_ERROR(CreateRunOptions, &run_options); + + // Assume that a logging level is check and properly set at JavaScript + RETURN_NULLPTR_IF_ERROR(RunOptionsSetRunLogSeverityLevel, run_options, log_severity_level); + + RETURN_NULLPTR_IF_ERROR(RunOptionsSetRunLogVerbosityLevel, run_options, log_verbosity_level); + + if (terminate) { + RETURN_NULLPTR_IF_ERROR(RunOptionsSetTerminate, run_options); + } else { + RETURN_NULLPTR_IF_ERROR(RunOptionsUnsetTerminate, run_options); + } + + if (tag != nullptr) { + RETURN_NULLPTR_IF_ERROR(RunOptionsSetRunTag, run_options, tag); + } + + return run_options; +} + +int OrtAddRunConfigEntry(OrtRunOptions* run_options, + const char* config_key, + const char* config_value) { + return CHECK_STATUS(AddRunConfigEntry, run_options, config_key, config_value); } void OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); } -int OrtRunOptionsSetRunLogSeverityLevel(OrtRunOptions* run_options, size_t level) { - return CHECK_STATUS(RunOptionsSetRunLogSeverityLevel, run_options, level); -} - -int OrtRunOptionsSetRunTag(OrtRunOptions* run_options, const char* tag) { - return CHECK_STATUS(RunOptionsSetRunTag, run_options, tag); -} - int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 4c54cd7d52..457f81de1c 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -28,62 +28,52 @@ extern "C" { /** * perform global initialization. should be called only once. - * @param numThreads number of total threads to use. + * @param num_threads number of total threads to use. * @param logging_level default logging level. */ -int EMSCRIPTEN_KEEPALIVE OrtInit(int numThreads, int logging_level); +int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level); /** * create an instance of ORT session options. + * assume that all enum type parameters, such as graph_optimization_level, execution_mode, and log_severity_level, + * are checked and set properly at JavaScript. + * @param graph_optimization_level disabled, basic, extended, or enable all + * @param enable_cpu_mem_arena enable or disable cpu memory arena + * @param enable_mem_pattern enable or disable memory pattern + * @param execution_mode sequential or parallel execution mode + * @param enable_profiling enable or disable profiling. it's a no-op and for a future use. + * @param profile_file_prefix file prefix for profiling data. it's a no-op and for a future use. + * @param log_id logger id for session output + * @param log_severity_level verbose, info, warning, error or fatal + * @param log_verbosity_level vlog level * @returns a pointer to a session option handle and must be freed by calling OrtReleaseSessionOptions(). */ -ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(); +ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t graph_optimization_level, + bool enable_cpu_mem_arena, + bool enable_mem_pattern, + size_t execution_mode, + bool enable_profiling, + const char* profile_file_prefix, + const char* log_id, + size_t log_severity_level, + size_t log_verbosity_level); + +/** + * store configurations for a session. + * @param session_options a handle to session options created by OrtCreateSessionOptions + * @param config_key configuration keys and value formats are defined in + * include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h + * @param config_value value for config_key + */ +int EMSCRIPTEN_KEEPALIVE OrtAddSessionConfigEntry(ort_session_options_handle_t session_options, + const char* config_key, + const char* config_value); /** * release the specified ORT session options. */ void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); -/** - * set an optimization level for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtSetSessionGraphOptimizationLevel(ort_session_options_handle_t session_options, size_t level); - -/** - * enable CPU memory arena for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtEnableCpuMemArena(ort_session_options_handle_t session_options); - -/** - * disable CPU memory arena for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtDisableCpuMemArena(ort_session_options_handle_t session_options); - -/** - * enable memory pattern for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtEnableMemPattern(ort_session_options_handle_t session_options); - -/** - * disable memory pattern for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtDisableMemPattern(ort_session_options_handle_t session_options); - -/** - * set an execution mode for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtSetSessionExecutionMode(ort_session_options_handle_t session_options, size_t mode); - -/** - * set a log ID for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtSetSessionLogId(ort_session_options_handle_t session_options, const char* logid); - -/** - * set a log severity level for session. - */ -int EMSCRIPTEN_KEEPALIVE OrtSetSessionLogSeverityLevel(ort_session_options_handle_t session_options, size_t level); - /** * create an instance of ORT session. * @param data a pointer to a buffer that contains the ONNX or ORT format model. @@ -154,25 +144,33 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); /** * create an instance of ORT run options. + * @param log_severity_level verbose, info, warning, error or fatal + * @param log_verbosity_level vlog level + * @param terminate if true, all incomplete OrtRun calls will exit as soon as possible + * @param tag tag for this run * @returns a pointer to a run option handle and must be freed by calling OrtReleaseRunOptions(). */ -ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(); +ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(size_t log_severity_level, + size_t log_verbosity_level, + bool terminate, + const char* tag); + +/** + * set a single run configuration entry + * @param run_options a handle to run options created by OrtCreateRunOptions + * @param config_key configuration keys and value formats are defined in + * include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h + * @param config_value value for config_key + */ +int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_options, + const char* config_key, + const char* config_value); /** * release the specified ORT run options. */ void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); -/** - * set log severity level for run. - */ -int EMSCRIPTEN_KEEPALIVE OrtRunOptionsSetRunLogSeverityLevel(ort_run_options_handle_t run_options, size_t level); - -/** - * set a tag for the Run() calls using this. - */ -int EMSCRIPTEN_KEEPALIVE OrtRunOptionsSetRunTag(ort_run_options_handle_t run_options, const char* tag); - /** * inference the model. * @param session handle of the specified session