mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
[js/web] support opset-13 of softmax (#9493)
* add p50 in test * support opset-13 of softmax * update a operators.md * resolve comments * fix lint and format Co-authored-by: Yulong Wang <yulongw@microsoft.com>
This commit is contained in:
parent
d079e0d48f
commit
c79307e7b4
4 changed files with 115 additions and 54 deletions
|
|
@ -152,7 +152,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
|
|||
| [Sinh](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sinh) | |
|
||||
| [Size](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Size) | |
|
||||
| [Slice](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice) | [1-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-1), [10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-10), [11-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-11), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-13) |
|
||||
| [Softmax](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax) | [1-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-1), [11-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11) |
|
||||
| [Softmax](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax) | [1-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-1), [11-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-13) |
|
||||
| [SoftmaxCrossEntropyLoss](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SoftmaxCrossEntropyLoss) | |
|
||||
| [Softplus](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softplus) | |
|
||||
| [Softsign](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softsign) | |
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import {reshape} from './ops/reshape';
|
|||
import {parseResizeAttributesV10, parseResizeAttributesV11, resize} from './ops/resize-packed';
|
||||
import {shape} from './ops/shape';
|
||||
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
|
||||
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
|
||||
import {parseSoftmaxAttributes, parseSoftmaxAttributesV13, softmax, softmaxV13} from './ops/softmax';
|
||||
import {parseSplitAttributes, split} from './ops/split';
|
||||
import {parseSqueezeAttributes, squeeze, squeezeV13} from './ops/squeeze';
|
||||
import {sum} from './ops/sum';
|
||||
|
|
@ -102,6 +102,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
|
|||
['Slice', '', '1-9', slice, parseSliceAttributes],
|
||||
// The "semantic" meaning of axis has changed in opset-13.
|
||||
['Softmax', '', '1-12', softmax, parseSoftmaxAttributes],
|
||||
['Softmax', '', '13+', softmaxV13, parseSoftmaxAttributesV13],
|
||||
// 'Split' operator has an optional attribute 'split'
|
||||
// this attribute determines how the specified axis of input data is split.
|
||||
// When the attribute is missing, we need the count of number of outputs
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import {ShapeUtil} from '../../../util';
|
|||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, TextureType} from '../types';
|
||||
import {transpose, TransposeAttributes} from './transpose';
|
||||
|
||||
export interface SoftmaxAttributes extends AttributeWithCacheKey {
|
||||
readonly axis: number;
|
||||
|
|
@ -38,62 +39,123 @@ export const softmax: OperatorImplementation<SoftmaxAttributes> =
|
|||
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
|
||||
const N = ShapeUtil.sizeToDimension(inputShape, axis);
|
||||
const D = ShapeUtil.sizeFromDimension(inputShape, axis);
|
||||
const logicalRowCount = ShapeUtil.sizeToDimension(inputShape, axis);
|
||||
const featureCount = ShapeUtil.sizeFromDimension(inputShape, axis);
|
||||
|
||||
const computeMaxProgramInfo = createComputeMaxProgramInfo(inferenceHandler, inputs[0], N, D, [N]);
|
||||
const output = computeSoftmax(inferenceHandler, inputs, attributes, logicalRowCount, featureCount);
|
||||
return output;
|
||||
};
|
||||
|
||||
export const parseSoftmaxAttributes: OperatorInitialization<SoftmaxAttributes> =
|
||||
(node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 1)});
|
||||
|
||||
export const parseSoftmaxAttributesV13: OperatorInitialization<SoftmaxAttributes> =
|
||||
(node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', -1)});
|
||||
|
||||
// The "semantic" meaning of axis has changed in opset-13.
|
||||
// Please compare: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
|
||||
// with https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11 for detailed explanations
|
||||
// To account for the opset-13 behavior, our plan will be to transpose the "axis" dim to the innermost dim
|
||||
// and perform softmax and then reverse the transpose. We can skip the transposing aspect if the axis is already
|
||||
// the innermost dim
|
||||
export const softmaxV13: OperatorImplementation<SoftmaxAttributes> =
|
||||
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes): Tensor[] => {
|
||||
validateInputs(inputs);
|
||||
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
|
||||
const rank = inputShape.length;
|
||||
|
||||
const isTransposeRequired = (axis !== rank - 1) ? true : false;
|
||||
const transposedInputShape: number[] = [];
|
||||
let perm: number[] = [];
|
||||
let transposedInputs: Tensor[] = [];
|
||||
let transposeAttribute: TransposeAttributes;
|
||||
|
||||
if (isTransposeRequired) {
|
||||
perm = Array.from({length: rank}).map((_, i) => i);
|
||||
|
||||
// swap the innermost dim with the dim corresponding to axis
|
||||
perm[axis] = rank - 1;
|
||||
perm[rank - 1] = axis;
|
||||
|
||||
perm.map(p => transposedInputShape.push(inputShape[p]));
|
||||
|
||||
transposeAttribute = createAttributeWithCacheKey({perm});
|
||||
transposedInputs = transpose(inferenceHandler, inputs, transposeAttribute);
|
||||
}
|
||||
|
||||
const logicalRowCount = isTransposeRequired ? ShapeUtil.sizeToDimension(transposedInputShape, rank - 1) :
|
||||
ShapeUtil.sizeToDimension(inputShape, rank - 1);
|
||||
const featureCount = isTransposeRequired ? ShapeUtil.sizeFromDimension(transposedInputShape, rank - 1) :
|
||||
ShapeUtil.sizeFromDimension(inputShape, rank - 1);
|
||||
|
||||
const output = computeSoftmax(
|
||||
inferenceHandler, isTransposeRequired ? transposedInputs : inputs, attributes, logicalRowCount, featureCount);
|
||||
|
||||
if (isTransposeRequired) {
|
||||
const reversedOutput = transpose(inferenceHandler, output, transposeAttribute!);
|
||||
return reversedOutput;
|
||||
} else {
|
||||
return output;
|
||||
}
|
||||
};
|
||||
|
||||
const computeSoftmax =
|
||||
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes, logicalRowCount: number,
|
||||
featureCount: number): Tensor[] => {
|
||||
const computeMaxProgramInfo =
|
||||
createComputeMaxProgramInfo(inferenceHandler, inputs[0], logicalRowCount, featureCount, [logicalRowCount]);
|
||||
const max = inferenceHandler.run(
|
||||
{...softmaxComputeMaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeMaxProgramInfo},
|
||||
inputs);
|
||||
|
||||
const computeScaleProgramInfo =
|
||||
createComputScaleProgramInfo(inferenceHandler, inputs[0], N, D, computeMaxProgramInfo.output.dims, [N]);
|
||||
const computeScaleProgramInfo = createComputScaleProgramInfo(
|
||||
inferenceHandler, inputs[0], logicalRowCount, featureCount, computeMaxProgramInfo.output.dims,
|
||||
[logicalRowCount]);
|
||||
const scale = inferenceHandler.run(
|
||||
{...softmaxComputeScaleProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeScaleProgramInfo},
|
||||
[inputs[0], max]);
|
||||
|
||||
const softMaxProgramInfo = createSoftMaxProgramInfo(
|
||||
inferenceHandler, inputs[0], N, D, computeMaxProgramInfo.output.dims, computeScaleProgramInfo.output.dims);
|
||||
inferenceHandler, inputs[0], logicalRowCount, featureCount, computeMaxProgramInfo.output.dims,
|
||||
computeScaleProgramInfo.output.dims);
|
||||
const output = inferenceHandler.run(
|
||||
{...softmaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => softMaxProgramInfo},
|
||||
[inputs[0], max, scale]);
|
||||
return [output];
|
||||
};
|
||||
|
||||
export const parseSoftmaxAttributes: OperatorInitialization<SoftmaxAttributes> =
|
||||
(node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 1)});
|
||||
|
||||
/**
|
||||
* Create a texture that contains the maximum value of each of the 'N' rows
|
||||
*/
|
||||
const createComputeMaxProgramInfo =
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
(inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number, outputShape: number[]):
|
||||
ProgramInfo => {
|
||||
const [textureWidth, textureHeight] =
|
||||
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
|
||||
const rank = outputShape.length;
|
||||
(inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number,
|
||||
outputShape: number[]): ProgramInfo => {
|
||||
const [textureWidth, textureHeight] =
|
||||
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
|
||||
const rank = outputShape.length;
|
||||
|
||||
if (N < 1 || D < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
if (logicalRowCount < 1 || featureCount < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
|
||||
if (outputShape.length !== 1) {
|
||||
throw new Error('Dimensionality of the output should be 1');
|
||||
}
|
||||
if (outputShape.length !== 1) {
|
||||
throw new Error('Dimensionality of the output should be 1');
|
||||
}
|
||||
|
||||
if (outputShape[0] !== N) {
|
||||
throw new Error('Shape of the output should be equal to logical row count');
|
||||
}
|
||||
if (outputShape[0] !== logicalRowCount) {
|
||||
throw new Error('Shape of the output should be equal to logical row count');
|
||||
}
|
||||
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
float process(int[${rank}] indices) {
|
||||
int logical_row_start_offset = indices[0] * ${D};
|
||||
int logical_row_start_offset = indices[0] * ${featureCount};
|
||||
|
||||
float max = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset, ${textureWidth},
|
||||
${textureHeight} )));
|
||||
for(int i=1; i<${D}; ++i)
|
||||
for(int i=1; i<${featureCount}; ++i)
|
||||
{
|
||||
float current = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
|
||||
${textureWidth}, ${textureHeight})));
|
||||
|
|
@ -103,25 +165,24 @@ const createComputeMaxProgramInfo =
|
|||
|
||||
return max;
|
||||
}`;
|
||||
return {
|
||||
...softmaxComputeMaxProgramMetadata,
|
||||
output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked},
|
||||
shaderSource
|
||||
};
|
||||
};
|
||||
return {
|
||||
...softmaxComputeMaxProgramMetadata,
|
||||
output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked},
|
||||
shaderSource
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* Create a texture that contains the normalization factor for each of the 'N' rows
|
||||
*/
|
||||
const createComputScaleProgramInfo =
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
(inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number,
|
||||
(inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number,
|
||||
maxElementPerLogicalRow: readonly number[], outputShape: number[]): ProgramInfo => {
|
||||
const [textureWidth, textureHeight] =
|
||||
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
|
||||
const rank = outputShape.length;
|
||||
|
||||
if (N < 1 || D < 1) {
|
||||
if (logicalRowCount < 1 || featureCount < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
|
||||
|
|
@ -129,7 +190,7 @@ const createComputScaleProgramInfo =
|
|||
throw new Error('Dimensionality of the output should be 1');
|
||||
}
|
||||
|
||||
if (outputShape[0] !== N) {
|
||||
if (outputShape[0] !== logicalRowCount) {
|
||||
throw new Error('Shape of the output should be equal to logical row count');
|
||||
}
|
||||
|
||||
|
|
@ -137,18 +198,18 @@ const createComputScaleProgramInfo =
|
|||
throw new Error('Dimensionality of the intermediate results should be 1');
|
||||
}
|
||||
|
||||
if (maxElementPerLogicalRow[0] !== N) {
|
||||
if (maxElementPerLogicalRow[0] !== logicalRowCount) {
|
||||
throw new Error('Shape of the intermediate results should be equal to logical row count');
|
||||
}
|
||||
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
float process(int[${rank}] indices) {
|
||||
int logical_row_start_offset = indices[0] * ${D};
|
||||
int logical_row_start_offset = indices[0] * ${featureCount};
|
||||
|
||||
float norm_factor = 0.0;
|
||||
float max = _Max(indices);
|
||||
for(int i=0; i<${D}; ++i)
|
||||
for(int i=0; i<${featureCount}; ++i)
|
||||
{
|
||||
norm_factor += exp(getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
|
||||
${textureWidth}, ${textureHeight}))) - max);
|
||||
|
|
@ -164,14 +225,13 @@ const createComputScaleProgramInfo =
|
|||
};
|
||||
|
||||
const createSoftMaxProgramInfo =
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
(inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number,
|
||||
(inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number,
|
||||
maxElementPerLogicalRow: readonly number[], normalizationPerLogicalRow: readonly number[]): ProgramInfo => {
|
||||
const [textureWidth, textureHeight] =
|
||||
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
|
||||
const rank = input.dims.length;
|
||||
|
||||
if (N < 1 || D < 1) {
|
||||
if (logicalRowCount < 1 || featureCount < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
|
||||
|
|
@ -179,7 +239,7 @@ const createSoftMaxProgramInfo =
|
|||
throw new Error('Dimensionality of the intermediate results should be 1');
|
||||
}
|
||||
|
||||
if (maxElementPerLogicalRow[0] !== N || normalizationPerLogicalRow[0] !== N) {
|
||||
if (maxElementPerLogicalRow[0] !== logicalRowCount || normalizationPerLogicalRow[0] !== logicalRowCount) {
|
||||
throw new Error('Shape of the intermediate results should be equal to logical row count');
|
||||
}
|
||||
|
||||
|
|
@ -191,7 +251,7 @@ const createSoftMaxProgramInfo =
|
|||
|
||||
//determine the logical row for this index
|
||||
int logical_row_index[1];
|
||||
logical_row_index[0] = offset / ${D};
|
||||
logical_row_index[0] = offset / ${featureCount};
|
||||
|
||||
float norm_factor = _Norm(logical_row_index);
|
||||
|
||||
|
|
|
|||
|
|
@ -137,13 +137,13 @@
|
|||
"test_sigmoid_example",
|
||||
"test_sin_example",
|
||||
"test_sin",
|
||||
"v{7,8,9,10,11,12}/test_softmax_axis_0",
|
||||
"v{7,8,9,10,11,12}/test_softmax_axis_1",
|
||||
"v{7,8,9,10,11,12}/test_softmax_axis_2",
|
||||
"v{7,8,9,10,11,12}/test_softmax_default_axis",
|
||||
"v{7,8,9,10,11,12}/test_softmax_example",
|
||||
"test_softmax_axis_0",
|
||||
"test_softmax_axis_1",
|
||||
"test_softmax_axis_2",
|
||||
"test_softmax_default_axis",
|
||||
"test_softmax_example",
|
||||
{
|
||||
"name": "v{7,8,9,10,11,12}/test_softmax_large_number",
|
||||
"name": "test_softmax_large_number",
|
||||
"condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment
|
||||
},
|
||||
"test_sub_bcast",
|
||||
|
|
|
|||
Loading…
Reference in a new issue