mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Support additional session options and run options in WebAssembly (#7712)
* add all session options and run options in C API except AddInitializer and AddFreeDimensionOverride * remove unnecessary comment * change extra session and run options to object notation * resolve comments * use an optional chaining for options * resolve comments
This commit is contained in:
parent
6d9f541442
commit
da5f24bd2d
8 changed files with 435 additions and 239 deletions
|
|
@ -93,6 +93,20 @@ export declare namespace InferenceSession {
|
|||
*/
|
||||
executionMode?: 'sequential'|'parallel';
|
||||
|
||||
/**
|
||||
* Wether enable profiling.
|
||||
*
|
||||
* This setting is a placeholder for a future use.
|
||||
*/
|
||||
enableProfiling?: boolean;
|
||||
|
||||
/**
|
||||
* File prefix for profiling.
|
||||
*
|
||||
* This setting is a placeholder for a future use.
|
||||
*/
|
||||
profileFilePrefix?: string;
|
||||
|
||||
/**
|
||||
* Log ID.
|
||||
*
|
||||
|
|
@ -107,6 +121,36 @@ export declare namespace InferenceSession {
|
|||
* This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend
|
||||
*/
|
||||
logSeverityLevel?: 0|1|2|3|4;
|
||||
|
||||
/**
|
||||
* Log verbosity level.
|
||||
*
|
||||
* This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later
|
||||
*/
|
||||
logVerbosityLevel?: number;
|
||||
|
||||
/**
|
||||
* Store configurations for a session. See
|
||||
* https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/
|
||||
* onnxruntime_session_options_config_keys.h
|
||||
*
|
||||
* In example,
|
||||
*
|
||||
* ```js
|
||||
* extra: {
|
||||
* session: {
|
||||
* set_denormal_as_zero: "1",
|
||||
* disable_prepacking: "1"
|
||||
* },
|
||||
* optimization: {
|
||||
* enable_gelu_approximation: "1"
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later
|
||||
*/
|
||||
extra?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
//#region execution providers
|
||||
|
|
@ -163,12 +207,45 @@ export declare namespace InferenceSession {
|
|||
*/
|
||||
logSeverityLevel?: 0|1|2|3|4;
|
||||
|
||||
/**
|
||||
* Log verbosity level.
|
||||
*
|
||||
* This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later
|
||||
*/
|
||||
logVerbosityLevel?: number;
|
||||
|
||||
/**
|
||||
* Terminate all incomplete OrtRun calls as soon as possible if true
|
||||
*
|
||||
* This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later
|
||||
*/
|
||||
terminate?: boolean;
|
||||
|
||||
/**
|
||||
* A tag for the Run() calls using this
|
||||
*
|
||||
* This setting is available only in ONNXRuntime (Node.js binding and react-native) or WebAssembly backend
|
||||
*/
|
||||
tag?: string;
|
||||
|
||||
/**
|
||||
* Set a single run configuration entry. See
|
||||
* https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/session/
|
||||
* onnxruntime_run_options_config_keys.h
|
||||
*
|
||||
* In example,
|
||||
*
|
||||
* ```js
|
||||
* extra: {
|
||||
* memory: {
|
||||
* enable_memory_arena_shrinkage: "1",
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* This setting is available only in WebAssembly backend. Will support Node.js binding and react-native later
|
||||
*/
|
||||
extra?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
//#endregion
|
||||
|
|
|
|||
19
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
19
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
|
|
@ -33,21 +33,16 @@ export interface OrtWasmModule extends EmscriptenModule {
|
|||
sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number,
|
||||
outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number;
|
||||
|
||||
_OrtCreateSessionOptions(): number;
|
||||
_OrtCreateSessionOptions(
|
||||
graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number,
|
||||
enableProfiling: boolean, profileFilePrefix: number, logId: number, logSeverityLevel: number,
|
||||
logVerbosityLevel: number): number;
|
||||
_OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number;
|
||||
_OrtReleaseSessionOptions(sessionOptionsHandle: number): void;
|
||||
_OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle: number, level: number): number;
|
||||
_OrtEnableCpuMemArena(sessionOptionsHandle: number): number;
|
||||
_OrtDisableCpuMemArena(sessionOptionsHandle: number): number;
|
||||
_OrtEnableMemPattern(sessionOptionsHandle: number): number;
|
||||
_OrtDisableMemPattern(sessionOptionsHandle: number): number;
|
||||
_OrtSetSessionExecutionMode(sessionOptionsHandle: number, mode: number): number;
|
||||
_OrtSetSessionLogId(sessionOptionsHandle: number, logid: number): number;
|
||||
_OrtSetSessionLogSeverityLevel(sessionOptionsHandle: number, level: number): number;
|
||||
|
||||
_OrtCreateRunOptions(): number;
|
||||
_OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number;
|
||||
_OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number;
|
||||
_OrtReleaseRunOptions(runOptionsHandle: number): void;
|
||||
_OrtRunOptionsSetRunLogSeverityLevel(runOptionsHandle: number, level: number): number;
|
||||
_OrtRunOptionsSetRunTag(runOptionsHandle: number, tag: number): number;
|
||||
//#endregion
|
||||
|
||||
//#region config
|
||||
|
|
|
|||
44
js/web/lib/wasm/options-utils.ts
Normal file
44
js/web/lib/wasm/options-utils.ts
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {getInstance} from './wasm-factory';
|
||||
|
||||
interface ExtraOptionsHandler {
|
||||
(name: string, value: string): void;
|
||||
}
|
||||
|
||||
export const iterateExtraOptions =
|
||||
(options: Record<string, unknown>, prefix: string, seen: WeakSet<Record<string, unknown>>,
|
||||
handler: ExtraOptionsHandler): void => {
|
||||
if (typeof options == 'object' && options !== null) {
|
||||
if (seen.has(options)) {
|
||||
throw new Error('Circular reference in options');
|
||||
} else {
|
||||
seen.add(options);
|
||||
}
|
||||
}
|
||||
|
||||
Object.entries(options).forEach(([key, value]) => {
|
||||
const name = (prefix) ? prefix + key : key;
|
||||
if (typeof value === 'object') {
|
||||
iterateExtraOptions(value as Record<string, unknown>, name + '.', seen, handler);
|
||||
} else if (typeof value === 'string' || typeof value === 'number') {
|
||||
handler(name, value.toString());
|
||||
} else if (typeof value === 'boolean') {
|
||||
handler(name, (value) ? '1' : '0');
|
||||
} else {
|
||||
throw new Error(`Can't handle extra config type: ${typeof value}`);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
export const allocWasmString = (data: string, allocs: number[]): number => {
|
||||
const wasm = getInstance();
|
||||
|
||||
const dataLength = wasm.lengthBytesUTF8(data) + 1;
|
||||
const dataOffset = wasm._malloc(dataLength);
|
||||
wasm.stringToUTF8(data, dataOffset, dataLength);
|
||||
allocs.push(dataOffset);
|
||||
|
||||
return dataOffset;
|
||||
};
|
||||
65
js/web/lib/wasm/run-options.ts
Normal file
65
js/web/lib/wasm/run-options.ts
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {InferenceSession} from 'onnxruntime-common';
|
||||
|
||||
import {allocWasmString, iterateExtraOptions} from './options-utils';
|
||||
import {getInstance} from './wasm-factory';
|
||||
|
||||
export const setRunOptions = (options: InferenceSession.RunOptions): [number, number[]] => {
|
||||
const wasm = getInstance();
|
||||
let runOptionsHandle = 0;
|
||||
const allocs: number[] = [];
|
||||
|
||||
const runOptions: InferenceSession.RunOptions = options || {};
|
||||
|
||||
try {
|
||||
if (options?.logSeverityLevel === undefined) {
|
||||
runOptions.logSeverityLevel = 2; // Default to warning
|
||||
} else if (
|
||||
typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) ||
|
||||
options.logSeverityLevel < 0 || options.logSeverityLevel > 4) {
|
||||
throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`);
|
||||
}
|
||||
|
||||
if (options?.logVerbosityLevel === undefined) {
|
||||
runOptions.logVerbosityLevel = 0; // Default to 0
|
||||
} else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) {
|
||||
throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`);
|
||||
}
|
||||
|
||||
if (options?.terminate === undefined) {
|
||||
runOptions.terminate = false;
|
||||
}
|
||||
|
||||
let tagDataOffset = 0;
|
||||
if (options?.tag !== undefined) {
|
||||
tagDataOffset = allocWasmString(options.tag, allocs);
|
||||
}
|
||||
|
||||
runOptionsHandle = wasm._OrtCreateRunOptions(
|
||||
runOptions.logSeverityLevel!, runOptions.logVerbosityLevel!, !!runOptions.terminate!, tagDataOffset);
|
||||
if (runOptionsHandle === 0) {
|
||||
throw new Error('Can\'t create run options');
|
||||
}
|
||||
|
||||
if (options?.extra !== undefined) {
|
||||
iterateExtraOptions(options.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {
|
||||
const keyDataOffset = allocWasmString(key, allocs);
|
||||
const valueDataOffset = allocWasmString(value, allocs);
|
||||
|
||||
if (wasm._OrtAddRunConfigEntry(runOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
|
||||
throw new Error(`Can't set a run config entry: ${key} - ${value}`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return [runOptionsHandle, allocs];
|
||||
} catch (e) {
|
||||
if (runOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
||||
}
|
||||
allocs.forEach(wasm._free);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
|
@ -3,6 +3,9 @@
|
|||
|
||||
import {onnx} from 'onnx-proto';
|
||||
import {env, InferenceSession, SessionHandler, Tensor, TypedTensor} from 'onnxruntime-common';
|
||||
|
||||
import {setRunOptions} from './run-options';
|
||||
import {setSessionOptions} from './session-options';
|
||||
import {getInstance} from './wasm-factory';
|
||||
|
||||
let ortInit: boolean;
|
||||
|
|
@ -119,133 +122,6 @@ const getLogLevel = (logLevel: 'verbose'|'info'|'warning'|'error'|'fatal'): numb
|
|||
}
|
||||
};
|
||||
|
||||
const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {
|
||||
const wasm = getInstance();
|
||||
const sessionOptionsHandle = wasm._OrtCreateSessionOptions();
|
||||
const allocs: number[] = [];
|
||||
|
||||
if (sessionOptionsHandle === 0) {
|
||||
throw new Error('Can\'t create session options');
|
||||
}
|
||||
|
||||
if (options === undefined) {
|
||||
return [sessionOptionsHandle, allocs];
|
||||
}
|
||||
|
||||
let errorCode = 0;
|
||||
|
||||
if (options.graphOptimizationLevel !== undefined) {
|
||||
switch (options.graphOptimizationLevel) {
|
||||
case 'disabled':
|
||||
errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 0);
|
||||
break;
|
||||
case 'basic':
|
||||
errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 1);
|
||||
break;
|
||||
case 'extended':
|
||||
errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 2);
|
||||
break;
|
||||
case 'all':
|
||||
errorCode = wasm._OrtSetSessionGraphOptimizationLevel(sessionOptionsHandle, 99);
|
||||
break;
|
||||
default:
|
||||
throw new Error(`unsupported graph optimization level: ${options.graphOptimizationLevel}`);
|
||||
}
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a graph optimization level as a session option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.enableCpuMemArena !== undefined) {
|
||||
if (options.enableCpuMemArena) {
|
||||
errorCode = wasm._OrtEnableCpuMemArena(sessionOptionsHandle);
|
||||
} else {
|
||||
errorCode = wasm._OrtDisableCpuMemArena(sessionOptionsHandle);
|
||||
}
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a CPU memory arena as a session option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.enableMemPattern !== undefined) {
|
||||
if (options.enableMemPattern) {
|
||||
errorCode = wasm._OrtEnableMemPattern(sessionOptionsHandle);
|
||||
} else {
|
||||
errorCode = wasm._OrtDisableMemPattern(sessionOptionsHandle);
|
||||
}
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a memory pattern as a session option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.executionMode !== undefined) {
|
||||
switch (options.executionMode) {
|
||||
case 'sequential':
|
||||
errorCode = wasm._OrtSetSessionExecutionMode(sessionOptionsHandle, 0);
|
||||
break;
|
||||
case 'parallel':
|
||||
errorCode = wasm._OrtSetSessionExecutionMode(sessionOptionsHandle, 1);
|
||||
break;
|
||||
default:
|
||||
throw new Error(`unsupported execution mode: ${options.executionMode}`);
|
||||
}
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set an execution mode as a session option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.logId !== undefined) {
|
||||
const logIdDataLength = wasm.lengthBytesUTF8(options.logId) + 1;
|
||||
const logIdDataOffset = wasm._malloc(logIdDataLength);
|
||||
wasm.stringToUTF8(options.logId, logIdDataOffset, logIdDataLength);
|
||||
errorCode = wasm._OrtSetSessionLogId(sessionOptionsHandle, logIdDataOffset);
|
||||
allocs.push(logIdDataOffset);
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a log id as a session option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.logSeverityLevel !== undefined) {
|
||||
errorCode = wasm._OrtSetSessionLogSeverityLevel(sessionOptionsHandle, options.logSeverityLevel);
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a log severity level as a session option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
return [sessionOptionsHandle, allocs];
|
||||
};
|
||||
|
||||
const setRunOptions = (options: InferenceSession.RunOptions): [number, number[]] => {
|
||||
const wasm = getInstance();
|
||||
const runOptionsHandle = wasm._OrtCreateRunOptions();
|
||||
if (runOptionsHandle === 0) {
|
||||
throw new Error('Can\'t create run options');
|
||||
}
|
||||
|
||||
const allocs: number[] = [];
|
||||
let errorCode = 0;
|
||||
|
||||
if (options.logSeverityLevel !== undefined) {
|
||||
errorCode = wasm._OrtRunOptionsSetRunLogSeverityLevel(runOptionsHandle, options.logSeverityLevel);
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a log severity level as a run option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.tag !== undefined) {
|
||||
const tagDataLength = wasm.lengthBytesUTF8(options.tag) + 1;
|
||||
const tagDataOffset = wasm._malloc(tagDataLength);
|
||||
wasm.stringToUTF8(options.tag, tagDataOffset, tagDataLength);
|
||||
errorCode = wasm._OrtRunOptionsSetRunTag(runOptionsHandle, tagDataOffset);
|
||||
allocs.push(tagDataOffset);
|
||||
if (errorCode !== 0) {
|
||||
throw new Error(`Can't set a tag as a run option. error code = ${errorCode}`);
|
||||
}
|
||||
}
|
||||
|
||||
return [runOptionsHandle, allocs];
|
||||
};
|
||||
|
||||
export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler {
|
||||
private sessionHandle: number;
|
||||
|
||||
|
|
|
|||
110
js/web/lib/wasm/session-options.ts
Normal file
110
js/web/lib/wasm/session-options.ts
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {InferenceSession} from 'onnxruntime-common';
|
||||
|
||||
import {allocWasmString, iterateExtraOptions} from './options-utils';
|
||||
import {getInstance} from './wasm-factory';
|
||||
|
||||
const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): number => {
|
||||
switch (graphOptimizationLevel) {
|
||||
case 'disabled':
|
||||
return 0;
|
||||
case 'basic':
|
||||
return 1;
|
||||
case 'extended':
|
||||
return 2;
|
||||
case 'all':
|
||||
return 99;
|
||||
default:
|
||||
throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);
|
||||
}
|
||||
};
|
||||
|
||||
const getExecutionMode = (executionMode: 'sequential'|'parallel'): number => {
|
||||
switch (executionMode) {
|
||||
case 'sequential':
|
||||
return 0;
|
||||
case 'parallel':
|
||||
return 1;
|
||||
default:
|
||||
throw new Error(`unsupported execution mode: ${executionMode}`);
|
||||
}
|
||||
};
|
||||
|
||||
export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {
|
||||
const wasm = getInstance();
|
||||
let sessionOptionsHandle = 0;
|
||||
const allocs: number[] = [];
|
||||
|
||||
const sessionOptions: InferenceSession.SessionOptions = options || {};
|
||||
|
||||
try {
|
||||
if (options?.graphOptimizationLevel === undefined) {
|
||||
sessionOptions.graphOptimizationLevel = 'all';
|
||||
}
|
||||
const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel!);
|
||||
|
||||
if (options?.enableCpuMemArena === undefined) {
|
||||
sessionOptions.enableCpuMemArena = true;
|
||||
}
|
||||
|
||||
if (options?.enableMemPattern === undefined) {
|
||||
sessionOptions.enableMemPattern = true;
|
||||
}
|
||||
|
||||
if (options?.executionMode === undefined) {
|
||||
sessionOptions.executionMode = 'sequential';
|
||||
}
|
||||
const executionMode = getExecutionMode(sessionOptions.executionMode!);
|
||||
|
||||
let logIdDataOffset = 0;
|
||||
if (options?.logId !== undefined) {
|
||||
logIdDataOffset = allocWasmString(options.logId, allocs);
|
||||
}
|
||||
|
||||
if (options?.logSeverityLevel === undefined) {
|
||||
sessionOptions.logSeverityLevel = 2; // Default to warning
|
||||
} else if (
|
||||
typeof options.logSeverityLevel !== 'number' || !Number.isInteger(options.logSeverityLevel) ||
|
||||
options.logSeverityLevel < 0 || options.logSeverityLevel > 4) {
|
||||
throw new Error(`log serverity level is not valid: ${options.logSeverityLevel}`);
|
||||
}
|
||||
|
||||
if (options?.logVerbosityLevel === undefined) {
|
||||
sessionOptions.logVerbosityLevel = 0; // Default to 0
|
||||
} else if (typeof options.logVerbosityLevel !== 'number' || !Number.isInteger(options.logVerbosityLevel)) {
|
||||
throw new Error(`log verbosity level is not valid: ${options.logVerbosityLevel}`);
|
||||
}
|
||||
|
||||
// TODO: Support profiling
|
||||
sessionOptions.enableProfiling = false;
|
||||
|
||||
sessionOptionsHandle = wasm._OrtCreateSessionOptions(
|
||||
graphOptimizationLevel, !!sessionOptions.enableCpuMemArena!, !!sessionOptions.enableMemPattern!, executionMode,
|
||||
sessionOptions.enableProfiling, 0, logIdDataOffset, sessionOptions.logSeverityLevel!,
|
||||
sessionOptions.logVerbosityLevel!);
|
||||
if (sessionOptionsHandle === 0) {
|
||||
throw new Error('Can\'t create session options');
|
||||
}
|
||||
|
||||
if (options?.extra !== undefined) {
|
||||
iterateExtraOptions(options.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {
|
||||
const keyDataOffset = allocWasmString(key, allocs);
|
||||
const valueDataOffset = allocWasmString(value, allocs);
|
||||
|
||||
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
|
||||
throw new Error(`Can't set a session config entry: ${key} - ${value}`);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return [sessionOptionsHandle, allocs];
|
||||
} catch (e) {
|
||||
if (sessionOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
||||
}
|
||||
allocs.forEach(wasm._free);
|
||||
throw e;
|
||||
}
|
||||
};
|
||||
|
|
@ -60,49 +60,62 @@ int OrtInit(int num_threads, int logging_level) {
|
|||
#endif
|
||||
}
|
||||
|
||||
OrtSessionOptions* OrtCreateSessionOptions() {
|
||||
OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level,
|
||||
bool enable_cpu_mem_arena,
|
||||
bool enable_mem_pattern,
|
||||
size_t execution_mode,
|
||||
bool /* enable_profiling */,
|
||||
const char* /* profile_file_prefix */,
|
||||
const char* log_id,
|
||||
size_t log_severity_level,
|
||||
size_t log_verbosity_level) {
|
||||
OrtSessionOptions* session_options = nullptr;
|
||||
return (CHECK_STATUS(CreateSessionOptions, &session_options) == ORT_OK) ? session_options : nullptr;
|
||||
RETURN_NULLPTR_IF_ERROR(CreateSessionOptions, &session_options);
|
||||
|
||||
// assume that a graph optimization level is checked and properly set at JavaScript
|
||||
RETURN_NULLPTR_IF_ERROR(SetSessionGraphOptimizationLevel,
|
||||
session_options,
|
||||
static_cast<GraphOptimizationLevel>(graph_optimization_level));
|
||||
|
||||
if (enable_cpu_mem_arena) {
|
||||
RETURN_NULLPTR_IF_ERROR(EnableCpuMemArena, session_options);
|
||||
} else {
|
||||
RETURN_NULLPTR_IF_ERROR(DisableCpuMemArena, session_options);
|
||||
}
|
||||
|
||||
if (enable_mem_pattern) {
|
||||
RETURN_NULLPTR_IF_ERROR(EnableCpuMemArena, session_options);
|
||||
} else {
|
||||
RETURN_NULLPTR_IF_ERROR(DisableCpuMemArena, session_options);
|
||||
}
|
||||
|
||||
// assume that an execution mode is checked and properly set at JavaScript
|
||||
RETURN_NULLPTR_IF_ERROR(SetSessionExecutionMode, session_options, static_cast<ExecutionMode>(execution_mode));
|
||||
|
||||
// TODO: support profling
|
||||
|
||||
if (log_id != nullptr) {
|
||||
RETURN_NULLPTR_IF_ERROR(SetSessionLogId, session_options, log_id);
|
||||
}
|
||||
|
||||
// assume that a log severity level is checked and properly set at JavaScript
|
||||
RETURN_NULLPTR_IF_ERROR(SetSessionLogSeverityLevel, session_options, log_severity_level);
|
||||
|
||||
RETURN_NULLPTR_IF_ERROR(SetSessionLogVerbosityLevel, session_options, log_verbosity_level);
|
||||
|
||||
return session_options;
|
||||
}
|
||||
|
||||
int OrtAddSessionConfigEntry(OrtSessionOptions* session_options,
|
||||
const char* config_key,
|
||||
const char* config_value) {
|
||||
return CHECK_STATUS(AddSessionConfigEntry, session_options, config_key, config_value);
|
||||
}
|
||||
|
||||
void OrtReleaseSessionOptions(OrtSessionOptions* session_options) {
|
||||
Ort::GetApi().ReleaseSessionOptions(session_options);
|
||||
}
|
||||
|
||||
int OrtSetSessionGraphOptimizationLevel(OrtSessionOptions* session_options, size_t level) {
|
||||
// Assume that a graph optimization level is check and properly set at JavaScript
|
||||
return CHECK_STATUS(SetSessionGraphOptimizationLevel, session_options, static_cast<GraphOptimizationLevel>(level));
|
||||
}
|
||||
|
||||
int OrtEnableCpuMemArena(OrtSessionOptions* session_options) {
|
||||
return CHECK_STATUS(EnableCpuMemArena, session_options);
|
||||
}
|
||||
|
||||
int OrtDisableCpuMemArena(OrtSessionOptions* session_options) {
|
||||
return CHECK_STATUS(DisableCpuMemArena, session_options);
|
||||
}
|
||||
|
||||
int OrtEnableMemPattern(OrtSessionOptions* session_options) {
|
||||
return CHECK_STATUS(EnableMemPattern, session_options);
|
||||
}
|
||||
|
||||
int OrtDisableMemPattern(OrtSessionOptions* session_options) {
|
||||
return CHECK_STATUS(DisableMemPattern, session_options);
|
||||
}
|
||||
|
||||
int OrtSetSessionExecutionMode(OrtSessionOptions* session_options, size_t mode) {
|
||||
// Assume that an execution mode is check and properly set at JavaScript
|
||||
return CHECK_STATUS(SetSessionExecutionMode, session_options, static_cast<ExecutionMode>(mode));
|
||||
}
|
||||
|
||||
int OrtSetSessionLogId(OrtSessionOptions* session_options, const char* logid) {
|
||||
return CHECK_STATUS(SetSessionLogId, session_options, logid);
|
||||
}
|
||||
|
||||
int OrtSetSessionLogSeverityLevel(OrtSessionOptions* session_options, size_t level) {
|
||||
return CHECK_STATUS(SetSessionLogSeverityLevel, session_options, level);
|
||||
}
|
||||
|
||||
OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) {
|
||||
// OrtSessionOptions must not be nullptr.
|
||||
if (session_options == nullptr) {
|
||||
|
|
@ -228,23 +241,41 @@ void OrtReleaseTensor(OrtValue* tensor) {
|
|||
Ort::GetApi().ReleaseValue(tensor);
|
||||
}
|
||||
|
||||
OrtRunOptions* OrtCreateRunOptions() {
|
||||
OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level,
|
||||
size_t log_verbosity_level,
|
||||
bool terminate,
|
||||
const char* tag) {
|
||||
OrtRunOptions* run_options = nullptr;
|
||||
return (CHECK_STATUS(CreateRunOptions, &run_options) == ORT_OK) ? run_options : nullptr;
|
||||
RETURN_NULLPTR_IF_ERROR(CreateRunOptions, &run_options);
|
||||
|
||||
// Assume that a logging level is check and properly set at JavaScript
|
||||
RETURN_NULLPTR_IF_ERROR(RunOptionsSetRunLogSeverityLevel, run_options, log_severity_level);
|
||||
|
||||
RETURN_NULLPTR_IF_ERROR(RunOptionsSetRunLogVerbosityLevel, run_options, log_verbosity_level);
|
||||
|
||||
if (terminate) {
|
||||
RETURN_NULLPTR_IF_ERROR(RunOptionsSetTerminate, run_options);
|
||||
} else {
|
||||
RETURN_NULLPTR_IF_ERROR(RunOptionsUnsetTerminate, run_options);
|
||||
}
|
||||
|
||||
if (tag != nullptr) {
|
||||
RETURN_NULLPTR_IF_ERROR(RunOptionsSetRunTag, run_options, tag);
|
||||
}
|
||||
|
||||
return run_options;
|
||||
}
|
||||
|
||||
int OrtAddRunConfigEntry(OrtRunOptions* run_options,
|
||||
const char* config_key,
|
||||
const char* config_value) {
|
||||
return CHECK_STATUS(AddRunConfigEntry, run_options, config_key, config_value);
|
||||
}
|
||||
|
||||
void OrtReleaseRunOptions(OrtRunOptions* run_options) {
|
||||
Ort::GetApi().ReleaseRunOptions(run_options);
|
||||
}
|
||||
|
||||
int OrtRunOptionsSetRunLogSeverityLevel(OrtRunOptions* run_options, size_t level) {
|
||||
return CHECK_STATUS(RunOptionsSetRunLogSeverityLevel, run_options, level);
|
||||
}
|
||||
|
||||
int OrtRunOptionsSetRunTag(OrtRunOptions* run_options, const char* tag) {
|
||||
return CHECK_STATUS(RunOptionsSetRunTag, run_options, tag);
|
||||
}
|
||||
|
||||
int OrtRun(OrtSession* session,
|
||||
const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count,
|
||||
const char** output_names, size_t output_count, ort_tensor_handle_t* outputs,
|
||||
|
|
|
|||
|
|
@ -28,62 +28,52 @@ extern "C" {
|
|||
|
||||
/**
|
||||
* perform global initialization. should be called only once.
|
||||
* @param numThreads number of total threads to use.
|
||||
* @param num_threads number of total threads to use.
|
||||
* @param logging_level default logging level.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtInit(int numThreads, int logging_level);
|
||||
int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level);
|
||||
|
||||
/**
|
||||
* create an instance of ORT session options.
|
||||
* assume that all enum type parameters, such as graph_optimization_level, execution_mode, and log_severity_level,
|
||||
* are checked and set properly at JavaScript.
|
||||
* @param graph_optimization_level disabled, basic, extended, or enable all
|
||||
* @param enable_cpu_mem_arena enable or disable cpu memory arena
|
||||
* @param enable_mem_pattern enable or disable memory pattern
|
||||
* @param execution_mode sequential or parallel execution mode
|
||||
* @param enable_profiling enable or disable profiling. it's a no-op and for a future use.
|
||||
* @param profile_file_prefix file prefix for profiling data. it's a no-op and for a future use.
|
||||
* @param log_id logger id for session output
|
||||
* @param log_severity_level verbose, info, warning, error or fatal
|
||||
* @param log_verbosity_level vlog level
|
||||
* @returns a pointer to a session option handle and must be freed by calling OrtReleaseSessionOptions().
|
||||
*/
|
||||
ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions();
|
||||
ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t graph_optimization_level,
|
||||
bool enable_cpu_mem_arena,
|
||||
bool enable_mem_pattern,
|
||||
size_t execution_mode,
|
||||
bool enable_profiling,
|
||||
const char* profile_file_prefix,
|
||||
const char* log_id,
|
||||
size_t log_severity_level,
|
||||
size_t log_verbosity_level);
|
||||
|
||||
/**
|
||||
* store configurations for a session.
|
||||
* @param session_options a handle to session options created by OrtCreateSessionOptions
|
||||
* @param config_key configuration keys and value formats are defined in
|
||||
* include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
|
||||
* @param config_value value for config_key
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtAddSessionConfigEntry(ort_session_options_handle_t session_options,
|
||||
const char* config_key,
|
||||
const char* config_value);
|
||||
|
||||
/**
|
||||
* release the specified ORT session options.
|
||||
*/
|
||||
void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options);
|
||||
|
||||
/**
|
||||
* set an optimization level for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtSetSessionGraphOptimizationLevel(ort_session_options_handle_t session_options, size_t level);
|
||||
|
||||
/**
|
||||
* enable CPU memory arena for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtEnableCpuMemArena(ort_session_options_handle_t session_options);
|
||||
|
||||
/**
|
||||
* disable CPU memory arena for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtDisableCpuMemArena(ort_session_options_handle_t session_options);
|
||||
|
||||
/**
|
||||
* enable memory pattern for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtEnableMemPattern(ort_session_options_handle_t session_options);
|
||||
|
||||
/**
|
||||
* disable memory pattern for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtDisableMemPattern(ort_session_options_handle_t session_options);
|
||||
|
||||
/**
|
||||
* set an execution mode for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtSetSessionExecutionMode(ort_session_options_handle_t session_options, size_t mode);
|
||||
|
||||
/**
|
||||
* set a log ID for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtSetSessionLogId(ort_session_options_handle_t session_options, const char* logid);
|
||||
|
||||
/**
|
||||
* set a log severity level for session.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtSetSessionLogSeverityLevel(ort_session_options_handle_t session_options, size_t level);
|
||||
|
||||
/**
|
||||
* create an instance of ORT session.
|
||||
* @param data a pointer to a buffer that contains the ONNX or ORT format model.
|
||||
|
|
@ -154,25 +144,33 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor);
|
|||
|
||||
/**
|
||||
* create an instance of ORT run options.
|
||||
* @param log_severity_level verbose, info, warning, error or fatal
|
||||
* @param log_verbosity_level vlog level
|
||||
* @param terminate if true, all incomplete OrtRun calls will exit as soon as possible
|
||||
* @param tag tag for this run
|
||||
* @returns a pointer to a run option handle and must be freed by calling OrtReleaseRunOptions().
|
||||
*/
|
||||
ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions();
|
||||
ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(size_t log_severity_level,
|
||||
size_t log_verbosity_level,
|
||||
bool terminate,
|
||||
const char* tag);
|
||||
|
||||
/**
|
||||
* set a single run configuration entry
|
||||
* @param run_options a handle to run options created by OrtCreateRunOptions
|
||||
* @param config_key configuration keys and value formats are defined in
|
||||
* include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h
|
||||
* @param config_value value for config_key
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_options,
|
||||
const char* config_key,
|
||||
const char* config_value);
|
||||
|
||||
/**
|
||||
* release the specified ORT run options.
|
||||
*/
|
||||
void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options);
|
||||
|
||||
/**
|
||||
* set log severity level for run.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtRunOptionsSetRunLogSeverityLevel(ort_run_options_handle_t run_options, size_t level);
|
||||
|
||||
/**
|
||||
* set a tag for the Run() calls using this.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtRunOptionsSetRunTag(ort_run_options_handle_t run_options, const char* tag);
|
||||
|
||||
/**
|
||||
* inference the model.
|
||||
* @param session handle of the specified session
|
||||
|
|
|
|||
Loading…
Reference in a new issue