[js/webgpu] Refactor createTensorShapeVariables (#18883)

This commit is contained in:
Xu Xing 2024-02-02 09:59:00 +08:00 committed by GitHub
parent 13ad922e7f
commit 3a2ab1963a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 40 additions and 64 deletions

View file

@ -195,8 +195,7 @@ export const createConv2DMatMulProgramInfo =
{type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));

View file

@ -204,8 +204,7 @@ export const createConv2DTransposeMatMulProgramInfo =
{type: DataType.int32, data: pads}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims));
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {

View file

@ -269,7 +269,7 @@ export const createConvTranspose2DProgramInfo =
{type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
{type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
{type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)
];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));

View file

@ -453,9 +453,7 @@ export const createMatmulProgramInfo =
{type: DataType.int32, data: dimInner}
];
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
...createTensorShapeVariables(bShapeTemp));
programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
const hasBias = inputs.length > 2;

View file

@ -180,9 +180,7 @@ const createBinaryOpProgramInfo =
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
programUniforms: [
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
...createTensorShapeVariables(a.dims),
...createTensorShapeVariables(b.dims),
...createTensorShapeVariables(outputShape),
...createTensorShapeVariables(a.dims, b.dims, outputShape)
],
}),
};

View file

@ -259,9 +259,16 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
return typeof mappedType === 'string' ? mappedType : mappedType[1];
};
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ?
[] :
[{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}];
export const createTensorShapeVariables = (...dims: ReadonlyArray<readonly number[]>): ProgramUniform[] => {
const programUniforms: ProgramUniform[] = [];
dims.forEach(dim => {
if (dim.length !== 0) {
programUniforms.push(
{type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)});
}
});
return programUniforms;
};
/**
* A helper function to get maximum vector size for specified data length

View file

@ -35,9 +35,7 @@ export const createGroupedConvProgramInfo =
{type: DataType.uint32, data: outputChannelsPerGroup}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
@ -134,9 +132,7 @@ export const createGroupedConvVectorizeProgramInfo =
{type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}
];
appendActivationUniformsData(attributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape),
...createTensorShapeVariables(outputShapeInShader));
programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader));
const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1];
const getShaderSource = (shaderHelper: ShaderHelper) => {
const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);

View file

@ -55,7 +55,7 @@ const createCumsumProgramInfo =
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis},
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
...createTensorShapeVariables(inputShape, inputShape)
]
}),

View file

@ -84,10 +84,8 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
${assignment}`;
};
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape)
];
const programUniforms: ProgramUniform[] =
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)];
return {
name: 'Expand',
shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']},

View file

@ -51,9 +51,7 @@ const createGatherElementsProgramInfo =
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit},
{type: DataType.uint32, data: axis}
];
programUniforms.push(...createTensorShapeVariables(inputShape));
programUniforms.push(...createTensorShapeVariables(indicesShape));
programUniforms.push(...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits

View file

@ -35,8 +35,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit},
{type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape)
{type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape)
];
const getShaderSource = (shaderHelper: ShaderHelper) => {

View file

@ -26,7 +26,7 @@ const createInstanceNormProgramInfo =
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
const programUniforms: ProgramUniform[] =
[{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}];
programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape));
programUniforms.push(...createTensorShapeVariables(inputShape, inputShape));
const getShaderSource = (shaderHelper: ShaderHelper) => {
const x = inputVariable('x', inputs[0].dataType, inputShape.length, components);

View file

@ -34,9 +34,7 @@ export const createNaiveMatmulProgramInfo =
{type: DataType.uint32, data: K}
];
appendActivationUniformsData(activationAttributes, programUniforms);
programUniforms.push(
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
...createTensorShapeVariables(bShape));
programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape));
if (hasBias) {
programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
}

View file

@ -158,7 +158,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
programUniforms.push({type: inputs[0].dataType, data: attributes.value});
}
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const getShaderSource = (shaderHelper: ShaderHelper) => {

View file

@ -298,7 +298,7 @@ const createAveragePoolProgramInfo =
}
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
@ -370,7 +370,7 @@ const createMaxPoolProgramInfo =
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
getUniformAndPadInfo(outputShape, adjustedAttributes);
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
programUniforms.push(...createTensorShapeVariables(input.dims, outputShape));
return {
name,
shaderCache:

View file

@ -100,10 +100,8 @@ export const createReduceProgramInfo =
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape)
]
programUniforms:
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]
}),
};
};

View file

@ -642,11 +642,8 @@ const createResizeProgramInfo =
outputs: [{dims: outputShape, dataType: inputTensor.dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: DataType.uint32, data: outputSize},
{type: DataType.float, data: scales},
{type: DataType.float, data: roi},
...createTensorShapeVariables(inputShape),
...createTensorShapeVariables(outputShape),
{type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales},
{type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape)
]
})
};

View file

@ -157,7 +157,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts},
{type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps},
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)
...createTensorShapeVariables(inputs[0].dims, outputShape)
];
const getShaderSource = (shaderHelper: ShaderHelper) => `

View file

@ -83,9 +83,8 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
}
programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis});
programUniforms.push(...createTensorShapeVariables(inputShape));
outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape)));
programUniforms.push(
{type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes));
const getShaderSource = (shaderHelper: ShaderHelper) => `
${
shaderHelper.registerUniform('input_size', 'u32')

View file

@ -79,10 +79,8 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape)
],
programUniforms:
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)],
}),
getShaderSource,
};

View file

@ -65,11 +65,8 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
return {
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
programUniforms: [
{type: DataType.uint32, data: outputSize},
...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape),
],
programUniforms:
[{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)],
};
},
getShaderSource,

View file

@ -97,11 +97,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
programUniforms: [
{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC),
...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB),
...createTensorShapeVariables(outputShape)
],
programUniforms:
[{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)],
}),
};
};