mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[WebNN] Fixes MLTensor caching across different contexts (#23100)
We weren't checking that MLTensors were from the same context before reusing them. Found while debugging microsoft/webnn-developer-preview#69
This commit is contained in:
parent
5afab787db
commit
54edb43e77
1 changed files with 7 additions and 5 deletions
|
|
@ -141,8 +141,9 @@ class TensorWrapper {
|
|||
return this.mlContext.readTensor(this.mlTensor);
|
||||
}
|
||||
|
||||
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
|
||||
public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean {
|
||||
return (
|
||||
this.mlContext === context &&
|
||||
this.dataType === dataType &&
|
||||
this.tensorShape.length === shape.length &&
|
||||
this.tensorShape.every((v, i) => v === shape[i])
|
||||
|
|
@ -176,12 +177,13 @@ class TensorIdTracker {
|
|||
}
|
||||
|
||||
public async ensureTensor(
|
||||
context: MLContext,
|
||||
dataType: MLOperandDataType,
|
||||
shape: readonly number[],
|
||||
copyOld: boolean,
|
||||
): Promise<MLTensor> {
|
||||
if (this.wrapper) {
|
||||
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
|
||||
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
|
||||
return this.wrapper.tensor;
|
||||
} else {
|
||||
if (copyOld) {
|
||||
|
|
@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager {
|
|||
if (!tensor) {
|
||||
throw new Error('Tensor not found.');
|
||||
}
|
||||
return tensor.ensureTensor(dataType, shape, copyOld);
|
||||
return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld);
|
||||
}
|
||||
|
||||
public upload(tensorId: TensorId, data: Uint8Array): void {
|
||||
|
|
@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager {
|
|||
readable: boolean,
|
||||
): Promise<TensorWrapper> {
|
||||
const sessionId = this.backend.currentSessionId;
|
||||
const context = this.backend.currentContext;
|
||||
for (const [index, tensor] of this.freeTensors.entries()) {
|
||||
if (tensor.sameTypeAndShape(dataType, shape)) {
|
||||
if (tensor.canReuseTensor(context, dataType, shape)) {
|
||||
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
|
||||
const wrapper = this.freeTensors.splice(index, 1)[0];
|
||||
wrapper.sessionId = sessionId;
|
||||
return wrapper;
|
||||
}
|
||||
}
|
||||
const context = this.backend.currentContext;
|
||||
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
|
||||
const tensor = await context.createTensor({
|
||||
dataType,
|
||||
|
|
|
|||
Loading…
Reference in a new issue