mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
[js/webgpu] Optimize softmax by vector (#18153)
### Description This PR enables `softmax` outputs max supported components instead of scalar for each thread. Softmax with input[0]: [12,4096,4096] becomes 47.86 ms from 55.11 ms
This commit is contained in:
parent
90d1f537cb
commit
785e2b1eae
1 changed files with 30 additions and 14 deletions
|
|
@ -10,7 +10,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
|
||||
import {getMaxComponents, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
|
|
@ -37,23 +37,39 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
|
||||
const cols = shape[axis];
|
||||
const rows = outputSize / cols;
|
||||
const components = getMaxComponents(cols);
|
||||
const packedCols = cols / components;
|
||||
const valueType = components === 1 ? dataType : `vec${components}<${dataType}>`;
|
||||
|
||||
const maxVector = (name: string, components: number) => {
|
||||
if (components === 4) {
|
||||
return `max(max(${name}.x, ${name}.y), max(${name}.z, ${name}.w))`;
|
||||
} else if (components === 2) {
|
||||
return `max(${name}.x, ${name}.y)`;
|
||||
} else if (components === 3) {
|
||||
return `max(max(${name}.x, ${name}.y), ${name}.z)`;
|
||||
}
|
||||
|
||||
return name;
|
||||
};
|
||||
|
||||
// 6.2.4 in wgsl spec
|
||||
const threadMaxDecl = dataType === 'f32' ? 'var threadMax: f32 = -3.402823e+38f;' : 'var threadMax: f16 = -65504.0h;';
|
||||
const threadMaxDecl =
|
||||
dataType === 'f32' ? `var threadMax = ${valueType}(-3.402823e+38f);` : `var threadMax = ${valueType}(-65504.0h);`;
|
||||
const getShaderSource = (_shaderHelper: ShaderHelper) => `
|
||||
var<workgroup> rowMaxShared : ${dataType};
|
||||
var<workgroup> rowSumShared : ${dataType};
|
||||
var<workgroup> threadShared : array<${dataType}, ${WG}>;
|
||||
var<workgroup> rowMaxShared : ${valueType};
|
||||
var<workgroup> rowSumShared : ${valueType};
|
||||
var<workgroup> threadShared : array<${valueType}, ${WG}>;
|
||||
|
||||
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> result : array<${dataType}>;
|
||||
@group(0) @binding(0) var<storage, read> x : array<${valueType}>;
|
||||
@group(0) @binding(1) var<storage, read_write> result : array<${valueType}>;
|
||||
|
||||
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
|
||||
fn getValue(row: i32, col: i32, row_stride: i32) -> ${valueType} {
|
||||
let index = row * row_stride + col;
|
||||
return x[index];
|
||||
}
|
||||
|
||||
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
|
||||
fn setValue(row: i32, col: i32, row_stride: i32, value: ${valueType}) {
|
||||
let index = row * row_stride + col;
|
||||
result[index] = value;
|
||||
}
|
||||
|
|
@ -64,8 +80,8 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
let lindex = i32(local_id.x);
|
||||
const wg = ${WG};
|
||||
let row = gindex / wg;
|
||||
let cols = ${cols};
|
||||
let row_stride : i32 = ${cols};
|
||||
let cols = ${packedCols};
|
||||
let row_stride : i32 = ${packedCols};
|
||||
|
||||
// find the rows max
|
||||
${threadMaxDecl}
|
||||
|
|
@ -87,12 +103,12 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
workgroupBarrier();
|
||||
}
|
||||
if (lindex == 0) {
|
||||
rowMaxShared = threadShared[0];
|
||||
rowMaxShared = ${valueType}(${maxVector('threadShared[0]', components)});
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
// find the rows sum
|
||||
var threadSum: ${dataType} = 0.0;
|
||||
var threadSum = ${valueType}(0.0);
|
||||
for (var col = lindex; col < cols; col += wg) {
|
||||
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
|
||||
threadSum += subExp;
|
||||
|
|
@ -107,7 +123,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
workgroupBarrier();
|
||||
}
|
||||
if (lindex == 0) {
|
||||
rowSumShared = threadShared[0];
|
||||
rowSumShared = ${valueType}(${sumVector('threadShared[0]', components)});
|
||||
}
|
||||
workgroupBarrier();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue