[WebNN EP] Cache MLTensors between runs (#22278)

### Description
This change enables caching `MLTensor`s between inferences runs. This is
done by keeping a reference to `MLTensor`s alive after they have been
released. `MLTensor`s are only destroyed once the sessions goes out of
scope.

### Motivation and Context
Creating and destroying `MTensor`s on every run has a non-trivial
performance penalty. This performance penalty materializes when using
`ort.Tensors`[location=cpu] for inputs/outputs or when using the CPU EP
as a fallback EP for unsupported operators. The former could be
mitigated by developer using `ort.Tensors`[location=ml-tensor]. The
latter cannot be mitigated by developers.
This commit is contained in:
Enrico Galli 2024-10-18 08:07:00 -07:00 committed by GitHub
parent b4cb937440
commit 1e5bda88f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 170 additions and 132 deletions

View file

@ -91,12 +91,12 @@ export class WebNNBackend {
// Current session is not a WebNN session.
return;
}
this.tensorManager.releaseTensorsForSession(sessionId);
this.mlContextBySessionId.delete(sessionId);
const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
sessionIds.delete(sessionId);
if (sessionIds.size === 0) {
this.sessionIdsByMLContext.delete(mlContext);
this.tensorManager.releaseTensorsForContext(mlContext);
}
}

View file

@ -42,9 +42,9 @@ export interface TensorManager {
download(tensorId: TensorId): Promise<ArrayBuffer>;
download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise<undefined>;
/**
* Release all tensors for a MLContext.
* Release all tensors for a given session.
*/
releaseTensorsForContext(mlContext: MLContext): void;
releaseTensorsForSession(session: number): void;
/**
* Register an externally created MLTensor with a given MLContext and return a TensorId.
*/
@ -54,65 +54,89 @@ export interface TensorManager {
let tensorGuid = 1;
const createNewTensorId = (): TensorId => tensorGuid++;
export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]];
/**
* TensorWrapper wraps an MLTensor and provides a way to track the last session that used it.
*/
class TensorWrapper {
// The id of the last session that used this tensor.
public sessionId: number;
private mlContext: MLContext;
private mlTensor: MLTensor;
private dataType: MLOperandDataType;
private tensorShape: readonly number[];
constructor(descriptor: {
sessionId: number;
context: MLContext;
tensor: MLTensor;
dataType: MLOperandDataType;
shape: readonly number[];
}) {
this.sessionId = descriptor.sessionId;
this.mlContext = descriptor.context;
this.mlTensor = descriptor.tensor;
this.dataType = descriptor.dataType;
this.tensorShape = descriptor.shape;
}
public get tensor(): MLTensor {
return this.mlTensor;
}
public get type(): MLOperandDataType {
return this.dataType;
}
public get shape(): readonly number[] {
return this.tensorShape;
}
public destroy(): void {
LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy');
this.mlTensor.destroy();
}
public write(data: Uint8Array): void {
this.mlContext.writeTensor(this.mlTensor, data);
}
public async read(): Promise<ArrayBuffer>;
public async read(dstBuffer: ArrayBufferView | ArrayBuffer): Promise<undefined>;
async read(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise<ArrayBuffer | undefined> {
if (dstBuffer) {
return this.mlContext.readTensor(this.mlTensor, dstBuffer);
}
return this.mlContext.readTensor(this.mlTensor);
}
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]);
}
}
/**
* TensorTracker tracks the MLTensor and pending upload data.
*
* We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until
* we know the data type and shape. This is because future implementations of WebNN will only support creating
* MLTensors with dataTypes and shape.
* we know the data type and shape. This is because WebNN only support creating MLTensors with dataTypes and shape.
*/
class TensorTracker {
private tensorEntry?: MLTensorEntry;
class TensorIdTracker {
private activeUpload?: Uint8Array;
private tensorCache: MLTensorEntry[];
constructor(
private mlContext?: MLContext,
tensorEntry?: MLTensorEntry,
) {
this.tensorEntry = tensorEntry;
this.tensorCache = tensorEntry ? [tensorEntry] : [];
private tensorManager: TensorManagerImpl,
private wrapper?: TensorWrapper,
) {}
public get tensorWrapper(): TensorWrapper | undefined {
return this.wrapper;
}
public get tensor(): MLTensor | undefined {
return this.tensorEntry?.[0];
}
public get context(): MLContext {
if (!this.mlContext) {
throw new Error('MLContext has not been set.');
public releaseTensor(): void {
if (this.tensorWrapper) {
this.tensorManager.releaseTensor(this.tensorWrapper);
}
return this.mlContext;
}
public set context(mlContext: MLContext) {
if (this.mlContext && this.mlContext !== mlContext) {
throw new Error('MLTensor in use in a different MLContext.');
}
this.mlContext = mlContext;
}
public destroy(): void {
for (const [mlTensor] of this.tensorCache) {
mlTensor.destroy();
}
this.tensorCache = [];
this.tensorEntry = undefined;
}
public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean {
for (const [mlTensor, dataType, shape] of this.tensorCache) {
if (tryMLTensor === mlTensor) {
if (this.context !== context) {
throw new Error('MLTensor cannot be registered with a different MLContext.');
}
this.tensorEntry = [mlTensor, dataType, shape];
return true;
}
}
return false;
}
public async ensureTensor(
@ -120,55 +144,40 @@ class TensorTracker {
shape: readonly number[],
copyOld: boolean,
): Promise<MLTensor> {
if (this.tensorEntry) {
const [mlTensor, existingDataType, existingShape] = this.tensorEntry;
if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
return mlTensor;
if (this.wrapper) {
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
return this.wrapper.tensor;
} else {
if (copyOld) {
this.activeUpload = new Uint8Array(await this.wrapper.read());
}
this.tensorManager.releaseTensor(this.wrapper);
}
}
for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) {
if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) {
if (copyOld && this.tensorEntry) {
// WebNN does not support copyTensorToTensor, so we need to read and write the tensors.
LOG_DEBUG(
'verbose',
() => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`,
);
const data = await this.context.readTensor(this.tensorEntry[0]);
this.context.writeTensor(mlTensor, data);
}
this.tensorEntry = [mlTensor, existingDataType, existingShape];
return mlTensor;
}
}
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
// eslint-disable-next-line no-bitwise
const usage = MLTensorUsage.READ | MLTensorUsage.WRITE;
const tensor = await this.context.createTensor({
dataType,
shape,
// Assign both shape and dimensions while transitioning to new API.
dimensions: shape,
usage,
});
this.tensorEntry = [tensor, dataType, shape];
this.tensorCache.push(this.tensorEntry);
this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage);
if (this.activeUpload) {
this.mlContext?.writeTensor(tensor, this.activeUpload);
if (copyOld && this.activeUpload) {
this.wrapper.write(this.activeUpload);
this.activeUpload = undefined;
}
return tensor;
return this.wrapper.tensor;
}
public upload(data: Uint8Array): void {
if (!this.tensorEntry) {
this.activeUpload = new Uint8Array(data);
if (this.wrapper) {
this.wrapper.write(data);
return;
}
this.mlContext?.writeTensor(this.tensorEntry[0], data);
if (this.activeUpload) {
this.activeUpload.set(data);
} else {
this.activeUpload = new Uint8Array(data);
}
}
public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise<ArrayBuffer | undefined> {
@ -179,49 +188,42 @@ class TensorTracker {
} else {
new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload);
}
return;
} else {
return this.activeUpload.buffer;
}
}
if (!this.tensorEntry) {
if (!this.wrapper) {
throw new Error('Tensor has not been created.');
}
if (dstBuffer) {
return this.context.readTensor(this.tensorEntry[0], dstBuffer);
if (!dstBuffer) {
return this.wrapper.read();
}
return this.context.readTensor(this.tensorEntry[0]);
return this.wrapper.read(dstBuffer);
}
}
class TensorManagerImpl implements TensorManager {
private tensorsById = new Map<TensorId, TensorTracker>();
private tensorIdsByContext = new Map<MLContext, Set<TensorId>>();
private tensorTrackersById: Map<TensorId, TensorIdTracker> = new Map();
private freeTensors: TensorWrapper[] = [];
private externalTensors: Set<TensorWrapper> = new Set();
constructor(private backend: WebNNBackend) {}
public reserveTensorId(): TensorId {
const tensorId = createNewTensorId();
this.tensorsById.set(tensorId, new TensorTracker());
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this));
return tensorId;
}
public releaseTensorId(tensorId: TensorId): void {
const tensorTracker = this.tensorsById.get(tensorId);
const tensorTracker = this.tensorTrackersById.get(tensorId);
if (!tensorTracker) {
return;
}
tensorTracker.destroy();
this.tensorsById.delete(tensorId);
for (const [mlContext, tensors] of this.tensorIdsByContext) {
if (tensors.has(tensorId)) {
tensors.delete(tensorId);
if (tensors.size === 0) {
this.tensorIdsByContext.delete(mlContext);
}
break;
}
this.tensorTrackersById.delete(tensorId);
if (tensorTracker.tensorWrapper) {
this.releaseTensor(tensorTracker.tensorWrapper);
}
}
@ -238,20 +240,19 @@ class TensorManagerImpl implements TensorManager {
dataType
}, shape: ${shape}, copyOld: ${copyOld}}`,
);
const tensor = this.tensorsById.get(tensorId);
const tensor = this.tensorTrackersById.get(tensorId);
if (!tensor) {
throw new Error('Tensor not found.');
}
tensor.context = this.backend.currentContext;
if (!this.tensorIdsByContext.has(this.backend.currentContext)) {
this.tensorIdsByContext.set(this.backend.currentContext, new Set());
}
this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId);
return tensor.ensureTensor(dataType, shape, copyOld);
}
public upload(tensorId: TensorId, data: Uint8Array): void {
this.tensorsById.get(tensorId)!.upload(data);
const tensor = this.tensorTrackersById.get(tensorId);
if (!tensor) {
throw new Error('Tensor not found.');
}
tensor.upload(data);
}
public async download(tensorId: TensorId): Promise<ArrayBuffer>;
@ -261,19 +262,20 @@ class TensorManagerImpl implements TensorManager {
'verbose',
() => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`,
);
return this.tensorsById.get(tensorId)!.download(dstBuffer);
const tensorTracker = this.tensorTrackersById.get(tensorId);
if (!tensorTracker) {
throw new Error('Tensor not found.');
}
return tensorTracker.download(dstBuffer);
}
public releaseTensorsForContext(mlContext: MLContext): void {
const tensors = this.tensorIdsByContext.get(mlContext);
if (!tensors) {
return;
public releaseTensorsForSession(sessionId: number): void {
for (const tensor of this.freeTensors) {
if (tensor.sessionId === sessionId) {
tensor.destroy();
}
}
for (const tensorId of tensors) {
this.tensorsById.get(tensorId)!.destroy();
this.tensorsById.delete(tensorId);
}
this.tensorIdsByContext.delete(mlContext);
this.freeTensors = this.freeTensors.filter((tensor) => tensor.sessionId !== sessionId);
}
public registerTensor(
@ -282,20 +284,56 @@ class TensorManagerImpl implements TensorManager {
dataType: MLOperandDataType,
shape: readonly number[],
): TensorId {
for (const [tensorId, tensorTracker] of this.tensorsById) {
if (tensorTracker.trySelectTensor(mlContext, mlTensor)) {
return tensorId;
const tensorId = createNewTensorId();
// Defaulting to READ | WRITE if usage is not provided.
// eslint-disable-next-line no-bitwise
const wrapper = new TensorWrapper({
sessionId: this.backend.currentSessionId,
context: mlContext,
tensor: mlTensor,
dataType,
shape,
});
this.tensorTrackersById.set(tensorId, new TensorIdTracker(this, wrapper));
this.externalTensors.add(wrapper);
return tensorId;
}
/**
* Get or create an MLTensor with the given data type and shape.
*/
public async getCachedTensor(
dataType: MLOperandDataType,
shape: readonly number[],
usage: MLTensorUsageFlags,
): Promise<TensorWrapper> {
const sessionId = this.backend.currentSessionId;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
const wrapper = this.freeTensors.splice(index, 1)[0];
wrapper.sessionId = sessionId;
return wrapper;
}
}
const tensorId = createNewTensorId();
this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape]));
let tensors = this.tensorIdsByContext.get(mlContext);
if (!tensors) {
tensors = new Set();
this.tensorIdsByContext.set(mlContext, tensors);
const context = this.backend.currentContext;
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
const tensor = await context.createTensor({
dataType,
shape,
dimensions: shape,
usage,
});
return new TensorWrapper({ sessionId, context, tensor, dataType, shape });
}
/**
* Release tensor for reuse unless external.
*/
public releaseTensor(tensorWrapper: TensorWrapper) {
if (this.externalTensors.has(tensorWrapper)) {
this.externalTensors.delete(tensorWrapper);
}
tensors.add(tensorId);
return tensorId;
this.freeTensors.push(tensorWrapper);
}
}