onnxruntime/js/web/lib/wasm/session-options.ts
Jiajia Qin 85cef0af8c
[js/webgpu] Support capture and replay for jsep (#18989)
### Description
This PR expands the graph capture capability to JS EP, which is similar
to #16081. But for JS EP, we don't use the CUDA Graph, instead, we
records all gpu commands and replay them, which removes most of the cpu
overhead to avoid the the situation that gpu waiting for cpu.

mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from
4.58ms on Intel A770.

All limitations are similar with CUDA EP:
1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not
supported.
2. Usage of graph capture is limited to models where-in all ops in the
model can be partitioned to the JS EP or CPU EP and no memory copy
between them.
3. Shapes of inputs/outputs cannot change across inference calls.
4. IObinding is required.

The usage is like below:
Method 1: specify outputs buffers explicitly.
```
    const sessionOptions = {
        executionProviders: [
          {
            name: "webgpu",
          },
        ],
        enableGraphCapture: true,
      };
    const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions);
   
    // prepare the inputBuffer/outputBuffer
    ... ...

   const feeds = {
       'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims })
   };

   const fetches = {
       'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] })
   };

   let results = await session.run(feeds, fetches);  // The first run will begin to capture the graph.

   // update inputBuffer content
  ... ...
   results = = await session.run(feeds, fetches);  // The 2ed run and after will directly call replay to execute the graph.

  ... ...
   session.release();
```
Method 2: Don't specify outputs buffers explicitly. Internally, when
graph capture is enabled, it will set all outputs location to
'gpu-buffer'.
```
    const sessionOptions = {
        executionProviders: [
          {
            name: "webgpu",
          },
        ],
        enableGraphCapture: true,
      };
    const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions);

    // prepare the inputBuffer
    ... ...

   const feeds = {
       'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims })
   };

   let results = await session.run(feeds);  // The first run will begin to capture the graph.
   
   // update inputBuffer content
  ... ...
   results = = await session.run(feeds);  // The 2ed run and after will directly call replay to execute the graph.

  ... ...
   session.release();
2024-01-30 18:28:03 -08:00

217 lines
9.3 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession} from 'onnxruntime-common';
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError, iterateExtraOptions} from './wasm-utils';
const getGraphOptimzationLevel = (graphOptimizationLevel: string|unknown): number => {
switch (graphOptimizationLevel) {
case 'disabled':
return 0;
case 'basic':
return 1;
case 'extended':
return 2;
case 'all':
return 99;
default:
throw new Error(`unsupported graph optimization level: ${graphOptimizationLevel}`);
}
};
const getExecutionMode = (executionMode: 'sequential'|'parallel'): number => {
switch (executionMode) {
case 'sequential':
return 0;
case 'parallel':
return 1;
default:
throw new Error(`unsupported execution mode: ${executionMode}`);
}
};
const appendDefaultOptions = (options: InferenceSession.SessionOptions): void => {
if (!options.extra) {
options.extra = {};
}
if (!options.extra.session) {
options.extra.session = {};
}
const session = options.extra.session as Record<string, string>;
if (!session.use_ort_model_bytes_directly) {
// eslint-disable-next-line camelcase
session.use_ort_model_bytes_directly = '1';
}
// if using JSEP with WebGPU, always disable memory pattern
if (options.executionProviders &&
options.executionProviders.some(ep => (typeof ep === 'string' ? ep : ep.name) === 'webgpu')) {
options.enableMemPattern = false;
}
};
const setExecutionProviders =
(sessionOptionsHandle: number, executionProviders: readonly InferenceSession.ExecutionProviderConfig[],
allocs: number[]): void => {
for (const ep of executionProviders) {
let epName = typeof ep === 'string' ? ep : ep.name;
// check EP name
switch (epName) {
case 'webnn':
epName = 'WEBNN';
if (typeof ep !== 'string') {
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
if (webnnOptions?.deviceType) {
const keyDataOffset = allocWasmString('deviceType', allocs);
const valueDataOffset = allocWasmString(webnnOptions.deviceType, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`);
}
}
if (webnnOptions?.numThreads) {
let numThreads = webnnOptions.numThreads;
// Just ignore invalid webnnOptions.numThreads.
if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) {
numThreads = 0;
}
const keyDataOffset = allocWasmString('numThreads', allocs);
const valueDataOffset = allocWasmString(numThreads.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`);
}
}
if (webnnOptions?.powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(
`Can't set a session config entry: 'powerPreference' - ${webnnOptions.powerPreference}.`);
}
}
}
break;
case 'webgpu':
epName = 'JS';
if (typeof ep !== 'string') {
const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption;
if (webgpuOptions?.preferredLayout) {
if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') {
throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`);
}
const keyDataOffset = allocWasmString('preferredLayout', allocs);
const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(
`Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`);
}
}
}
break;
case 'wasm':
case 'cpu':
continue;
default:
throw new Error(`not supported execution provider: ${epName}`);
}
const epNameDataOffset = allocWasmString(epName, allocs);
if (getInstance()._OrtAppendExecutionProvider(sessionOptionsHandle, epNameDataOffset) !== 0) {
checkLastError(`Can't append execution provider: ${epName}.`);
}
}
};
export const setSessionOptions = (options?: InferenceSession.SessionOptions): [number, number[]] => {
const wasm = getInstance();
let sessionOptionsHandle = 0;
const allocs: number[] = [];
const sessionOptions: InferenceSession.SessionOptions = options || {};
appendDefaultOptions(sessionOptions);
try {
const graphOptimizationLevel = getGraphOptimzationLevel(sessionOptions.graphOptimizationLevel ?? 'all');
const executionMode = getExecutionMode(sessionOptions.executionMode ?? 'sequential');
const logIdDataOffset =
typeof sessionOptions.logId === 'string' ? allocWasmString(sessionOptions.logId, allocs) : 0;
const logSeverityLevel = sessionOptions.logSeverityLevel ?? 2; // Default to 2 - warning
if (!Number.isInteger(logSeverityLevel) || logSeverityLevel < 0 || logSeverityLevel > 4) {
throw new Error(`log serverity level is not valid: ${logSeverityLevel}`);
}
const logVerbosityLevel = sessionOptions.logVerbosityLevel ?? 0; // Default to 0 - verbose
if (!Number.isInteger(logVerbosityLevel) || logVerbosityLevel < 0 || logVerbosityLevel > 4) {
throw new Error(`log verbosity level is not valid: ${logVerbosityLevel}`);
}
const optimizedModelFilePathOffset = typeof sessionOptions.optimizedModelFilePath === 'string' ?
allocWasmString(sessionOptions.optimizedModelFilePath, allocs) :
0;
sessionOptionsHandle = wasm._OrtCreateSessionOptions(
graphOptimizationLevel, !!sessionOptions.enableCpuMemArena, !!sessionOptions.enableMemPattern, executionMode,
!!sessionOptions.enableProfiling, 0, logIdDataOffset, logSeverityLevel, logVerbosityLevel,
optimizedModelFilePathOffset);
if (sessionOptionsHandle === 0) {
checkLastError('Can\'t create session options.');
}
if (sessionOptions.executionProviders) {
setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs);
}
if (sessionOptions.enableGraphCapture !== undefined) {
if (typeof sessionOptions.enableGraphCapture !== 'boolean') {
throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`);
}
const keyDataOffset = allocWasmString('enableGraphCapture', allocs);
const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs);
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
checkLastError(
`Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`);
}
}
if (sessionOptions.freeDimensionOverrides) {
for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) {
if (typeof name !== 'string') {
throw new Error(`free dimension override name must be a string: ${name}`);
}
if (typeof value !== 'number' || !Number.isInteger(value) || value < 0) {
throw new Error(`free dimension override value must be a non-negative integer: ${value}`);
}
const nameOffset = allocWasmString(name, allocs);
if (wasm._OrtAddFreeDimensionOverride(sessionOptionsHandle, nameOffset, value) !== 0) {
checkLastError(`Can't set a free dimension override: ${name} - ${value}.`);
}
}
}
if (sessionOptions.extra !== undefined) {
iterateExtraOptions(sessionOptions.extra, '', new WeakSet<Record<string, unknown>>(), (key, value) => {
const keyDataOffset = allocWasmString(key, allocs);
const valueDataOffset = allocWasmString(value, allocs);
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
checkLastError(`Can't set a session config entry: ${key} - ${value}.`);
}
});
}
return [sessionOptionsHandle, allocs];
} catch (e) {
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
throw e;
}
};