onnxruntime/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts
Guenther Schmuelling 0df2e14038
js/webgpu: argmax,argmin,softmax support (#16882)
argmax and argmin are similar to reduce. Eventually we need to add
optimized flavors of the shader.

softmax is optimized but only works on the last axis for now which
should be the common use case.

todo: enable more ut for argmax/argmin
2023-08-02 18:16:19 -07:00

147 lines
4.9 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// TODO: this is the same naive implementation we use for reduce that has
// performance limitations when the reduced axis is long. Need to add
// a optimized codepath for this.
import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo} from '../types';
import {ShaderHelper} from './common';
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Softmax op requires 1 input.');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Softmax input needs to be float.');
}
};
export interface SoftmaxAttributes extends AttributeWithCacheKey {
readonly axis: number;
}
export const softmaxProgramMetadata = {
name: 'Softmax',
inputTypes: [GpuDataType.default]
};
const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttributes): ProgramInfo => {
const dataType = 'f32';
const shape = input.dims;
const outputSize = ShapeUtil.size(shape);
const WG = 64;
let axis = attributes.axis;
if (axis < 0) {
axis = shape.length + axis;
}
if (axis < shape.length - 1) {
throw new Error('softmax only supports last axis for now.');
}
const cols = shape[axis];
const rows = outputSize / cols;
const getShaderSource = (_shaderHelper: ShaderHelper) => `
var<workgroup> rowMaxShared : ${dataType};
var<workgroup> rowSumShared : ${dataType};
var<workgroup> threadShared : array<${dataType}, ${WG}>;
@group(0) @binding(0) var<storage, read> x : array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> result : array<${dataType}>;
fn getValue(row: i32, col: i32, row_stride: i32) -> ${dataType} {
let index = row * row_stride + col;
return x[index];
}
fn setValue(row: i32, col: i32, row_stride: i32, value: ${dataType}) {
let index = row * row_stride + col;
result[index] = value;
}
@compute @workgroup_size(${WG}, 1, 1)
fn main(@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(global_invocation_id) global_id : vec3u) {
let gindex = i32(global_id.x);
let lindex = i32(local_id.x);
const wg = ${WG};
let row = gindex / wg;
let cols = ${cols};
let row_stride : i32 = ${cols};
// find the rows max
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
for (var col = lindex; col < cols; col += wg) {
let value = getValue(row, col, row_stride);
threadMax = max(threadMax, value);
}
if (lindex < cols) {
threadShared[lindex] = threadMax;
}
workgroupBarrier();
var reduceSize = min(cols, wg);
for (var currSize = reduceSize >> 1; currSize > 0; currSize = reduceSize >> 1) {
reduceSize = currSize + (reduceSize & 1);
if (lindex < currSize) {
threadShared[lindex] = max(threadShared[lindex], threadShared[lindex + reduceSize]);
}
workgroupBarrier();
}
if (lindex == 0) {
rowMaxShared = threadShared[0];
}
workgroupBarrier();
// find the rows sum
var threadSum = 0.0;
for (var col = lindex; col < cols; col += wg) {
let subExp = exp(getValue(row, col, row_stride) - rowMaxShared);
threadSum += subExp;
}
threadShared[lindex] = threadSum;
workgroupBarrier();
for (var currSize = wg >> 1; currSize > 0; currSize = currSize >> 1) {
if (lindex < currSize) {
threadShared[lindex] = threadShared[lindex] + threadShared[lindex + currSize];
}
workgroupBarrier();
}
if (lindex == 0) {
rowSumShared = threadShared[0];
}
workgroupBarrier();
// calculate final value for each element in the row
for (var col = lindex; col < cols; col += wg) {
let value = exp(getValue(row, col, row_stride) - rowMaxShared) / rowSumShared;
setValue(row, col, row_stride, value);
}
}`;
return {
...softmaxProgramMetadata,
outputs: [{dims: shape, dataType: input.dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: rows})
};
};
export const softmax = (context: ComputeContext, attributes: SoftmaxAttributes): void => {
validateInputs(context.inputs);
context.compute({
...softmaxProgramMetadata,
cacheHint: attributes.cacheKey,
get: () => createSoftmaxProgramInfo(context.inputs[0], attributes)
});
};
export const parseSoftmaxAttributes = (attributes: Record<string, unknown>): SoftmaxAttributes =>
createAttributeWithCacheKey({axis: attributes.axis as number});