diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 3f8131be1c..1dc3a206cf 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -87,7 +87,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; const outputShape = inputShape.slice(); - outputShape[attributes.axis] = attributes.splitSizes[i]; + outputShape[axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({ dims: outputShapes[i], dataType: inputs[0].dataType }); diff --git a/js/web/test/data/ops/split.jsonc b/js/web/test/data/ops/split.jsonc index 46fc323cc6..f837e4d26d 100644 --- a/js/web/test/data/ops/split.jsonc +++ b/js/web/test/data/ops/split.jsonc @@ -64,5 +64,38 @@ ] } ] + }, + { + "name": "Split on Axis -1 - 2D", + "operator": "Split", + "opset": { "domain": "", "version": 12 }, + "attributes": [ + { "name": "axis", "data": -1, "type": "int" }, + { "name": "split", "data": [2, 4], "type": "ints" } + ], + "cases": [ + { + "name": "T[6]", + "inputs": [ + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [2, 6], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 2.0, 7.0, 8.0], + "dims": [2, 2], + "type": "float32" + }, + { + "data": [3.0, 4.0, 5.0, 6.0, 9.0, 10.0, 11.0, 12.0], + "dims": [2, 4], + "type": "float32" + } + ] + } + ] } ]