[js/webgpu] Support capture and replay for jsep (#18989)

### Description
This PR expands the graph capture capability to JS EP, which is similar
to #16081. But for JS EP, we don't use the CUDA Graph, instead, we
records all gpu commands and replay them, which removes most of the cpu
overhead to avoid the the situation that gpu waiting for cpu.

mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from
4.58ms on Intel A770.

All limitations are similar with CUDA EP:
1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not
supported.
2. Usage of graph capture is limited to models where-in all ops in the
model can be partitioned to the JS EP or CPU EP and no memory copy
between them.
3. Shapes of inputs/outputs cannot change across inference calls.
4. IObinding is required.

The usage is like below:
Method 1: specify outputs buffers explicitly.
```
    const sessionOptions = {
        executionProviders: [
          {
            name: "webgpu",
          },
        ],
        enableGraphCapture: true,
      };
    const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions);
   
    // prepare the inputBuffer/outputBuffer
    ... ...

   const feeds = {
       'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims })
   };

   const fetches = {
       'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] })
   };

   let results = await session.run(feeds, fetches);  // The first run will begin to capture the graph.

   // update inputBuffer content
  ... ...
   results = = await session.run(feeds, fetches);  // The 2ed run and after will directly call replay to execute the graph.

  ... ...
   session.release();
```
Method 2: Don't specify outputs buffers explicitly. Internally, when
graph capture is enabled, it will set all outputs location to
'gpu-buffer'.
```
    const sessionOptions = {
        executionProviders: [
          {
            name: "webgpu",
          },
        ],
        enableGraphCapture: true,
      };
    const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions);

    // prepare the inputBuffer
    ... ...

   const feeds = {
       'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims })
   };

   let results = await session.run(feeds);  // The first run will begin to capture the graph.
   
   // update inputBuffer content
  ... ...
   results = = await session.run(feeds);  // The 2ed run and after will directly call replay to execute the graph.

  ... ...
   session.release();
This commit is contained in:
Jiajia Qin 2024-01-31 10:28:03 +08:00 committed by GitHub
parent 6dd0079d13
commit 85cef0af8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 437 additions and 137 deletions

View file

@ -111,7 +111,7 @@ export declare namespace InferenceSession {
optimizedModelFilePath?: string;
/**
* Wether enable profiling.
* Whether enable profiling.
*
* This setting is a placeholder for a future use.
*/
@ -154,6 +154,12 @@ export declare namespace InferenceSession {
*/
preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation};
/**
* Whether enable graph capture.
* This setting is available only in ONNXRuntime Web for WebGPU EP.
*/
enableGraphCapture?: boolean;
/**
* Store configurations for a session. See
* https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/

View file

@ -13,6 +13,9 @@ export declare namespace JSEP {
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction =
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => number;
type CaptureBeginFunction = () => void;
type CaptureEndFunction = () => void;
type ReplayFunction = () => void;
}
export interface OrtWasmModule extends EmscriptenModule {
@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepInit?
(backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction,
captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void;
/**
* [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule {
* @returns the GPU data ID for the registered GPU buffer.
*/
jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
/**
* [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
*
* @param sessionId - specify the session ID.
*/
jsepUnregisterBuffers?: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
*
@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule {
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
/**
* [exported from js_internal_api.js] Called when InferenceSession.run started.
* [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
* _OrtRun[WithBinding]() is called.
* @param sessionId - specify the session ID.
*/
jsepOnRunStart: () => void;
jsepOnRunStart: (sessionId: number) => void;
/**
* [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
* called.
* @param sessionId - specify the session ID.
* @returns
*/
jsepOnReleaseSession: (sessionId: number) => void;
// #endregion
}

View file

@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types';
import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
interface CommandInfo {
readonly kernelId: number;
readonly computePipeline: GPUComputePipeline;
readonly bindGroup: GPUBindGroup;
readonly dispatchGroup: [number, number, number];
}
interface KernelInfo {
readonly kernelType: string;
@ -103,6 +110,13 @@ export class WebGpuBackend {
*/
programManager: ProgramManager;
/**
* representing the session ID of which is currently being run.
* `null` means no session is being run.
* only valid when session.run is executed.
*/
currentSessionId: number|null = null;
/**
* representing the kernel ID of which is currently being computed (CPU code perspective).
* `null` means no kernel is being computed.
@ -155,6 +169,16 @@ export class WebGpuBackend {
queryType: TimestampQuery;
env: Env;
sessionStatus: SessionState = 'default';
/**
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
*/
capturedCommandList: Map<number, CommandInfo[]> = new Map();
/**
* a SessionID -> PendingKernelInfo[] mapping for profiling.
*/
private capturedPendingKernels: Map<number, PendingKernelInfo[]> = new Map();
/**
* a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
@ -228,6 +252,7 @@ export class WebGpuBackend {
getComputePassEncoder(): GPUComputePassEncoder {
if (!this.computePassEncoder) {
const commandEncoder = this.getCommandEncoder();
const computePassDescriptor: GPUComputePassDescriptor = {};
if (this.queryType === 'at-passes') {
@ -238,7 +263,7 @@ export class WebGpuBackend {
};
}
this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor);
this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
}
return this.computePassEncoder;
}
@ -494,7 +519,7 @@ export class WebGpuBackend {
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);
if (this.queryType !== 'none') {
if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
const pendingKernelInfo: PendingKernelInfo = {
kernelId: this.currentKernelId!,
programName: artifact.programInfo.name,
@ -502,6 +527,9 @@ export class WebGpuBackend {
outputTensorViews,
};
this.pendingKernels.push(pendingKernelInfo);
const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
sessionPendingKernels!.push(pendingKernelInfo);
}
this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);
@ -672,7 +700,71 @@ export class WebGpuBackend {
}
}
}
onRunStart(): void {
captureBegin(): void {
LOG_DEBUG('info', 'captureBegin');
let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
if (!sessionCommandList) {
sessionCommandList = [];
this.capturedCommandList.set(this.currentSessionId!, sessionCommandList);
sessionPendingKernels = [];
this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels);
}
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'capturing';
}
captureEnd(): void {
LOG_DEBUG('info', 'captureEnd');
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'default';
}
replay(): void {
LOG_DEBUG('info', 'replay');
this.sessionStatus = 'replaying';
const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
const length = sessionCommandList!.length;
this.pendingKernels = [];
for (let i = 0; i < length; i++) {
const computePassEncoder = this.getComputePassEncoder();
const command = sessionCommandList![i];
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
this.pendingKernels.push(sessionPendingKernels![i]);
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
this.endComputePass();
}
if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
this.flush();
}
}
// flush the left commands before we change the status.
this.flush();
this.sessionStatus = 'default';
}
onReleaseSession(sessionId: number): void {
this.unregisterBuffers(sessionId);
if (this.capturedCommandList.has(sessionId)) {
this.capturedCommandList.delete(sessionId);
}
if (this.capturedPendingKernels.has(sessionId)) {
this.capturedPendingKernels.delete(sessionId);
}
this.gpuDataManager.onReleaseSession(sessionId);
}
onRunStart(sessionId: number): void {
this.currentSessionId = sessionId;
this.setQueryType();
}
}

View file

@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
},
// jsepCaptureBegin
() => backend.captureBegin(),
// jsepCaptureEnd
() => backend.captureEnd(),
// jsepReplay
() => backend.replay());
};

View file

@ -60,9 +60,15 @@ export interface GpuDataManager {
unregisterExternalBuffer(buffer: GPUBuffer): void;
/**
* destroy all gpu buffers. Call this when the session.release is called.
* destroy all gpu buffers.
*/
dispose(): void;
/**
* release session related data.
* @param sessionId - specify the session ID.
*/
onReleaseSession(sessionId: number): void;
}
interface StorageCacheValue {
@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager {
// The external buffers registered users for IO Binding.
private externalBuffers: Map<GPUBuffer, GpuDataId>;
// The pendingBuffers for capture graph.
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map<number, GPUBuffer[]>;
constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager {
this.buffersForUploadingPending = [];
this.buffersPending = [];
this.externalBuffers = new Map();
this.capturedPendingBuffers = new Map();
}
upload(id: GpuDataId, data: Uint8Array): void {
@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager {
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
id}, buffer is the same, skip.`);
return id;
} else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
Please use the previous external buffer!`);
}
this.externalBuffers.delete(previousBuffer);
} else {
@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager {
buffer.destroy();
}
this.buffersForUploadingPending = [];
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
this.freeBuffers.get(buffer.size)!.push(buffer);
// eslint-disable-next-line no-bitwise
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
} else {
buffer.destroy();
}
if (this.buffersPending.length === 0) {
return;
}
if (this.backend.sessionStatus === 'default') {
for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
// Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
this.freeBuffers.get(buffer.size)!.push(buffer);
// eslint-disable-next-line no-bitwise
} else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
// Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
this.freeUniformBuffers.get(buffer.size)!.push(buffer);
} else {
buffer.destroy();
}
}
this.buffersPending = [];
} else {
// Don't release intermediate tensors in non-default mode.
// TODO: reuse the storage buffers in non-default mode.
let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!);
if (!capturedBuffers) {
capturedBuffers = [];
this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers);
}
for (const buffer of this.buffersPending) {
capturedBuffers.push(buffer);
}
this.buffersPending = [];
}
this.buffersPending = [];
}
dispose() {
@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager {
storage.gpuData.buffer.destroy();
});
this.capturedPendingBuffers.forEach((buffers) => {
buffers.forEach(buffer => {
buffer.destroy();
});
});
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
this.capturedPendingBuffers = new Map();
}
onReleaseSession(sessionId: number) {
// release the captured pending buffers.
const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
if (pendingBuffers) {
pendingBuffers.forEach(buffer => {
buffer.destroy();
});
this.capturedPendingBuffers.delete(sessionId);
}
}
}

View file

@ -38,7 +38,6 @@ export class ProgramManager {
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(buildArtifact.computePipeline);
const entries = [];
for (const input of inputs) {
entries.push({binding: entries.length, resource: {buffer: input.buffer}});
@ -51,8 +50,20 @@ export class ProgramManager {
}
const bindGroup = device.createBindGroup(
{layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name});
computePassEncoder.setBindGroup(0, bindGroup);
if (this.backend.sessionStatus === 'capturing') {
const commandInfo = {
kernelId: this.backend.currentKernelId!,
computePipeline: buildArtifact.computePipeline,
bindGroup,
dispatchGroup
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
}
computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;

View file

@ -5,6 +5,8 @@ import {TensorView} from '../tensor-view';
import {ShaderHelper} from './ops/common';
export type SessionState = 'default'|'capturing'|'replaying';
export enum GpuDataType {
default = 0,
upload = 1,

View file

@ -168,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n
setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs);
}
if (sessionOptions.enableGraphCapture !== undefined) {
if (typeof sessionOptions.enableGraphCapture !== 'boolean') {
throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`);
}
const keyDataOffset = allocWasmString('enableGraphCapture', allocs);
const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs);
if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) {
checkLastError(
`Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`);
}
}
if (sessionOptions.freeDimensionOverrides) {
for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) {
if (typeof name !== 'string') {

View file

@ -139,7 +139,7 @@ type IOBindingState = {
*/
type SessionMetadata = [
inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[],
bindingState: IOBindingState|null
bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean
];
const activeSessions = new Map<number, SessionMetadata>();
@ -235,6 +235,8 @@ export const createSession = async(
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
const enableGraphCapture = !!options?.enableGraphCapture;
const inputNames = [];
const outputNames = [];
const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = [];
@ -256,12 +258,20 @@ export const createSession = async(
outputNames.push(nameString);
if (!BUILD_DEFS.DISABLE_WEBGPU) {
if (enableGraphCapture && options?.preferredOutputLocation === undefined) {
outputPreferredLocations.push('gpu-buffer');
continue;
}
const location = typeof options?.preferredOutputLocation === 'string' ?
options.preferredOutputLocation :
options?.preferredOutputLocation?.[nameString] ?? 'cpu';
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') {
throw new Error(`Not supported preferred output location: ${location}.`);
}
if (enableGraphCapture && location !== 'gpu-buffer') {
throw new Error(`Not supported preferred output location: ${
location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`);
}
outputPreferredLocations.push(location);
}
}
@ -281,7 +291,9 @@ export const createSession = async(
};
}
activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]);
activeSessions.set(
sessionHandle,
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]);
return [sessionHandle, inputNames, outputNames];
} catch (e) {
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
@ -313,13 +325,16 @@ export const releaseSession = (sessionId: number): void => {
if (!session) {
throw new Error(`cannot release session. invalid session id: ${sessionId}`);
}
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session;
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session;
if (ioBindingState) {
if (enableGraphCapture) {
wasm._OrtClearBoundOutputs(ioBindingState.handle);
}
wasm._OrtReleaseBinding(ioBindingState.handle);
}
wasm.jsepUnregisterBuffers?.(sessionId);
wasm.jsepOnReleaseSession?.(sessionId);
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
@ -328,70 +343,75 @@ export const releaseSession = (sessionId: number): void => {
};
export const prepareInputOutputTensor =
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number):
void => {
if (!tensor) {
tensorHandles.push(0);
return;
}
(tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number,
enableGraphCapture = false): void => {
if (!tensor) {
tensorHandles.push(0);
return;
}
const wasm = getInstance();
const wasm = getInstance();
const dataType = tensor[0];
const dims = tensor[1];
const location = tensor[3];
const dataType = tensor[0];
const dims = tensor[1];
const location = tensor[3];
let rawData: number;
let dataByteLength: number;
let rawData: number;
let dataByteLength: number;
if (dataType === 'string' && location === 'gpu-buffer') {
throw new Error('String tensor is not supported on GPU.');
}
if (dataType === 'string' && location === 'gpu-buffer') {
throw new Error('String tensor is not supported on GPU.');
}
if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
} else {
const data = tensor[2];
if (enableGraphCapture && location !== 'gpu-buffer') {
throw new Error(
`External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`);
}
if (Array.isArray(data)) {
// string tensor
dataByteLength = 4 * data.length;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
let dataIndex = rawData / 4;
for (let i = 0; i < data.length; i++) {
if (typeof data[i] !== 'string') {
throw new TypeError(`tensor data at index ${i} is not a string`);
}
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
}
} else {
dataByteLength = data.byteLength;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
if (location === 'gpu-buffer') {
const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;
rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength);
} else {
const data = tensor[2];
if (Array.isArray(data)) {
// string tensor
dataByteLength = 4 * data.length;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
let dataIndex = rawData / 4;
for (let i = 0; i < data.length; i++) {
if (typeof data[i] !== 'string') {
throw new TypeError(`tensor data at index ${i} is not a string`);
}
wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs);
}
} else {
dataByteLength = data.byteLength;
rawData = wasm._malloc(dataByteLength);
allocs.push(rawData);
wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData);
}
}
const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
dataLocationStringToEnum(location));
if (tensor === 0) {
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
}
tensorHandles.push(tensor);
} finally {
wasm.stackRestore(stack);
}
};
const stack = wasm.stackSave();
const dimsOffset = wasm.stackAlloc(4 * dims.length);
try {
let dimIndex = dimsOffset / 4;
dims.forEach(d => wasm.HEAP32[dimIndex++] = d);
const tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length,
dataLocationStringToEnum(location));
if (tensor === 0) {
checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`);
}
tensorHandles.push(tensor);
} finally {
wasm.stackRestore(stack);
}
};
/**
* perform inference run
@ -404,7 +424,12 @@ export const run = async(
if (!session) {
throw new Error(`cannot run inference. invalid session id: ${sessionId}`);
}
const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session;
const sessionHandle = session[0];
const inputNamesUTF8Encoded = session[1];
const outputNamesUTF8Encoded = session[2];
const ioBindingState = session[3];
const enableGraphCapture = session[4];
const inputOutputBound = session[5];
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
@ -427,13 +452,15 @@ export const run = async(
// create input tensors
for (let i = 0; i < inputCount; i++) {
prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]);
prepareInputOutputTensor(
inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture);
}
// create output tensors
for (let i = 0; i < outputCount; i++) {
prepareInputOutputTensor(
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]);
outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i],
enableGraphCapture);
}
let inputValuesIndex = inputValuesOffset / 4;
@ -449,7 +476,7 @@ export const run = async(
wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]];
}
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) {
const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState;
if (inputNamesUTF8Encoded.length !== inputCount) {
@ -486,9 +513,12 @@ export const run = async(
}
}
}
activeSessions.set(
sessionId,
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]);
}
wasm.jsepOnRunStart?.();
wasm.jsepOnRunStart?.(sessionHandle);
let errorCode: number;
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
@ -595,10 +625,12 @@ export const run = async(
}
}
if (ioBindingState) {
if (ioBindingState && !enableGraphCapture) {
wasm._OrtClearBoundOutputs(ioBindingState.handle);
activeSessions.set(
sessionId,
[sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]);
}
return output;
} finally {
wasm.stackRestore(beforeRunStack);

View file

@ -3,6 +3,7 @@
#include "js_execution_provider.h"
#include <emscripten.h>
#include <string_view>
#include <unordered_map>
#include <unordered_set>
@ -681,9 +682,13 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
using namespace js;
JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info)
JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options)
: IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)},
preferred_data_layout_{info.data_layout} {
if (session_options) {
enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true";
LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_;
}
}
std::vector<AllocatorPtr> JsExecutionProvider::CreatePreferredAllocators() {
@ -751,4 +756,46 @@ std::unique_ptr<onnxruntime::IDataTransfer> JsExecutionProvider::GetDataTransfer
JsExecutionProvider::~JsExecutionProvider() {
}
Status JsExecutionProvider::OnRunStart() {
if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model";
EM_ASM({ Module.jsepCaptureBegin(); });
}
return Status::OK();
}
Status JsExecutionProvider::OnRunEnd(bool sync_stream) {
if (IsGraphCaptureEnabled() && !IsGraphCaptured()) {
if (IsGraphCaptureAllowed()) {
EM_ASM({ Module.jsepCaptureEnd(); });
is_graph_captured_ = true;
} else {
IncrementRegularRunCountBeforeGraphCapture();
}
}
return Status::OK();
}
bool JsExecutionProvider::IsGraphCaptureEnabled() const {
return enable_graph_capture_;
}
bool JsExecutionProvider::IsGraphCaptured() const {
return is_graph_captured_;
}
Status JsExecutionProvider::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
EM_ASM({ Module.jsepReplay(); });
return Status::OK();
}
bool JsExecutionProvider::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
}
void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
++regular_run_count_before_graph_capture_;
}
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#pragma once
#include "core/framework/execution_provider.h"
#include "core/framework/session_options.h"
#include "core/graph/constants.h"
#include "core/providers/providers.h"
@ -38,7 +39,7 @@ struct JsExecutionProviderInfo {
class JsExecutionProvider : public IExecutionProvider {
public:
JsExecutionProvider(const JsExecutionProviderInfo& info);
JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options);
~JsExecutionProvider() override;
std::vector<std::unique_ptr<ComputeCapability>> GetCapability(
@ -57,7 +58,22 @@ class JsExecutionProvider : public IExecutionProvider {
bool ConcurrentRunSupported() const override { return false; }
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Status OnRunStart() override;
Status OnRunEnd(bool sync_stream) override;
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
private:
bool IsGraphCaptureAllowed() const;
void IncrementRegularRunCountBeforeGraphCapture();
DataLayout preferred_data_layout_;
bool enable_graph_capture_ = false;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations.
};
} // namespace onnxruntime

View file

@ -10,21 +10,22 @@
namespace onnxruntime {
struct JsProviderFactory : IExecutionProviderFactory {
JsProviderFactory(const ProviderOptions& provider_options)
: info_{provider_options} {
JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options)
: info_{provider_options}, session_options_(session_options) {
}
std::unique_ptr<IExecutionProvider> CreateProvider() override {
return std::make_unique<JsExecutionProvider>(info_);
return std::make_unique<JsExecutionProvider>(info_, session_options_);
}
private:
JsExecutionProviderInfo info_;
const SessionOptions* session_options_;
};
std::shared_ptr<IExecutionProviderFactory> JsProviderFactoryCreator::Create(
const ProviderOptions& provider_options) {
return std::make_shared<JsProviderFactory>(provider_options);
const ProviderOptions& provider_options, const SessionOptions* session_options) {
return std::make_shared<JsProviderFactory>(provider_options, session_options);
}
} // namespace onnxruntime

View file

@ -9,9 +9,11 @@
#include "core/providers/providers.h"
namespace onnxruntime {
struct SessionOptions;
struct JsProviderFactoryCreator {
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options);
static std::shared_ptr<IExecutionProviderFactory> Create(const ProviderOptions& provider_options,
const SessionOptions* session_options);
};
} // namespace onnxruntime

View file

@ -145,28 +145,30 @@ static bool HasMemcpyNodes(const Graph& graph) {
return false;
}
static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) {
bool nodes_on_cpu_and_cuda_eps_only = true;
static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) {
bool nodes_on_cpu_and_cuda_and_js_eps_only = true;
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();
// Empty node provider means CPU EP
if (!node_provider.empty() &&
!(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) &&
!(node_provider == kCudaExecutionProvider ||
node_provider == kRocmExecutionProvider ||
node_provider == kJsExecutionProvider) &&
node_provider != kCpuExecutionProvider) {
nodes_on_cpu_and_cuda_eps_only = false;
nodes_on_cpu_and_cuda_and_js_eps_only = false;
break;
}
}
// If we see nodes assigned to EPs other than CPU or CUDA
// If we see nodes assigned to EPs other than CPU, or CUDA/JS
// (or) if there are Memcpy nodes, then all compute nodes have
// not been parititoned to the CUDA EP.
// not been parititoned to the CUDA/JS EP.
// We allow CPU EPs to show up in the EP list as long as thre is no Memcpy
// involved as shape subgraphs will be forced onto CPU and these will not have
// Memcpy nodes involved.
return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph);
return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph);
}
static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) {
@ -1715,8 +1717,7 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
// Currently CUDA graph is only considered by CUDA EP and TRT EP, and
// HIP graph is only considered by ROCM EP.
// Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP.
//
// Check for CUDA EP:
// If the CUDA EP is part of the providers list for this session AND
@ -1730,6 +1731,12 @@ common::Status InferenceSession::Initialize() {
// All the graph nodes have been assigned to the TRT EP,
// Then the TRT EP is cached for triggering a ReplayGraph() in Run().
//
// Check for JS EP:
// If the JS EP is part of the providers list for this session AND
// The JS EP is configured to do a graph capture AND
// All the "compute" graph nodes have been assigned to the JS EP,
// Then the JS EP is cached for triggering a ReplayGraph() in Run().
//
// Check for ROCM EP:
// If the ROCM EP is part of the providers list for this session AND
// The ROCM EP is configured to do a graph capture AND
@ -1739,48 +1746,54 @@ common::Status InferenceSession::Initialize() {
std::vector<const char*> graph_support_ep_list = {
onnxruntime::kTensorrtExecutionProvider,
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider};
onnxruntime::kRocmExecutionProvider,
onnxruntime::kJsExecutionProvider};
for (auto& it : graph_support_ep_list) {
auto* target_ep = execution_providers_.Get(it);
if (target_ep && target_ep->IsGraphCaptureEnabled()) {
// CUDA/HIP Graphs can't work with control flow nodes
// Graphs capture can't work with control flow nodes
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user "
<< "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.";
LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user "
<< "as the model has control flow nodes which can't be supported by "
<< target_ep->Type();
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA/HIP Graph feature as requested by the user "
"as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."));
"This session cannot use the graph capture feature as requested by the user "
"as the model has control flow nodes which can't be supported by" +
target_ep->Type()));
}
if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 ||
strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) {
// Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes
strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0 ||
strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) {
// Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes
// The reasoning behind this logic is that certain shape nodes will be forced onto CPU
// and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP
// which is all we care about.
if (!AreAllComputeNodesAssignedToCudaEp(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user "
<< " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.";
if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user "
<< " as all compute graph nodes have not been partitioned to the "
<< target_ep->Type();
ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA/HIP Graph feature as requested by the user "
" as all compute graph nodes have not been partitioned to the CUDA/HIP EP."));
"This session cannot use the graph capture feature as requested by the user "
" as all compute graph nodes have not been partitioned to the " +
target_ep->Type()));
}
// Log a warning for the user to know that there are shape subgraphs that will execute on CPU
if (HasShapeSubgraphNodes(graph)) {
LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. "
<< "Use the CUDA/HIP Graph feature with caution. "
<< "Use the graph capture feature with caution. "
<< "As long as the intermediate shapes produced in the model "
<< "using the representative input used to capture the CUDA/HIP graph, "
<< "using the representative input used to capture the graph, "
<< "will match the shapes produced in the model for other inputs "
<< "of the same shape as the representative input (common case), "
<< "it is safe to use the CUDA/HIP Graph feature.";
<< "it is safe to use the graph capture feature.";
}
} else {
// Following code path is for TRT EP currently.

View file

@ -145,7 +145,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) {
provider_options["preferred_layout"] = preferred_layout;
}
options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options));
options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value)));
#else
status = create_not_supported_status();
#endif

View file

@ -24,7 +24,7 @@ Module['unmountExternalData'] = () => {
/**
* init JSEP
*/
Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => {
Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => {
Module.jsepBackend = backend;
Module.jsepAlloc = alloc;
Module.jsepFree = free;
@ -33,6 +33,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
Module.jsepCreateKernel = createKernel;
Module.jsepReleaseKernel = releaseKernel;
Module.jsepRunKernel = runKernel;
Module.jsepCaptureBegin = captureBegin;
Module.jsepCaptureEnd = captureEnd;
Module.jsepReplay = replay;
// This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1)
// It removes some overhead in cwarp() and ccall() that we don't need.
@ -181,16 +184,16 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea
Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => {
return backend['registerBuffer'](sessionId, index, buffer, size);
};
Module['jsepUnregisterBuffers'] = sessionId => {
backend['unregisterBuffers'](sessionId);
};
Module['jsepGetBuffer'] = (dataId) => {
return backend['getBuffer'](dataId);
};
Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => {
return backend['createDownloader'](gpuBuffer, size, type);
};
Module['jsepOnRunStart'] = () => {
return backend['onRunStart']();
Module['jsepOnReleaseSession'] = sessionId => {
backend['onReleaseSession'](sessionId);
};
Module['jsepOnRunStart'] = sessionId => {
return backend['onRunStart'](sessionId);
};
};