mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-29 03:30:52 +00:00
[WebNN EP] Add cache for MLContexts in the WebNNBackend (#22510)
### Description
This change adds a cache of `MLContext`s keyed by their options to the
`WebNNBackend`. This makes is so that multiple `InferenceSession`s
create with the same options will share the same context.
### Motivation and Context
Since `MLTensor`s are tied `MLContext`s, developer can't easily share
tensors between `InferenceSession` (outside of manually an `MLContext`
and specifying the `context` options). This leads strange behaviors such
as,
```js
const sessionsA = ort.InferenceSession.create(urlA, {
executionProviders: ["webnn"],
preferredOutputLocation: "ml-buffer",
});
const sessionsB = ort.InferenceSession.create(urlB, {
executionProviders: ["webnn"],
});
const temp = await sessionA.run({/* arguments */});
const result = await sessionB.run({"input":temp["output"]}); // ERROR: Failed to execute 'dispatch' on 'MLContext': Invalid inputs: The context of MLGraph doesn't match the context of the MLTensor with name "input".
```
We encountered this behavior when updating the transformers.js version
in the developer preview demos. microsoft/webnn-developer-preview#46
This commit is contained in:
parent
46ff240821
commit
df236c7894
4 changed files with 76 additions and 6 deletions
|
|
@ -32,6 +32,24 @@ const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
|
|||
[DataType.bool, 'uint8'],
|
||||
]);
|
||||
|
||||
type MLContextEntry = {
|
||||
gpuDevice?: GPUDevice;
|
||||
options?: MLContextOptions;
|
||||
mlContext: MLContext;
|
||||
};
|
||||
|
||||
const compareMLContextOptions = (a?: MLContextOptions, b?: MLContextOptions): boolean => {
|
||||
if (a === b) {
|
||||
return true;
|
||||
}
|
||||
if (a === undefined || b === undefined) {
|
||||
return false;
|
||||
}
|
||||
const aKeys = Object.keys(a).sort() as Array<keyof typeof a>;
|
||||
const bKeys = Object.keys(b).sort() as Array<keyof typeof b>;
|
||||
return aKeys.length === bKeys.length && aKeys.every((key, index) => key === bKeys[index] && a[key] === b[key]);
|
||||
};
|
||||
|
||||
/**
|
||||
* WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track
|
||||
* of the current MLContext being used by the sessions.
|
||||
|
|
@ -49,6 +67,10 @@ export class WebNNBackend {
|
|||
* Maps from MLContext to session ids.
|
||||
*/
|
||||
private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
|
||||
/**
|
||||
* Cache of MLContexts.
|
||||
*/
|
||||
private mlContextCache: MLContextEntry[] = [];
|
||||
/**
|
||||
* Current session id.
|
||||
*/
|
||||
|
|
@ -69,6 +91,41 @@ export class WebNNBackend {
|
|||
this.activeSessionId = sessionId;
|
||||
}
|
||||
|
||||
public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise<MLContext> {
|
||||
if (optionsOrDevice instanceof GPUDevice) {
|
||||
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice);
|
||||
if (mlContextIndex !== -1) {
|
||||
return this.mlContextCache[mlContextIndex].mlContext;
|
||||
} else {
|
||||
const mlContext = await navigator.ml.createContext(optionsOrDevice);
|
||||
this.mlContextCache.push({ gpuDevice: optionsOrDevice, mlContext });
|
||||
return mlContext;
|
||||
}
|
||||
} else if (optionsOrDevice === undefined) {
|
||||
const mlContextIndex = this.mlContextCache.findIndex(
|
||||
(entry) => entry.options === undefined && entry.gpuDevice === undefined,
|
||||
);
|
||||
if (mlContextIndex !== -1) {
|
||||
return this.mlContextCache[mlContextIndex].mlContext;
|
||||
} else {
|
||||
const mlContext = await navigator.ml.createContext();
|
||||
this.mlContextCache.push({ mlContext });
|
||||
return mlContext;
|
||||
}
|
||||
}
|
||||
|
||||
const mlContextIndex = this.mlContextCache.findIndex((entry) =>
|
||||
compareMLContextOptions(entry.options, optionsOrDevice),
|
||||
);
|
||||
if (mlContextIndex !== -1) {
|
||||
return this.mlContextCache[mlContextIndex].mlContext;
|
||||
} else {
|
||||
const mlContext = await navigator.ml.createContext(optionsOrDevice);
|
||||
this.mlContextCache.push({ options: optionsOrDevice, mlContext });
|
||||
return mlContext;
|
||||
}
|
||||
}
|
||||
|
||||
public get currentContext(): MLContext {
|
||||
const mlContext = this.getMLContext(this.currentSessionId);
|
||||
if (!mlContext) {
|
||||
|
|
@ -99,6 +156,10 @@ export class WebNNBackend {
|
|||
sessionIds.delete(sessionId);
|
||||
if (sessionIds.size === 0) {
|
||||
this.sessionIdsByMLContext.delete(mlContext);
|
||||
const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.mlContext === mlContext);
|
||||
if (mlContextIndex !== -1) {
|
||||
this.mlContextCache.splice(mlContextIndex, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -303,12 +303,12 @@ export const createSession = async (
|
|||
if (context) {
|
||||
wasm.currentContext = context as MLContext;
|
||||
} else if (gpuDevice) {
|
||||
wasm.currentContext = await navigator.ml.createContext(gpuDevice);
|
||||
wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice);
|
||||
} else {
|
||||
wasm.currentContext = await navigator.ml.createContext({ deviceType, powerPreference });
|
||||
wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference });
|
||||
}
|
||||
} else {
|
||||
wasm.currentContext = await navigator.ml.createContext();
|
||||
wasm.currentContext = await wasm.jsepCreateMLContext!();
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -225,6 +225,13 @@ export declare namespace JSEP {
|
|||
* @returns the MLTensor ID for the external MLTensor.
|
||||
*/
|
||||
jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number;
|
||||
|
||||
/**
|
||||
* [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions.
|
||||
* @param optionsOrGpuDevice - specify the options or GPUDevice.
|
||||
* @returns
|
||||
*/
|
||||
jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise<MLContext>;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -237,11 +237,13 @@ Module['jsepInit'] = (name, params) => {
|
|||
}
|
||||
Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => {
|
||||
return backend['registerMLTensor'](tensor, dataType, shape);
|
||||
}
|
||||
|
||||
};
|
||||
Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => {
|
||||
return backend['createMLContext'](optionsOrGpuDevice);
|
||||
};
|
||||
Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => {
|
||||
return backend['registerMLConstant'](
|
||||
externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles);
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue