From b4be9e1bbb20e1e03528f73df71e9f141ae04fcf Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 12 Dec 2023 10:11:38 +0800 Subject: [PATCH] [js/webgpu] Fix shader compilation errors in cumsum (#18779) ### Description This PR fixes below shader compilation errors: ``` Tint WGSL reader failure: :39:31 error: no matching overload for operator + (f32, i32) 5 candidate operators: operator + (T, T) -> T where: T is abstract-float, abstract-int, f32, i32, u32 or f16 operator + (vecN, T) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 operator + (T, vecN) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 operator + (vecN, vecN) -> vecN where: T is abstract-float, abstract-int, f32, i32, u32 or f16 operator + (matNxM, matNxM) -> matNxM where: T is abstract-float, f32 or f16 sum = sum + get_inputByIndices(inputIndices); ^ - While validating [ShaderModuleDescriptor "CumSum"] - While calling [Device].CreateShaderModule([ShaderModuleDescriptor "CumSum"]). --- js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts | 2 +- js/web/test/data/ops/cumsum.jsonc | 36 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index e7208ce34d..85682f0b47 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -37,7 +37,7 @@ const createCumsumProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} var inputIndices = ${output.offsetToIndices('global_idx')}; - var sum = 0.0; + var sum = ${output.type.value}(0); let first : i32 = ${lowerLimit}; let last : i32 = ${upperLimit}; for (var i : i32 = first; i < last; i++) { diff --git a/js/web/test/data/ops/cumsum.jsonc b/js/web/test/data/ops/cumsum.jsonc index cac9be734b..b3173afb69 100644 --- a/js/web/test/data/ops/cumsum.jsonc +++ b/js/web/test/data/ops/cumsum.jsonc @@ -1322,5 +1322,41 @@ ] } ] + }, + { + "name": "CumSum", + "operator": "CumSum", + "attributes": [ + { "name": "exclusive", "data": 0, "type": "int" }, + { "name": "reverse", "data": 0, "type": "int" } + ], + "opset": { + "domain": "", + "version": 11 + }, + "cases": [ + { + "name": "CumSum int32; axis = 0; exclusive = 0, reverse = 0", + "inputs": [ + { + "data": [1, 2, 3, 4, 5], + "dims": [1, 1, 1, 1, 5], + "type": "int32" + }, + { + "data": [4], + "dims": [], + "type": "int32" + } + ], + "outputs": [ + { + "data": [1, 3, 6, 10, 15], + "dims": [1, 1, 1, 1, 5], + "type": "int32" + } + ] + } + ] } ]