mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Cache MLTensors between runs (#22278)
### Description This change enables caching `MLTensor`s between inferences runs. This is done by keeping a reference to `MLTensor`s alive after they have been released. `MLTensor`s are only destroyed once the sessions goes out of scope. ### Motivation and Context Creating and destroying `MTensor`s on every run has a non-trivial performance penalty. This performance penalty materializes when using `ort.Tensors`[location=cpu] for inputs/outputs or when using the CPU EP as a fallback EP for unsupported operators. The former could be mitigated by developer using `ort.Tensors`[location=ml-tensor]. The latter cannot be mitigated by developers.
This commit is contained in:
parent
b4cb937440
commit
1e5bda88f0
2 changed files with 170 additions and 132 deletions
|
|
@ -91,12 +91,12 @@ export class WebNNBackend {
|
|||
// Current session is not a WebNN session.
|
||||
return;
|
||||
}
|
||||
this.tensorManager.releaseTensorsForSession(sessionId);
|
||||
this.mlContextBySessionId.delete(sessionId);
|
||||
const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
|
||||
sessionIds.delete(sessionId);
|
||||
if (sessionIds.size === 0) {
|
||||
this.sessionIdsByMLContext.delete(mlContext);
|
||||
this.tensorManager.releaseTensorsForContext(mlContext);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ export interface TensorManager {
|
|||
download(tensorId: TensorId): Promise<ArrayBuffer>;
|
||||
download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise<undefined>;
|
||||
/**
|
||||
* Release all tensors for a MLContext.
|
||||
* Release all tensors for a given session.
|
||||
*/
|
||||
releaseTensorsForContext(mlContext: MLContext): void;
|
||||
releaseTensorsForSession(session: number): void;
|
||||
/**
|
||||
* Register an externally created MLTensor with a given MLContext and return a TensorId.
|
||||
*/
|
||||
|
|
@ -54,65 +54,89 @@ export interface TensorManager {
|
|||
let tensorGuid = 1;
|
||||
const createNewTensorId = (): TensorId => tensorGuid++;
|
||||
|
||||
export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]];
|
||||
/**
|
||||
* TensorWrapper wraps an MLTensor and provides a way to track the last session that used it.
|
||||
*/
|
||||
class TensorWrapper {
|
||||
// The id of the last session that used this tensor.
|
||||
public sessionId: number;
|
||||
|
||||
private mlContext: MLContext;
|
||||
private mlTensor: MLTensor;
|
||||
private dataType: MLOperandDataType;
|
||||
private tensorShape: readonly number[];
|
||||
|
||||
constructor(descriptor: {
|
||||
sessionId: number;
|
||||
context: MLContext;
|
||||
tensor: MLTensor;
|
||||
dataType: MLOperandDataType;
|
||||
shape: readonly number[];
|
||||
}) {
|
||||
this.sessionId = descriptor.sessionId;
|
||||
this.mlContext = descriptor.context;
|
||||
this.mlTensor = descriptor.tensor;
|
||||
this.dataType = descriptor.dataType;
|
||||
this.tensorShape = descriptor.shape;
|
||||
}
|
||||
|
||||
public get tensor(): MLTensor {
|
||||
return this.mlTensor;
|
||||
}
|
||||
|
||||
public get type(): MLOperandDataType {
|
||||
return this.dataType;
|
||||
}
|
||||
|
||||
public get shape(): readonly number[] {
|
||||
return this.tensorShape;
|
||||
}
|
||||
|
||||
public destroy(): void {
|
||||
LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy');
|
||||
this.mlTensor.destroy();
|
||||
}
|
||||
|
||||
public write(data: Uint8Array): void {
|
||||
this.mlContext.writeTensor(this.mlTensor, data);
|
||||
}
|
||||
|
||||
public async read(): Promise<ArrayBuffer>;
|
||||
public async read(dstBuffer: ArrayBufferView | ArrayBuffer): Promise<undefined>;
|
||||
async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise<ArrayBuffer | undefined> {
|
||||
if (dstBuffer) {
|
||||
return this.mlContext.readTensor(this.mlTensor, dstBuffer);
|
||||
}
|
||||
return this.mlContext.readTensor(this.mlTensor);
|
||||
}
|
||||
|
||||
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
|
||||
return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* TensorTracker tracks the MLTensor and pending upload data.
|
||||
*
|
||||
* We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until
|
||||
* we know the data type and shape. This is because future implementations of WebNN will only support creating
|
||||
* MLTensors with dataTypes and shape.
|
||||
* we know the data type and shape. This is because WebNN only support creating MLTensors with dataTypes and shape.
|
||||
*/
|
||||
class TensorTracker {
|
||||
private tensorEntry?: MLTensorEntry;
|
||||
class TensorIdTracker {
|
||||
private activeUpload?: Uint8Array;
|
||||
private tensorCache: MLTensorEntry[];
|
||||
|
||||
constructor(
|
||||
private mlContext?: MLContext,
|
||||
tensorEntry?: MLTensorEntry,
|
||||
) {
|
||||
this.tensorEntry = tensorEntry;
|
||||
this.tensorCache = tensorEntry ? [tensorEntry] : [];
|
||||
private tensorManager: TensorManagerImpl,
|
||||
private wrapper?: TensorWrapper,
|
||||
) {}
|
||||
|
||||
public get tensorWrapper(): TensorWrapper | undefined {
|
||||
return this.wrapper;
|
||||
}
|
||||
|
||||
public get tensor(): MLTensor | undefined {
|
||||
return this.tensorEntry?.[0];
|
||||
}
|
||||
|
||||
public get context(): MLContext {
|
||||
if (!this.mlContext) {
|
||||
throw new Error('MLContext has not been set.');
|
||||
public releaseTensor(): void {
|
||||
if (this.tensorWrapper) {
|
||||
this.tensorManager.releaseTensor(this.tensorWrapper);
|
||||
}
|
||||
return this.mlContext;
|
||||
}
|
||||
|
||||
public set context(mlContext: MLContext) {
|
||||
if (this.mlContext && this.mlContext !== mlContext) {
|
||||
throw new Error('MLTensor in use in a different MLContext.');
|
||||
}
|
||||
this.mlContext = mlContext;
|
||||
}
|
||||
|
||||
public destroy(): void {
|
||||
for (const [mlTensor] of this.tensorCache) {
|
||||
mlTensor.destroy();
|
||||
}
|
||||
this.tensorCache = [];
|
||||
this.tensorEntry = undefined;
|
||||
}
|
||||
|
||||
public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean {
|
||||
for (const [mlTensor, dataType, shape] of this.tensorCache) {
|
||||
if (tryMLTensor === mlTensor) {
|
||||
if (this.context !== context) {
|
||||
throw new Error('MLTensor cannot be registered with a different MLContext.');
|
||||
}
|
||||
this.tensorEntry = [mlTensor, dataType, shape];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public async ensureTensor(
|
||||
|
|
@ -120,55 +144,40 @@ class TensorTracker {
|
|||
shape: readonly number[],
|
||||
copyOld: boolean,
|
||||
): Promise<MLTensor> {
|
||||
if (this.tensorEntry) {
|
||||
const [mlTensor, existingDataType, existingShape] = this.tensorEntry;
|
||||
if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
|
||||
return mlTensor;
|
||||
if (this.wrapper) {
|
||||
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
|
||||
return this.wrapper.tensor;
|
||||
} else {
|
||||
if (copyOld) {
|
||||
this.activeUpload = new Uint8Array(await this.wrapper.read());
|
||||
}
|
||||
this.tensorManager.releaseTensor(this.wrapper);
|
||||
}
|
||||
}
|
||||
|
||||
for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) {
|
||||
if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
|
||||
if (copyOld && this.tensorEntry) {
|
||||
// WebNN does not support copyTensorToTensor, so we need to read and write the tensors.
|
||||
LOG_DEBUG(
|
||||
'verbose',
|
||||
() => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`,
|
||||
);
|
||||
const data = await this.context.readTensor(this.tensorEntry[0]);
|
||||
this.context.writeTensor(mlTensor, data);
|
||||
}
|
||||
this.tensorEntry = [mlTensor, existingDataType, existingShape];
|
||||
return mlTensor;
|
||||
}
|
||||
}
|
||||
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
|
||||
// eslint-disable-next-line no-bitwise
|
||||
const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
|
||||
const tensor = await this.context.createTensor({
|
||||
dataType,
|
||||
shape,
|
||||
// Assign both shape and dimensions while transitioning to new API.
|
||||
dimensions: shape,
|
||||
usage,
|
||||
});
|
||||
this.tensorEntry = [tensor, dataType, shape];
|
||||
this.tensorCache.push(this.tensorEntry);
|
||||
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage);
|
||||
|
||||
if (this.activeUpload) {
|
||||
this.mlContext?.writeTensor(tensor, this.activeUpload);
|
||||
if (copyOld && this.activeUpload) {
|
||||
this.wrapper.write(this.activeUpload);
|
||||
this.activeUpload = undefined;
|
||||
}
|
||||
|
||||
return tensor;
|
||||
return this.wrapper.tensor;
|
||||
}
|
||||
|
||||
public upload(data: Uint8Array): void {
|
||||
if (!this.tensorEntry) {
|
||||
this.activeUpload = new Uint8Array(data);
|
||||
if (this.wrapper) {
|
||||
this.wrapper.write(data);
|
||||
return;
|
||||
}
|
||||
this.mlContext?.writeTensor(this.tensorEntry[0], data);
|
||||
|
||||
if (this.activeUpload) {
|
||||
this.activeUpload.set(data);
|
||||
} else {
|
||||
this.activeUpload = new Uint8Array(data);
|
||||
}
|
||||
}
|
||||
|
||||
public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise<ArrayBuffer | undefined> {
|
||||
|
|
@ -179,49 +188,42 @@ class TensorTracker {
|
|||
} else {
|
||||
new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload);
|
||||
}
|
||||
|
||||
return;
|
||||
} else {
|
||||
return this.activeUpload.buffer;
|
||||
}
|
||||
}
|
||||
if (!this.tensorEntry) {
|
||||
if (!this.wrapper) {
|
||||
throw new Error('Tensor has not been created.');
|
||||
}
|
||||
if (dstBuffer) {
|
||||
return this.context.readTensor(this.tensorEntry[0], dstBuffer);
|
||||
if (!dstBuffer) {
|
||||
return this.wrapper.read();
|
||||
}
|
||||
return this.context.readTensor(this.tensorEntry[0]);
|
||||
return this.wrapper.read(dstBuffer);
|
||||
}
|
||||
}
|
||||
|
||||
class TensorManagerImpl implements TensorManager {
|
||||
private tensorsById = new Map<TensorId, TensorTracker>();
|
||||
private tensorIdsByContext = new Map<MLContext, Set<TensorId>>();
|
||||
private tensorTrackersById: Map<TensorId, TensorIdTracker> = new Map();
|
||||
private freeTensors: TensorWrapper[] = [];
|
||||
private externalTensors: Set<TensorWrapper> = new Set();
|
||||
|
||||
constructor(private backend: WebNNBackend) {}
|
||||
|
||||
public reserveTensorId(): TensorId {
|
||||
const tensorId = createNewTensorId();
|
||||
this.tensorsById.set(tensorId, new TensorTracker());
|
||||
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this));
|
||||
return tensorId;
|
||||
}
|
||||
|
||||
public releaseTensorId(tensorId: TensorId): void {
|
||||
const tensorTracker = this.tensorsById.get(tensorId);
|
||||
const tensorTracker = this.tensorTrackersById.get(tensorId);
|
||||
if (!tensorTracker) {
|
||||
return;
|
||||
}
|
||||
tensorTracker.destroy();
|
||||
this.tensorsById.delete(tensorId);
|
||||
for (const [mlContext, tensors] of this.tensorIdsByContext) {
|
||||
if (tensors.has(tensorId)) {
|
||||
tensors.delete(tensorId);
|
||||
if (tensors.size === 0) {
|
||||
this.tensorIdsByContext.delete(mlContext);
|
||||
}
|
||||
break;
|
||||
}
|
||||
this.tensorTrackersById.delete(tensorId);
|
||||
if (tensorTracker.tensorWrapper) {
|
||||
this.releaseTensor(tensorTracker.tensorWrapper);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -238,20 +240,19 @@ class TensorManagerImpl implements TensorManager {
|
|||
dataType
|
||||
}, shape: ${shape}, copyOld: ${copyOld}}`,
|
||||
);
|
||||
const tensor = this.tensorsById.get(tensorId);
|
||||
const tensor = this.tensorTrackersById.get(tensorId);
|
||||
if (!tensor) {
|
||||
throw new Error('Tensor not found.');
|
||||
}
|
||||
tensor.context = this.backend.currentContext;
|
||||
if (!this.tensorIdsByContext.has(this.backend.currentContext)) {
|
||||
this.tensorIdsByContext.set(this.backend.currentContext, new Set());
|
||||
}
|
||||
this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId);
|
||||
return tensor.ensureTensor(dataType, shape, copyOld);
|
||||
}
|
||||
|
||||
public upload(tensorId: TensorId, data: Uint8Array): void {
|
||||
this.tensorsById.get(tensorId)!.upload(data);
|
||||
const tensor = this.tensorTrackersById.get(tensorId);
|
||||
if (!tensor) {
|
||||
throw new Error('Tensor not found.');
|
||||
}
|
||||
tensor.upload(data);
|
||||
}
|
||||
|
||||
public async download(tensorId: TensorId): Promise<ArrayBuffer>;
|
||||
|
|
@ -261,19 +262,20 @@ class TensorManagerImpl implements TensorManager {
|
|||
'verbose',
|
||||
() => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`,
|
||||
);
|
||||
return this.tensorsById.get(tensorId)!.download(dstBuffer);
|
||||
const tensorTracker = this.tensorTrackersById.get(tensorId);
|
||||
if (!tensorTracker) {
|
||||
throw new Error('Tensor not found.');
|
||||
}
|
||||
return tensorTracker.download(dstBuffer);
|
||||
}
|
||||
|
||||
public releaseTensorsForContext(mlContext: MLContext): void {
|
||||
const tensors = this.tensorIdsByContext.get(mlContext);
|
||||
if (!tensors) {
|
||||
return;
|
||||
public releaseTensorsForSession(sessionId: number): void {
|
||||
for (const tensor of this.freeTensors) {
|
||||
if (tensor.sessionId === sessionId) {
|
||||
tensor.destroy();
|
||||
}
|
||||
}
|
||||
for (const tensorId of tensors) {
|
||||
this.tensorsById.get(tensorId)!.destroy();
|
||||
this.tensorsById.delete(tensorId);
|
||||
}
|
||||
this.tensorIdsByContext.delete(mlContext);
|
||||
this.freeTensors = this.freeTensors.filter((tensor) => tensor.sessionId !== sessionId);
|
||||
}
|
||||
|
||||
public registerTensor(
|
||||
|
|
@ -282,20 +284,56 @@ class TensorManagerImpl implements TensorManager {
|
|||
dataType: MLOperandDataType,
|
||||
shape: readonly number[],
|
||||
): TensorId {
|
||||
for (const [tensorId, tensorTracker] of this.tensorsById) {
|
||||
if (tensorTracker.trySelectTensor(mlContext, mlTensor)) {
|
||||
return tensorId;
|
||||
const tensorId = createNewTensorId();
|
||||
// Defaulting to READ | WRITE if usage is not provided.
|
||||
// eslint-disable-next-line no-bitwise
|
||||
const wrapper = new TensorWrapper({
|
||||
sessionId: this.backend.currentSessionId,
|
||||
context: mlContext,
|
||||
tensor: mlTensor,
|
||||
dataType,
|
||||
shape,
|
||||
});
|
||||
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this, wrapper));
|
||||
this.externalTensors.add(wrapper);
|
||||
return tensorId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create an MLTensor with the given data type and shape.
|
||||
*/
|
||||
public async getCachedTensor(
|
||||
dataType: MLOperandDataType,
|
||||
shape: readonly number[],
|
||||
usage: MLTensorUsageFlags,
|
||||
): Promise<TensorWrapper> {
|
||||
const sessionId = this.backend.currentSessionId;
|
||||
for (const [index, tensor] of this.freeTensors.entries()) {
|
||||
if (tensor.sameTypeAndShape(dataType, shape)) {
|
||||
const wrapper = this.freeTensors.splice(index, 1)[0];
|
||||
wrapper.sessionId = sessionId;
|
||||
return wrapper;
|
||||
}
|
||||
}
|
||||
const tensorId = createNewTensorId();
|
||||
this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape]));
|
||||
let tensors = this.tensorIdsByContext.get(mlContext);
|
||||
if (!tensors) {
|
||||
tensors = new Set();
|
||||
this.tensorIdsByContext.set(mlContext, tensors);
|
||||
const context = this.backend.currentContext;
|
||||
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
|
||||
const tensor = await context.createTensor({
|
||||
dataType,
|
||||
shape,
|
||||
dimensions: shape,
|
||||
usage,
|
||||
});
|
||||
return new TensorWrapper({ sessionId, context, tensor, dataType, shape });
|
||||
}
|
||||
|
||||
/**
|
||||
* Release tensor for reuse unless external.
|
||||
*/
|
||||
public releaseTensor(tensorWrapper: TensorWrapper) {
|
||||
if (this.externalTensors.has(tensorWrapper)) {
|
||||
this.externalTensors.delete(tensorWrapper);
|
||||
}
|
||||
tensors.add(tensorId);
|
||||
return tensorId;
|
||||
this.freeTensors.push(tensorWrapper);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue