[WebNN EP] Fix issues with MLTensor caching (#22701)

This PR fixes a bug that occurs when searching for compatible `MLTensor`
in the cache. We were missing checking the number of dimensions in the
shape. This would mean that a cached buffer of shape `[1]` could match
for `[1, 1, 256, 256]`.

This PR also adds better handling when attempting to force an `MLTensor`
to a different shape.
This commit is contained in:
Enrico Galli 2024-11-06 09:17:11 -08:00 committed by GitHub
parent 811231e418
commit 1cb5ceedf3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -54,6 +54,33 @@ export interface TensorManager {
let tensorGuid = 1;
const createNewTensorId = (): TensorId => tensorGuid++;
/**
* Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4.
*/
const webnnDataTypeToSize = new Map<MLOperandDataType, number>([
['float32', 32],
['float16', 16],
['int32', 32],
['uint32', 32],
['int64', 64],
['uint64', 64],
['int8', 8],
['uint8', 8],
['int4', 4],
['uint4', 4],
]);
/**
* Calculate the byte length of a tensor with the given data type and shape.
*/
const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => {
const size = webnnDataTypeToSize.get(dataType);
if (!size) {
throw new Error('Unsupported data type.');
}
return Math.ceil((shape.reduce((a, b) => a * b) * size) / 8);
};
/**
* TensorWrapper wraps an MLTensor and provides a way to track the last session that used it.
*/
@ -92,6 +119,10 @@ class TensorWrapper {
return this.tensorShape;
}
public get byteLength(): number {
return calculateByteLength(this.dataType, this.tensorShape);
}
public destroy(): void {
LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy');
this.mlTensor.destroy();
@ -111,7 +142,11 @@ class TensorWrapper {
}
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]);
return (
this.dataType === dataType &&
this.tensorShape.length === shape.length &&
this.tensorShape.every((v, i) => v === shape[i])
);
}
}
@ -136,6 +171,7 @@ class TensorIdTracker {
public releaseTensor(): void {
if (this.tensorWrapper) {
this.tensorManager.releaseTensor(this.tensorWrapper);
this.wrapper = undefined;
}
}
@ -149,6 +185,9 @@ class TensorIdTracker {
return this.wrapper.tensor;
} else {
if (copyOld) {
if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) {
throw new Error('Unable to copy data to tensor with different size.');
}
this.activeUpload = new Uint8Array(await this.wrapper.read());
}
this.tensorManager.releaseTensor(this.wrapper);
@ -169,8 +208,13 @@ class TensorIdTracker {
public upload(data: Uint8Array): void {
if (this.wrapper) {
this.wrapper.write(data);
return;
if (data.byteLength === this.wrapper.byteLength) {
this.wrapper.write(data);
return;
} else {
LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.');
this.releaseTensor();
}
}
if (this.activeUpload) {
@ -312,6 +356,7 @@ class TensorManagerImpl implements TensorManager {
const sessionId = this.backend.currentSessionId;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
const wrapper = this.freeTensors.splice(index, 1)[0];
wrapper.sessionId = sessionId;
return wrapper;