mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Description This PR adds `BatchNormalization` with `float` support. Some Todos: 1. all inputs don't have same data type. For example, x/y is float16, but bias/scale is float32 or double. 2. training mode support. We see many models are using `BatchNormalization` ops. However, due to the missing in jsep, all of them run on cpu, which result very poor performance. With this PR's support, densenet-9 model becomes 20.29 ms from 250.69 ms.
150 lines
6.8 KiB
TypeScript
150 lines
6.8 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {env} from 'onnxruntime-common';
|
|
|
|
import {TensorView} from '../../tensor-view';
|
|
import {ShapeUtil} from '../../util';
|
|
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
|
import {ComputeContext, ProgramInfo} from '../types';
|
|
|
|
import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
|
|
|
|
export interface BatchNormAttributes extends AttributeWithCacheKey {
|
|
readonly epsilon: number;
|
|
readonly momentum: number;
|
|
readonly spatial: boolean;
|
|
readonly trainingMode: boolean;
|
|
readonly format: 'NHWC'|'NCHW';
|
|
readonly outputCount: number;
|
|
}
|
|
|
|
const validateInputs = (inputs: readonly TensorView[], attributes: BatchNormAttributes): void => {
|
|
if (!inputs || inputs.length !== 5) {
|
|
throw new Error('BatchNormalization requires 5 inputs');
|
|
}
|
|
|
|
const checkShapeEqual = (actual: readonly number[], expected: readonly number[], message: string) => {
|
|
const r = expected.length;
|
|
if (r !== actual.length) {
|
|
throw new Error(`${message}: num dimensions != ${r}`);
|
|
}
|
|
expected.forEach((v, i) => {
|
|
if (v !== actual[i]) {
|
|
throw new Error(`${message}: dim[${i}] do not match`);
|
|
}
|
|
});
|
|
};
|
|
|
|
if (inputs[0].dims.length > 1) {
|
|
const shape = attributes.format === 'NHWC' ?
|
|
(attributes.spatial ? inputs[0].dims.slice(-1) :
|
|
inputs[0].dims.slice(-1).concat(inputs[0].dims.slice(1, inputs[0].dims.length - 1))) :
|
|
inputs[0].dims.slice(1, attributes.spatial ? 2 : undefined);
|
|
checkShapeEqual(inputs[1].dims, shape, 'Invalid input scale');
|
|
checkShapeEqual(inputs[2].dims, shape, 'Invalid input B');
|
|
checkShapeEqual(inputs[3].dims, shape, 'Invalid input mean');
|
|
checkShapeEqual(inputs[4].dims, shape, 'Invalid input var');
|
|
} else {
|
|
checkShapeEqual(inputs[1].dims, [1], 'Invalid input scale');
|
|
checkShapeEqual(inputs[2].dims, [1], 'Invalid input B');
|
|
checkShapeEqual(inputs[3].dims, [1], 'Invalid input mean');
|
|
checkShapeEqual(inputs[4].dims, [1], 'Invalid input var');
|
|
}
|
|
};
|
|
|
|
const createBatchNormInferenceProgramInfo =
|
|
(inputs: readonly TensorView[], attributes: BatchNormAttributes): ProgramInfo => {
|
|
const {epsilon, spatial, format} = attributes;
|
|
const yShape = inputs[0].dims;
|
|
const components = spatial ? getMaxComponents(yShape[yShape.length - 1]) : 1;
|
|
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
|
|
const outputSize = ShapeUtil.size(yShape) / components;
|
|
// Only support uniforms for opset version >= 9 (spatial = true).
|
|
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial;
|
|
const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
|
|
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
|
|
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);
|
|
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims, cComponents);
|
|
const inputMean = inputVariable('inputMean', inputs[3].dataType, inputs[3].dims, cComponents);
|
|
const inputVar = inputVariable('inputVar', inputs[4].dataType, inputs[4].dims, cComponents);
|
|
const y = outputVariable('y', inputs[0].dataType, shapeOrRank, components);
|
|
// TODO: support inputs with different data type. Current we need to make sure all inputs have the same data type.
|
|
// Otherwise, the shader compilation will fail.
|
|
const calcCOffset = (): string => {
|
|
let cOffset = '';
|
|
if (spatial) {
|
|
cOffset = `let cOffset = ${
|
|
yShape.length === 1 ? '0u' :
|
|
format === 'NHWC' ? `outputIndices[${yShape.length - 1}] / ${components}` :
|
|
'outputIndices[1]'};`;
|
|
} else {
|
|
if (format === 'NCHW') {
|
|
cOffset = `
|
|
${y.indicesSet('outputIndices', '0', '0')}
|
|
let cOffset = ${y.indicesToOffset('outputIndices')};`;
|
|
} else {
|
|
// update C channel.
|
|
cOffset = `var cIndices = ${scale.type.indices}(0);
|
|
cIndices[0] = outputIndices[${yShape.length - 1}];`;
|
|
// update D1 x ... x Dn channels.
|
|
for (let i = 1; i < scale.rank; i++) {
|
|
cOffset += `cIndices[${i}] = outputIndices[${i}];`;
|
|
}
|
|
cOffset += `let cOffset = ${scale.indicesToOffset('cIndices')};`;
|
|
}
|
|
}
|
|
return cOffset;
|
|
};
|
|
const getInferenceModeShaderSource = (helper: ShaderHelper) => `
|
|
const epsilon = ${epsilon};
|
|
${helper.registerUniform('outputSize', 'u32').declareVariables(x, scale, bias, inputMean, inputVar, y)}
|
|
${helper.mainStart()}
|
|
${helper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
|
var outputIndices = ${y.offsetToIndices(`global_idx * ${components}`)};
|
|
${calcCOffset()}
|
|
let scale = ${scale.getByOffset('cOffset')};
|
|
let bias = ${bias.getByOffset('cOffset')};
|
|
let inputMean = ${inputMean.getByOffset('cOffset')};
|
|
let inputVar = ${inputVar.getByOffset('cOffset')};
|
|
let x = ${x.getByOffset('global_idx')};
|
|
let value = (x - inputMean) / sqrt(inputVar + epsilon) * scale + bias;
|
|
${y.setByOffset('global_idx', 'value')}
|
|
}`;
|
|
return {
|
|
name: 'BatchNormalization',
|
|
shaderCache: {
|
|
hint: `${attributes.epsilon}_${attributes.format}_${spatial}_${components}`,
|
|
inputDependencies: useShapesUniforms ? ['rank', 'type', 'type', 'type', 'type'] : undefined,
|
|
},
|
|
getShaderSource: getInferenceModeShaderSource,
|
|
getRunData: () => ({
|
|
outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}],
|
|
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
|
programUniforms: useShapesUniforms ?
|
|
[
|
|
{type: 'uint32', data: outputSize},
|
|
...createTensorShapeVariables(yShape),
|
|
] :
|
|
[
|
|
{type: 'uint32', data: outputSize},
|
|
],
|
|
}),
|
|
};
|
|
};
|
|
|
|
export const parseBatchNormAttributes = (attributes: Record<string, unknown>): BatchNormAttributes =>
|
|
createAttributeWithCacheKey(attributes as Omit<BatchNormAttributes, keyof AttributeWithCacheKey>);
|
|
|
|
export const batchNorm = (context: ComputeContext, attributes: Record<string, unknown>): void => {
|
|
const {inputs, outputCount} = context;
|
|
const updatedAttributes = parseBatchNormAttributes({...attributes, outputCount});
|
|
if (env.webgpu.validateInputContent) {
|
|
validateInputs(inputs, updatedAttributes);
|
|
}
|
|
if (attributes.trainingMode) {
|
|
throw new Error('BatchNormalization trainingMode is not supported yet.');
|
|
} else {
|
|
context.compute(createBatchNormInferenceProgramInfo(inputs, updatedAttributes));
|
|
}
|
|
};
|