mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-28 22:56:32 +00:00
[js/webgpu] Remove the limitation on axis in softmax (#22231)
In current implementation, axis in softmax has to be the last, which is an obvious limitation. This PR removes this limitation and will fix issues #20710 and #22176.
This commit is contained in:
parent
d9de054eb5
commit
c75f4a09b7
2 changed files with 110 additions and 33 deletions
|
|
@ -9,7 +9,8 @@ import { DataType } from '../../../wasm-common';
|
|||
import { TensorView } from '../../tensor-view';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
||||
import { ComputeContext, ProgramInfo } from '../types';
|
||||
import { ComputeContext } from '../types';
|
||||
import { createTransposeProgramInfo } from './transpose';
|
||||
|
||||
import {
|
||||
getMaxComponents,
|
||||
|
|
@ -30,19 +31,32 @@ export interface SoftmaxAttributes extends AttributeWithCacheKey {
|
|||
readonly axis: number;
|
||||
}
|
||||
|
||||
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
|
||||
const shape = input.dims;
|
||||
const outputSize = ShapeUtil.size(shape);
|
||||
const createSoftmaxProgramInfo = (context: ComputeContext, attributes: SoftmaxAttributes) => {
|
||||
const input = context.inputs[0];
|
||||
const inputShape = input.dims;
|
||||
const outputSize = ShapeUtil.size(inputShape);
|
||||
const WG = 64;
|
||||
let axis = attributes.axis;
|
||||
if (axis < 0) {
|
||||
axis = shape.length + axis;
|
||||
}
|
||||
if (axis < shape.length - 1) {
|
||||
throw new Error('softmax only supports last axis for now.');
|
||||
const inputRank = inputShape.length;
|
||||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
|
||||
const isTransposeRequired = axis < inputShape.length - 1;
|
||||
let transposedInput: TensorView;
|
||||
let perm: number[] = [];
|
||||
|
||||
if (isTransposeRequired) {
|
||||
perm = Array.from({ length: inputRank }, (_, i) => i);
|
||||
perm[axis] = inputRank - 1;
|
||||
perm[inputRank - 1] = axis;
|
||||
|
||||
transposedInput = context.compute(createTransposeProgramInfo(input, perm), {
|
||||
inputs: [input],
|
||||
outputs: [-1],
|
||||
})[0];
|
||||
} else {
|
||||
transposedInput = input;
|
||||
}
|
||||
|
||||
const cols = shape[axis];
|
||||
const transposedInputShape = transposedInput.dims;
|
||||
const cols = transposedInputShape[inputRank - 1];
|
||||
const rows = outputSize / cols;
|
||||
const components = getMaxComponents(cols);
|
||||
const packedCols = cols / components;
|
||||
|
|
@ -58,12 +72,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
|
||||
return name;
|
||||
};
|
||||
const x = inputVariable('x', input.dataType, input.dims, components);
|
||||
const output = outputVariable('result', input.dataType, input.dims, components);
|
||||
const x = inputVariable('x', transposedInput.dataType, transposedInput.dims, components);
|
||||
const output = outputVariable('result', transposedInput.dataType, transposedInput.dims, components);
|
||||
const valueType = x.type.value;
|
||||
// 6.2.4 in wgsl spec
|
||||
const threadMaxDecl =
|
||||
tensorTypeToWsglStorageType(input.dataType) === 'f32'
|
||||
tensorTypeToWsglStorageType(transposedInput.dataType) === 'f32'
|
||||
? `var threadMax = ${valueType}(-3.402823e+38f);`
|
||||
: `var threadMax = ${valueType}(-65504.0h);`;
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
|
@ -139,21 +153,33 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
setValue(row, col, row_stride, value);
|
||||
}
|
||||
}`;
|
||||
return {
|
||||
name: 'Softmax',
|
||||
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
|
||||
getRunData: () => ({
|
||||
outputs: [{ dims: shape, dataType: input.dataType }],
|
||||
dispatchGroup: { x: rows },
|
||||
programUniforms: [{ type: DataType.int32, data: packedCols }],
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
const result = context.compute(
|
||||
{
|
||||
name: 'Softmax',
|
||||
shaderCache: { hint: `${components}`, inputDependencies: ['type'] },
|
||||
getRunData: () => ({
|
||||
outputs: [{ dims: transposedInputShape, dataType: transposedInput.dataType }],
|
||||
dispatchGroup: { x: rows },
|
||||
programUniforms: [{ type: DataType.int32, data: packedCols }],
|
||||
}),
|
||||
getShaderSource,
|
||||
},
|
||||
{
|
||||
inputs: [transposedInput],
|
||||
outputs: [isTransposeRequired ? -1 : 0],
|
||||
},
|
||||
)[0];
|
||||
|
||||
if (isTransposeRequired) {
|
||||
context.compute(createTransposeProgramInfo(result, perm), {
|
||||
inputs: [result],
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
|
||||
validateInputs(context.inputs);
|
||||
context.compute(createSoftmaxProgramInfo(context.inputs[0], attributes));
|
||||
createSoftmaxProgramInfo(context, attributes);
|
||||
};
|
||||
|
||||
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
|
||||
|
|
|
|||
|
|
@ -20,14 +20,7 @@
|
|||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Softmax with no attributes",
|
||||
"operator": "Softmax",
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
},
|
||||
{
|
||||
"name": "T[2, 2, 2]",
|
||||
"inputs": [
|
||||
|
|
@ -49,5 +42,63 @@
|
|||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Softmax with attribute axis -1",
|
||||
"operator": "Softmax",
|
||||
"attributes": [{ "name": "axis", "data": -1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2,2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.0, 2.0, 3.0, 4.0],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.2689414322376251, 0.7310585975646973, 0.2689414322376251, 0.7310585975646973],
|
||||
"dims": [2, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Softmax with attribute axis 1",
|
||||
"operator": "Softmax",
|
||||
"attributes": [{ "name": "axis", "data": 1, "type": "int" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1, 2, 3, 4]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0,
|
||||
20.0, 21.0, 22.0, 23.0, 24.0
|
||||
],
|
||||
"dims": [1, 2, 3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
|
||||
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
|
||||
0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878, 0.000006144174221844878,
|
||||
0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434,
|
||||
0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434, 0.9999938011169434,
|
||||
0.9999938011169434, 0.9999938011169434
|
||||
],
|
||||
"dims": [1, 2, 3, 4],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue