mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN EP] Fix issues with MLTensor caching (#22701)
This PR fixes a bug that occurs when searching for compatible `MLTensor` in the cache. We were missing checking the number of dimensions in the shape. This would mean that a cached buffer of shape `[1]` could match for `[1, 1, 256, 256]`. This PR also adds better handling when attempting to force an `MLTensor` to a different shape.
This commit is contained in:
parent
811231e418
commit
1cb5ceedf3
1 changed files with 48 additions and 3 deletions
|
|
@ -54,6 +54,33 @@ export interface TensorManager {
|
|||
let tensorGuid = 1;
|
||||
const createNewTensorId = (): TensorId => tensorGuid++;
|
||||
|
||||
/**
|
||||
* Map from MLOperandDataType to size in bits. Using bits instead of bytes to avoid possible precision loss on int4 and uint4.
|
||||
*/
|
||||
const webnnDataTypeToSize = new Map<MLOperandDataType, number>([
|
||||
['float32', 32],
|
||||
['float16', 16],
|
||||
['int32', 32],
|
||||
['uint32', 32],
|
||||
['int64', 64],
|
||||
['uint64', 64],
|
||||
['int8', 8],
|
||||
['uint8', 8],
|
||||
['int4', 4],
|
||||
['uint4', 4],
|
||||
]);
|
||||
|
||||
/**
|
||||
* Calculate the byte length of a tensor with the given data type and shape.
|
||||
*/
|
||||
const calculateByteLength = (dataType: MLOperandDataType, shape: readonly number[]): number => {
|
||||
const size = webnnDataTypeToSize.get(dataType);
|
||||
if (!size) {
|
||||
throw new Error('Unsupported data type.');
|
||||
}
|
||||
return Math.ceil((shape.reduce((a, b) => a * b) * size) / 8);
|
||||
};
|
||||
|
||||
/**
|
||||
* TensorWrapper wraps an MLTensor and provides a way to track the last session that used it.
|
||||
*/
|
||||
|
|
@ -92,6 +119,10 @@ class TensorWrapper {
|
|||
return this.tensorShape;
|
||||
}
|
||||
|
||||
public get byteLength(): number {
|
||||
return calculateByteLength(this.dataType, this.tensorShape);
|
||||
}
|
||||
|
||||
public destroy(): void {
|
||||
LOG_DEBUG('verbose', () => '[WebNN] TensorWrapper.destroy');
|
||||
this.mlTensor.destroy();
|
||||
|
|
@ -111,7 +142,11 @@ class TensorWrapper {
|
|||
}
|
||||
|
||||
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
|
||||
return this.dataType === dataType && this.tensorShape.every((v, i) => v === shape[i]);
|
||||
return (
|
||||
this.dataType === dataType &&
|
||||
this.tensorShape.length === shape.length &&
|
||||
this.tensorShape.every((v, i) => v === shape[i])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -136,6 +171,7 @@ class TensorIdTracker {
|
|||
public releaseTensor(): void {
|
||||
if (this.tensorWrapper) {
|
||||
this.tensorManager.releaseTensor(this.tensorWrapper);
|
||||
this.wrapper = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -149,6 +185,9 @@ class TensorIdTracker {
|
|||
return this.wrapper.tensor;
|
||||
} else {
|
||||
if (copyOld) {
|
||||
if (this.wrapper.byteLength !== calculateByteLength(dataType, shape)) {
|
||||
throw new Error('Unable to copy data to tensor with different size.');
|
||||
}
|
||||
this.activeUpload = new Uint8Array(await this.wrapper.read());
|
||||
}
|
||||
this.tensorManager.releaseTensor(this.wrapper);
|
||||
|
|
@ -169,8 +208,13 @@ class TensorIdTracker {
|
|||
|
||||
public upload(data: Uint8Array): void {
|
||||
if (this.wrapper) {
|
||||
this.wrapper.write(data);
|
||||
return;
|
||||
if (data.byteLength === this.wrapper.byteLength) {
|
||||
this.wrapper.write(data);
|
||||
return;
|
||||
} else {
|
||||
LOG_DEBUG('verbose', () => 'Data size does not match tensor size. Releasing tensor.');
|
||||
this.releaseTensor();
|
||||
}
|
||||
}
|
||||
|
||||
if (this.activeUpload) {
|
||||
|
|
@ -312,6 +356,7 @@ class TensorManagerImpl implements TensorManager {
|
|||
const sessionId = this.backend.currentSessionId;
|
||||
for (const [index, tensor] of this.freeTensors.entries()) {
|
||||
if (tensor.sameTypeAndShape(dataType, shape)) {
|
||||
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
|
||||
const wrapper = this.freeTensors.splice(index, 1)[0];
|
||||
wrapper.sessionId = sessionId;
|
||||
return wrapper;
|
||||
|
|
|
|||
Loading…
Reference in a new issue