[js/webgpu] Handle negative axis in op Split (#21771)

This is to fix issue #21703, where the axis is a negative value in the
model. According to the spec
(https://onnx.ai/onnx/operators/onnx__Split.html), negative axis means
counting dimensions from the back.
This commit is contained in:
Yang Gu 2024-08-18 07:41:23 +08:00 committed by GitHub
parent d79e3c5791
commit 49fc168eed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 34 additions and 1 deletions

View file

@ -87,7 +87,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
previousSum += attributes.splitSizes[i];
sizeInSplitAxis[i] = previousSum;
const outputShape = inputShape.slice();
outputShape[attributes.axis] = attributes.splitSizes[i];
outputShape[axis] = attributes.splitSizes[i];
outputShapes.push(outputShape);
outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length);
outputsTensorInfo.push({ dims: outputShapes[i], dataType: inputs[0].dataType });

View file

@ -64,5 +64,38 @@
]
}
]
},
{
"name": "Split on Axis -1 - 2D",
"operator": "Split",
"opset": { "domain": "", "version": 12 },
"attributes": [
{ "name": "axis", "data": -1, "type": "int" },
{ "name": "split", "data": [2, 4], "type": "ints" }
],
"cases": [
{
"name": "T[6]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
"dims": [2, 6],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 2.0, 7.0, 8.0],
"dims": [2, 2],
"type": "float32"
},
{
"data": [3.0, 4.0, 5.0, 6.0, 9.0, 10.0, 11.0, 12.0],
"dims": [2, 4],
"type": "float32"
}
]
}
]
}
]