[js/web] support opset-13 of softmax (#9493)

* add p50 in test

* support opset-13 of softmax

* update a operators.md

* resolve comments

* fix lint and format

Co-authored-by: Yulong Wang <yulongw@microsoft.com>
This commit is contained in:
Sunghoon 2021-10-26 23:58:50 -07:00 committed by GitHub
parent d079e0d48f
commit c79307e7b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 54 deletions

View file

@ -152,7 +152,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Sinh](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Sinh) | |
| [Size](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Size) | |
| [Slice](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Slice) | [1-9](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-1), [10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-10), [11-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-11), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Slice-13) |
| [Softmax](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax) | [1-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-1), [11-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11) |
| [Softmax](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax) | [1-10](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-1), [11-12](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11), [13+](https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-13) |
| [SoftmaxCrossEntropyLoss](https://github.com/onnx/onnx/blob/master/docs/Operators.md#SoftmaxCrossEntropyLoss) | |
| [Softplus](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softplus) | |
| [Softsign](https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softsign) | |

View file

@ -26,7 +26,7 @@ import {reshape} from './ops/reshape';
import {parseResizeAttributesV10, parseResizeAttributesV11, resize} from './ops/resize-packed';
import {shape} from './ops/shape';
import {parseSliceAttributes, slice, sliceV10} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSoftmaxAttributes, parseSoftmaxAttributesV13, softmax, softmaxV13} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
import {parseSqueezeAttributes, squeeze, squeezeV13} from './ops/squeeze';
import {sum} from './ops/sum';
@ -102,6 +102,7 @@ export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
['Slice', '', '1-9', slice, parseSliceAttributes],
// The "semantic" meaning of axis has changed in opset-13.
['Softmax', '', '1-12', softmax, parseSoftmaxAttributes],
['Softmax', '', '13+', softmaxV13, parseSoftmaxAttributesV13],
// 'Split' operator has an optional attribute 'split'
// this attribute determines how the specified axis of input data is split.
// When the attribute is missing, we need the count of number of outputs

View file

@ -9,6 +9,7 @@ import {ShapeUtil} from '../../../util';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, TextureType} from '../types';
import {transpose, TransposeAttributes} from './transpose';
export interface SoftmaxAttributes extends AttributeWithCacheKey {
readonly axis: number;
@ -38,62 +39,123 @@ export const softmax: OperatorImplementation<SoftmaxAttributes> =
const inputShape = inputs[0].dims.slice();
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
const N = ShapeUtil.sizeToDimension(inputShape, axis);
const D = ShapeUtil.sizeFromDimension(inputShape, axis);
const logicalRowCount = ShapeUtil.sizeToDimension(inputShape, axis);
const featureCount = ShapeUtil.sizeFromDimension(inputShape, axis);
const computeMaxProgramInfo = createComputeMaxProgramInfo(inferenceHandler, inputs[0], N, D, [N]);
const output = computeSoftmax(inferenceHandler, inputs, attributes, logicalRowCount, featureCount);
return output;
};
export const parseSoftmaxAttributes: OperatorInitialization<SoftmaxAttributes> =
(node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 1)});
export const parseSoftmaxAttributesV13: OperatorInitialization<SoftmaxAttributes> =
(node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', -1)});
// The "semantic" meaning of axis has changed in opset-13.
// Please compare: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Softmax
// with https://github.com/onnx/onnx/blob/master/docs/Changelog.md#Softmax-11 for detailed explanations
// To account for the opset-13 behavior, our plan will be to transpose the "axis" dim to the innermost dim
// and perform softmax and then reverse the transpose. We can skip the transposing aspect if the axis is already
// the innermost dim
export const softmaxV13: OperatorImplementation<SoftmaxAttributes> =
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes): Tensor[] => {
validateInputs(inputs);
const inputShape = inputs[0].dims.slice();
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
const rank = inputShape.length;
const isTransposeRequired = (axis !== rank - 1) ? true : false;
const transposedInputShape: number[] = [];
let perm: number[] = [];
let transposedInputs: Tensor[] = [];
let transposeAttribute: TransposeAttributes;
if (isTransposeRequired) {
perm = Array.from({length: rank}).map((_, i) => i);
// swap the innermost dim with the dim corresponding to axis
perm[axis] = rank - 1;
perm[rank - 1] = axis;
perm.map(p => transposedInputShape.push(inputShape[p]));
transposeAttribute = createAttributeWithCacheKey({perm});
transposedInputs = transpose(inferenceHandler, inputs, transposeAttribute);
}
const logicalRowCount = isTransposeRequired ? ShapeUtil.sizeToDimension(transposedInputShape, rank - 1) :
ShapeUtil.sizeToDimension(inputShape, rank - 1);
const featureCount = isTransposeRequired ? ShapeUtil.sizeFromDimension(transposedInputShape, rank - 1) :
ShapeUtil.sizeFromDimension(inputShape, rank - 1);
const output = computeSoftmax(
inferenceHandler, isTransposeRequired ? transposedInputs : inputs, attributes, logicalRowCount, featureCount);
if (isTransposeRequired) {
const reversedOutput = transpose(inferenceHandler, output, transposeAttribute!);
return reversedOutput;
} else {
return output;
}
};
const computeSoftmax =
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: SoftmaxAttributes, logicalRowCount: number,
featureCount: number): Tensor[] => {
const computeMaxProgramInfo =
createComputeMaxProgramInfo(inferenceHandler, inputs[0], logicalRowCount, featureCount, [logicalRowCount]);
const max = inferenceHandler.run(
{...softmaxComputeMaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeMaxProgramInfo},
inputs);
const computeScaleProgramInfo =
createComputScaleProgramInfo(inferenceHandler, inputs[0], N, D, computeMaxProgramInfo.output.dims, [N]);
const computeScaleProgramInfo = createComputScaleProgramInfo(
inferenceHandler, inputs[0], logicalRowCount, featureCount, computeMaxProgramInfo.output.dims,
[logicalRowCount]);
const scale = inferenceHandler.run(
{...softmaxComputeScaleProgramMetadata, cacheHint: attributes.cacheKey, get: () => computeScaleProgramInfo},
[inputs[0], max]);
const softMaxProgramInfo = createSoftMaxProgramInfo(
inferenceHandler, inputs[0], N, D, computeMaxProgramInfo.output.dims, computeScaleProgramInfo.output.dims);
inferenceHandler, inputs[0], logicalRowCount, featureCount, computeMaxProgramInfo.output.dims,
computeScaleProgramInfo.output.dims);
const output = inferenceHandler.run(
{...softmaxProgramMetadata, cacheHint: attributes.cacheKey, get: () => softMaxProgramInfo},
[inputs[0], max, scale]);
return [output];
};
export const parseSoftmaxAttributes: OperatorInitialization<SoftmaxAttributes> =
(node: Graph.Node): SoftmaxAttributes => createAttributeWithCacheKey({axis: node.attributes.getInt('axis', 1)});
/**
* Create a texture that contains the maximum value of each of the 'N' rows
*/
const createComputeMaxProgramInfo =
// eslint-disable-next-line @typescript-eslint/naming-convention
(inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number, outputShape: number[]):
ProgramInfo => {
const [textureWidth, textureHeight] =
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
const rank = outputShape.length;
(inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number,
outputShape: number[]): ProgramInfo => {
const [textureWidth, textureHeight] =
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
const rank = outputShape.length;
if (N < 1 || D < 1) {
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
}
if (logicalRowCount < 1 || featureCount < 1) {
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
}
if (outputShape.length !== 1) {
throw new Error('Dimensionality of the output should be 1');
}
if (outputShape.length !== 1) {
throw new Error('Dimensionality of the output should be 1');
}
if (outputShape[0] !== N) {
throw new Error('Shape of the output should be equal to logical row count');
}
if (outputShape[0] !== logicalRowCount) {
throw new Error('Shape of the output should be equal to logical row count');
}
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
float process(int[${rank}] indices) {
int logical_row_start_offset = indices[0] * ${D};
int logical_row_start_offset = indices[0] * ${featureCount};
float max = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset, ${textureWidth},
${textureHeight} )));
for(int i=1; i<${D}; ++i)
for(int i=1; i<${featureCount}; ++i)
{
float current = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
${textureWidth}, ${textureHeight})));
@ -103,25 +165,24 @@ const createComputeMaxProgramInfo =
return max;
}`;
return {
...softmaxComputeMaxProgramMetadata,
output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked},
shaderSource
};
};
return {
...softmaxComputeMaxProgramMetadata,
output: {dims: outputShape, type: input.type, textureType: TextureType.unpacked},
shaderSource
};
};
/**
* Create a texture that contains the normalization factor for each of the 'N' rows
*/
const createComputScaleProgramInfo =
// eslint-disable-next-line @typescript-eslint/naming-convention
(inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number,
(inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number,
maxElementPerLogicalRow: readonly number[], outputShape: number[]): ProgramInfo => {
const [textureWidth, textureHeight] =
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
const rank = outputShape.length;
if (N < 1 || D < 1) {
if (logicalRowCount < 1 || featureCount < 1) {
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
}
@ -129,7 +190,7 @@ const createComputScaleProgramInfo =
throw new Error('Dimensionality of the output should be 1');
}
if (outputShape[0] !== N) {
if (outputShape[0] !== logicalRowCount) {
throw new Error('Shape of the output should be equal to logical row count');
}
@ -137,18 +198,18 @@ const createComputScaleProgramInfo =
throw new Error('Dimensionality of the intermediate results should be 1');
}
if (maxElementPerLogicalRow[0] !== N) {
if (maxElementPerLogicalRow[0] !== logicalRowCount) {
throw new Error('Shape of the intermediate results should be equal to logical row count');
}
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
float process(int[${rank}] indices) {
int logical_row_start_offset = indices[0] * ${D};
int logical_row_start_offset = indices[0] * ${featureCount};
float norm_factor = 0.0;
float max = _Max(indices);
for(int i=0; i<${D}; ++i)
for(int i=0; i<${featureCount}; ++i)
{
norm_factor += exp(getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i,
${textureWidth}, ${textureHeight}))) - max);
@ -164,14 +225,13 @@ const createComputScaleProgramInfo =
};
const createSoftMaxProgramInfo =
// eslint-disable-next-line @typescript-eslint/naming-convention
(inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number,
(inferenceHandler: WebGLInferenceHandler, input: Tensor, logicalRowCount: number, featureCount: number,
maxElementPerLogicalRow: readonly number[], normalizationPerLogicalRow: readonly number[]): ProgramInfo => {
const [textureWidth, textureHeight] =
inferenceHandler.calculateTextureWidthAndHeight(input.dims, TextureType.unpacked);
const rank = input.dims.length;
if (N < 1 || D < 1) {
if (logicalRowCount < 1 || featureCount < 1) {
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
}
@ -179,7 +239,7 @@ const createSoftMaxProgramInfo =
throw new Error('Dimensionality of the intermediate results should be 1');
}
if (maxElementPerLogicalRow[0] !== N || normalizationPerLogicalRow[0] !== N) {
if (maxElementPerLogicalRow[0] !== logicalRowCount || normalizationPerLogicalRow[0] !== logicalRowCount) {
throw new Error('Shape of the intermediate results should be equal to logical row count');
}
@ -191,7 +251,7 @@ const createSoftMaxProgramInfo =
//determine the logical row for this index
int logical_row_index[1];
logical_row_index[0] = offset / ${D};
logical_row_index[0] = offset / ${featureCount};
float norm_factor = _Norm(logical_row_index);

View file

@ -137,13 +137,13 @@
"test_sigmoid_example",
"test_sin_example",
"test_sin",
"v{7,8,9,10,11,12}/test_softmax_axis_0",
"v{7,8,9,10,11,12}/test_softmax_axis_1",
"v{7,8,9,10,11,12}/test_softmax_axis_2",
"v{7,8,9,10,11,12}/test_softmax_default_axis",
"v{7,8,9,10,11,12}/test_softmax_example",
"test_softmax_axis_0",
"test_softmax_axis_1",
"test_softmax_axis_2",
"test_softmax_default_axis",
"test_softmax_example",
{
"name": "v{7,8,9,10,11,12}/test_softmax_large_number",
"name": "test_softmax_large_number",
"condition": "^((?!iOS).)*$" // does NOT contains 'iOS': large number cannot be handled in a half_float environment
},
"test_sub_bcast",