mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-08 00:23:03 +00:00
[JS/WebGPU] Preserve zero size input tensor dims. (#19737)
### Description For Concat operation, the zero-size input tensor shape need to be preserved and, unlike non-zero tensors, the dims are not constrained to match other input tensors' dims. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
6c3bed6740
commit
24b72d2613
2 changed files with 149 additions and 77 deletions
|
|
@ -13,25 +13,32 @@ export interface ConcatAttributes extends AttributeWithCacheKey {
|
|||
readonly axis: number;
|
||||
}
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
const validateInputs = (inputs: readonly TensorView[], axis: number): void => {
|
||||
if (!inputs || inputs.length < 1) {
|
||||
throw new Error('too few inputs');
|
||||
}
|
||||
|
||||
const inputType = inputs[0].dataType;
|
||||
const inputDimensionality = inputs[0].dims.length;
|
||||
|
||||
for (const input of inputs) {
|
||||
const referenceIndex = 0;
|
||||
const referenceInput = inputs[referenceIndex];
|
||||
const inputType = referenceInput.dataType;
|
||||
const inputRank = referenceInput.dims.length;
|
||||
inputs.forEach((input, i) => {
|
||||
if (i === referenceIndex) {
|
||||
return;
|
||||
}
|
||||
// make sure types of all inputs match
|
||||
if (input.dataType !== inputType) {
|
||||
throw new Error('input tensors should be one type');
|
||||
}
|
||||
|
||||
// make sure the dimensionality of all inputs are the same
|
||||
if (input.dims.length !== inputDimensionality) {
|
||||
if (input.dims.length !== inputRank) {
|
||||
throw new Error('input tensors should have the same shape');
|
||||
}
|
||||
}
|
||||
input.dims.forEach((dim, i) => {
|
||||
if (i !== axis && dim !== referenceInput.dims[i]) {
|
||||
throw new Error('non concat dimensions must match');
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const calculateInputIndexImpl = (numberOfTensors: number, sizeInConcatAxisStr: string): string => `
|
||||
|
|
@ -64,65 +71,43 @@ const assignOutputData = (inputs: readonly IndicesHelper[], output: IndicesHelpe
|
|||
return codeLines.join('\n');
|
||||
};
|
||||
|
||||
const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): ProgramInfo => {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
if (axis >= inputShape.length || axis < (-1 * inputShape.length)) {
|
||||
throw new Error('axis specified for concat doesn\'t match input dimensionality');
|
||||
}
|
||||
const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
|
||||
// ensure all of the non-concatenated axes match each other
|
||||
// calculate the shape of the output tensor while we do that
|
||||
const outputShape = inputShape.slice(0);
|
||||
for (let i = 1; i < inputs.length; i++) {
|
||||
const dataNShape = inputs[i].dims.slice();
|
||||
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
|
||||
// add to the placeholder for computing output shape
|
||||
if (axisIndex === adjustedAxis) {
|
||||
outputShape[adjustedAxis] += dataNShape[axisIndex];
|
||||
const createConcatProgramInfo =
|
||||
(inputs: readonly TensorView[], adjustedAxis: number, outputShape: number[], dataType: DataType): ProgramInfo => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const sizeInConcatAxis = new Array<number>(inputs.length);
|
||||
const inputVars = new Array<IndicesHelper>(inputs.length);
|
||||
|
||||
let previousSum = 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
const inputRanks = [];
|
||||
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
previousSum += inputs[i].dims[adjustedAxis];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
inputRanks.push(inputs[i].dims.length);
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
|
||||
inputDependencies.push('rank');
|
||||
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
|
||||
}
|
||||
// ensure all non-cancatenated axes match each other
|
||||
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
|
||||
throw new Error('non concat dimensions must match');
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
|
||||
}
|
||||
}
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const sizeInConcatAxis = new Array<number>(inputs.length);
|
||||
const inputVars = new Array<IndicesHelper>(inputs.length);
|
||||
const dataType = inputs[0].dataType;
|
||||
|
||||
let previousSum = 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
const inputRanks = [];
|
||||
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
previousSum += inputs[i].dims[adjustedAxis];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
inputRanks.push(inputs[i].dims.length);
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
|
||||
inputDependencies.push('rank');
|
||||
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
|
||||
}
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
const output = outputVariable('output', dataType, outputShape.length);
|
||||
const indicesAxis = output.indicesGet('indices', adjustedAxis);
|
||||
const sizeInConcatAxisStr =
|
||||
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const output = outputVariable('output', dataType, outputShape.length);
|
||||
const indicesAxis = output.indicesGet('indices', adjustedAxis);
|
||||
const sizeInConcatAxisStr =
|
||||
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
||||
${(() => {
|
||||
shaderHelper.registerUniform('outputSize', 'u32');
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
|
||||
}
|
||||
return shaderHelper.declareVariables(...inputVars, output);
|
||||
})()}
|
||||
shaderHelper.registerUniform('outputSize', 'u32');
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
shaderHelper.registerUniform(`sizeInConcatAxis${i}`, 'u32');
|
||||
}
|
||||
return shaderHelper.declareVariables(...inputVars, output);
|
||||
})()}
|
||||
|
||||
${calculateInputIndexImpl(sizeInConcatAxis.length, sizeInConcatAxisStr)}
|
||||
|
||||
|
|
@ -140,23 +125,30 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
|
|||
${assignOutputData(inputVars, output)}
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'Concat',
|
||||
shaderCache: {hint: `${axis}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
return {
|
||||
name: 'Concat',
|
||||
shaderCache: {hint: `${adjustedAxis}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms,
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
const inputs = context.inputs;
|
||||
const inputShape = inputs[0].dims;
|
||||
const adjustedAxis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
|
||||
validateInputs(inputs, adjustedAxis);
|
||||
const outputShape = inputShape.slice();
|
||||
outputShape[adjustedAxis] =
|
||||
inputs.reduce((sum, input) => sum + (input.dims.length > adjustedAxis ? input.dims[adjustedAxis] : 0), 0);
|
||||
// 0 length tensors are valid for concat, remove them
|
||||
const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0);
|
||||
context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs});
|
||||
const nonEmptyInputs = inputs.filter(input => ShapeUtil.size(input.dims) > 0);
|
||||
context.compute(
|
||||
createConcatProgramInfo(nonEmptyInputs, adjustedAxis, outputShape, inputs[0].dataType), {inputs: nonEmptyInputs});
|
||||
};
|
||||
|
||||
export const parseConcatAttributes = (attributes: Record<string, unknown>): ConcatAttributes =>
|
||||
|
|
|
|||
|
|
@ -557,5 +557,85 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Concat 2D axis=1; Preserve dims",
|
||||
"operator": "Concat",
|
||||
"attributes": [
|
||||
{
|
||||
"name": "axis",
|
||||
"data": 0,
|
||||
"type": "int"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "Some but not all input tensors are zero-sized",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [],
|
||||
"dims": [0, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1],
|
||||
"dims": [1, 1],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Concat 2D axis=1; Preserve dims",
|
||||
"operator": "Concat",
|
||||
"attributes": [
|
||||
{
|
||||
"name": "axis",
|
||||
"data": 1,
|
||||
"type": "int"
|
||||
}
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "All input tensors are zero-sized",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [],
|
||||
"dims": [0, 0],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [],
|
||||
"dims": [0, 1],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [],
|
||||
"dims": [0, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [],
|
||||
"dims": [0, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [],
|
||||
"dims": [0, 6],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue