diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 17ca147a7f..f2132d8912 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -8,25 +8,25 @@ import {TensorView} from './tensor'; import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, GpuDataType, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; /** - * get a unique key representing the program from the program info,input shapes and types. + * get a unique key representing the program from the program info, input shapes and types. * * @returns a unique key is a shorter string than the shader source, which contains all the information to identify a * program. if the key is the same, the program shader source should be the same, so we can reuse the program. * */ const getProgramInfoUniqueKey = - (programInfo: ProgramInfo|ProgramInfoLoader, inputTensorShapes: ReadonlyArray, - inputGpuDataTypes: readonly GpuDataType[]): string => { - const inputTensorShapesToString = inputTensorShapes.map(d => `${d.join(',')}`).join('_'); - const inputGpuDataTypesToString = inputGpuDataTypes.join('_'); + (programInfo: ProgramInfo|ProgramInfoLoader, inputTensors: readonly TensorView[]): string => { + // final key format: + // []:||... + const inputInfos = inputTensors.map(tensor => `${tensor.dataType};${tensor.dims.join(',')}`).join('|'); let key = programInfo.name; if (programInfo.cacheHint) { key += '[' + programInfo.cacheHint + ']'; } - key += ':' + inputTensorShapesToString + ';' + inputGpuDataTypesToString; + key += ':' + inputInfos; return key; }; @@ -221,7 +221,7 @@ export class WebGpuBackend { inputDatas[i] = gpuData; } - const key = getProgramInfoUniqueKey(program, inputs.map(i => i.dims), inputDatas.map(i => i.type)); + const key = getProgramInfoUniqueKey(program, inputs); let artifact = this.programManager.getArtifact(key); const programInfo = artifact ? artifact.programInfo :