[WebNN] Fixes MLTensor caching across different contexts (#23100)

We weren't checking that MLTensors were from the same context before
reusing them.

Found while debugging microsoft/webnn-developer-preview#69
This commit is contained in:
Enrico Galli 2024-12-17 12:51:16 -08:00 committed by GitHub
parent 5afab787db
commit 54edb43e77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -141,8 +141,9 @@ class TensorWrapper {
return this.mlContext.readTensor(this.mlTensor);
}
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean {
return (
this.mlContext === context &&
this.dataType === dataType &&
this.tensorShape.length === shape.length &&
this.tensorShape.every((v, i) => v === shape[i])
@ -176,12 +177,13 @@ class TensorIdTracker {
}
public async ensureTensor(
context: MLContext,
dataType: MLOperandDataType,
shape: readonly number[],
copyOld: boolean,
): Promise<MLTensor> {
if (this.wrapper) {
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
return this.wrapper.tensor;
} else {
if (copyOld) {
@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager {
if (!tensor) {
throw new Error('Tensor not found.');
}
return tensor.ensureTensor(dataType, shape, copyOld);
return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld);
}
public upload(tensorId: TensorId, data: Uint8Array): void {
@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager {
readable: boolean,
): Promise<TensorWrapper> {
const sessionId = this.backend.currentSessionId;
const context = this.backend.currentContext;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
if (tensor.canReuseTensor(context, dataType, shape)) {
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
const wrapper = this.freeTensors.splice(index, 1)[0];
wrapper.sessionId = sessionId;
return wrapper;
}
}
const context = this.backend.currentContext;
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
const tensor = await context.createTensor({
dataType,