[js/webgpu] Remove the limitation on axis in softmax (#22231)

In current implementation, axis in softmax has to be the last, which is
an obvious limitation. This PR removes this limitation and will fix
issues #20710 and #22176.
This commit is contained in:
Yang Gu 2024-10-01 09:27:11 +08:00 committed by GitHub
parent d9de054eb5
commit c75f4a09b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 110 additions and 33 deletions

View file

@ -9,7 +9,8 @@ import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo } from '../types';
import { ComputeContext } from '../types';
import { createTransposeProgramInfo } from './transpose';
import {
getMaxComponents,
@ -30,19 +31,32 @@ export interface SoftmaxAttributes extends AttributeWithCacheKey {
readonly axis: number;
}
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAttributes) => {
const input = context.inputs[0];
const inputShape = input.dims;
const outputSize = ShapeUtil.size(inputShape);
const WG = 64;
let axis = attributes.axis;
if (axis < 0) {
axis = shape.length + axis;
}
if (axis < shape.length - 1) {
throw new Error('softmax only supports last axis for now.');
const inputRank = inputShape.length;
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const isTransposeRequired = axis < inputShape.length - 1;
let transposedInput: TensorView;
let perm: number[] = [];
if (isTransposeRequired) {
perm = Array.from({ length: inputRank }, (_, i) => i);
perm[axis] = inputRank - 1;
perm[inputRank - 1] = axis;
transposedInput = context.compute(createTransposeProgramInfo(input, perm), {
inputs: [input],
outputs: [-1],
})[0];
} else {
transposedInput = input;
}
const cols = shape[axis];
const transposedInputShape = transposedInput.dims;
const cols = transposedInputShape[inputRank - 1];
const rows = outputSize / cols;
const components = getMaxComponents(cols);
const packedCols = cols / components;
@ -58,12 +72,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
return name;
};
const x = inputVariable('x', input.dataType, input.dims, components);
const output = outputVariable('result', input.dataType, input.dims, components);
const x = inputVariable('x', transposedInput.dataType, transposedInput.dims, components);
const output = outputVariable('result', transposedInput.dataType, transposedInput.dims, components);
const valueType = x.type.value;
// 6.2.4 in wgsl spec
const threadMaxDecl =
tensorTypeToWsglStorageType(input.dataType) === 'f32'
tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32'
? `var threadMax = ${valueType}(-3.402823e+38f);`
: `var threadMax = ${valueType}(-65504.0h);`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
@ -139,21 +153,33 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
setValue(row, col, row_stride, value);
}
}`;
return {
name: 'Softmax',
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
getRunData: () => ({
outputs: [{ dims: shape, dataType: input.dataType }],
dispatchGroup: { x: rows },
programUniforms: [{ type: DataType.int32, data: packedCols }],
}),
getShaderSource,
};
const result = context.compute(
{
name: 'Softmax',
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
getRunData: () => ({
outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }],
dispatchGroup: { x: rows },
programUniforms: [{ type: DataType.int32, data: packedCols }],
}),
getShaderSource,
},
{
inputs: [transposedInput],
outputs: [isTransposeRequired ? -1 : 0],
},
)[0];
if (isTransposeRequired) {
context.compute(createTransposeProgramInfo(result, perm), {
inputs: [result],
});
}
};
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
validateInputs(context.inputs);
context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes));
createSoftmaxProgramInfo(context, attributes);
};
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>

View file

@ -20,14 +20,7 @@
"type": "float32"
}
]
}
]
},
{
"name": "Softmax with no attributes",
"operator": "Softmax",
"attributes": [],
"cases": [
},
{
"name": "T[2, 2, 2]",
"inputs": [
@ -49,5 +42,63 @@
]
}
]
},
{
"name": "Softmax with attribute axis -1",
"operator": "Softmax",
"attributes": [{ "name": "axis", "data": -1, "type": "int" }],
"cases": [
{
"name": "T[2,2]",
"inputs": [
{
"data": [1.0, 2.0, 3.0, 4.0],
"dims": [2, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0.2689414322376251, 0.7310585975646973, 0.2689414322376251, 0.7310585975646973],
"dims": [2, 2],
"type": "float32"
}
]
}
]
},
{
"name": "Softmax with attribute axis 1",
"operator": "Softmax",
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[1, 2, 3, 4]",
"inputs": [
{
"data": [
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
20.0, 21.0, 22.0, 23.0, 24.0
],
"dims": [1, 2, 3, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434,
0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434,
0.9999938011169434, 0.9999938011169434
],
"dims": [1, 2, 3, 4],
"type": "float32"
}
]
}
]
}
]