2021-04-27 07:04:25 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../../../attribute-with-cache-key';
|
|
|
|
|
import { Graph } from '../../../graph';
|
|
|
|
|
import { OperatorImplementation, OperatorInitialization } from '../../../operators';
|
|
|
|
|
import { Tensor } from '../../../tensor';
|
|
|
|
|
import { getGlsl } from '../glsl-source';
|
|
|
|
|
import { WebGLInferenceHandler } from '../inference-handler';
|
|
|
|
|
import { ProgramInfo, TextureType } from '../types';
|
2021-04-27 07:04:25 +00:00
|
|
|
|
2021-08-12 19:30:49 +00:00
|
|
|
export interface BatchNormalizationAttributes extends AttributeWithCacheKey {
|
|
|
|
|
epsilon: number;
|
|
|
|
|
momentum: number;
|
|
|
|
|
spatial: number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const batchNormalizationProgramMetadata = {
|
|
|
|
|
name: 'BatchNormalization',
|
|
|
|
|
inputNames: ['A', 'Scale', 'B', 'Mean', 'Variance'],
|
2024-08-14 23:51:22 +00:00
|
|
|
inputTypes: [
|
|
|
|
|
TextureType.unpacked,
|
|
|
|
|
TextureType.unpacked,
|
|
|
|
|
TextureType.unpacked,
|
|
|
|
|
TextureType.unpacked,
|
|
|
|
|
TextureType.unpacked,
|
|
|
|
|
],
|
2021-08-12 19:30:49 +00:00
|
|
|
};
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
export const batchNormalization: OperatorImplementation<BatchNormalizationAttributes> = (
|
|
|
|
|
inferenceHandler: WebGLInferenceHandler,
|
|
|
|
|
inputs: Tensor[],
|
|
|
|
|
attributes: BatchNormalizationAttributes,
|
|
|
|
|
): Tensor[] => {
|
|
|
|
|
validateInputs(inputs);
|
|
|
|
|
const output = inferenceHandler.run(
|
|
|
|
|
{
|
|
|
|
|
...batchNormalizationProgramMetadata,
|
|
|
|
|
cacheHint: attributes.cacheKey,
|
|
|
|
|
get: () => createBatchNormalizationProgramInfo(inferenceHandler, inputs, attributes),
|
|
|
|
|
},
|
|
|
|
|
inputs,
|
|
|
|
|
);
|
|
|
|
|
return [output];
|
|
|
|
|
};
|
2021-08-12 19:30:49 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
export const parseBatchNormalizationAttributes: OperatorInitialization<BatchNormalizationAttributes> = (
|
|
|
|
|
node: Graph.Node,
|
|
|
|
|
): BatchNormalizationAttributes => {
|
|
|
|
|
const epsilon = node.attributes.getFloat('epsilon', 1e-5);
|
|
|
|
|
const momentum = node.attributes.getFloat('momentum', 0.9);
|
|
|
|
|
const spatial = node.attributes.getInt('spatial', 1);
|
|
|
|
|
return createAttributeWithCacheKey({ epsilon, momentum, spatial });
|
|
|
|
|
};
|
2021-08-12 19:30:49 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
const createBatchNormalizationProgramInfo = (
|
|
|
|
|
inferenceHandler: WebGLInferenceHandler,
|
|
|
|
|
inputs: Tensor[],
|
|
|
|
|
attributes: BatchNormalizationAttributes,
|
|
|
|
|
): ProgramInfo => {
|
|
|
|
|
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
|
|
|
|
const rank = inputs[0].dims.length;
|
|
|
|
|
const [scaleWidth, scaleHeight] = inferenceHandler.calculateTextureWidthAndHeight(
|
|
|
|
|
inputs[1].dims,
|
|
|
|
|
TextureType.unpacked,
|
|
|
|
|
);
|
|
|
|
|
const shaderSource = `
|
2021-08-12 19:30:49 +00:00
|
|
|
float process(int[${rank}] indices) {
|
|
|
|
|
vec2 position = offsetToCoords(indices[1], ${scaleWidth}, ${scaleHeight});
|
|
|
|
|
float scale = getColorAsFloat(${glsl.texture2D}(Scale, position));
|
|
|
|
|
float mean = getColorAsFloat(${glsl.texture2D}(Mean, position));
|
|
|
|
|
float variance = getColorAsFloat(${glsl.texture2D}(Variance, position));
|
|
|
|
|
float b = getColorAsFloat(${glsl.texture2D}(B, position));
|
|
|
|
|
|
|
|
|
|
return scale * ( (_A(indices) - mean) / sqrt(variance + float(${attributes.epsilon})) ) + b;
|
|
|
|
|
}`;
|
2024-08-14 23:51:22 +00:00
|
|
|
return {
|
|
|
|
|
...batchNormalizationProgramMetadata,
|
|
|
|
|
output: { dims: inputs[0].dims, type: inputs[0].type, textureType: TextureType.unpacked },
|
|
|
|
|
shaderSource,
|
|
|
|
|
};
|
|
|
|
|
};
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
const validateInputs = (inputs: Tensor[]): void => {
|
|
|
|
|
if (!inputs || inputs.length !== 5) {
|
|
|
|
|
throw new Error('BatchNormalization requires 5 inputs.');
|
2021-04-27 07:04:25 +00:00
|
|
|
}
|
2021-08-12 19:30:49 +00:00
|
|
|
|
|
|
|
|
const X = inputs[0];
|
|
|
|
|
const scale = inputs[1];
|
|
|
|
|
const B = inputs[2];
|
|
|
|
|
const mean = inputs[3];
|
|
|
|
|
const var_ = inputs[4];
|
|
|
|
|
|
|
|
|
|
// input should atleast have three dimensions - N,C,dim1,...,dimn
|
|
|
|
|
// other inputs can have only one dimensions
|
2024-08-14 23:51:22 +00:00
|
|
|
if (
|
|
|
|
|
X.dims.length < 3 ||
|
|
|
|
|
scale.dims.length !== 1 ||
|
|
|
|
|
B.dims.length !== 1 ||
|
|
|
|
|
mean.dims.length !== 1 ||
|
|
|
|
|
var_.dims.length !== 1
|
|
|
|
|
) {
|
2021-08-12 19:30:49 +00:00
|
|
|
throw new Error('invalid input shape.');
|
2021-04-27 07:04:25 +00:00
|
|
|
}
|
2024-08-14 23:51:22 +00:00
|
|
|
if (
|
|
|
|
|
scale.dims[0] !== X.dims[1] ||
|
|
|
|
|
B.dims[0] !== X.dims[1] ||
|
|
|
|
|
mean.dims[0] !== X.dims[1] ||
|
|
|
|
|
var_.dims[0] !== X.dims[1]
|
|
|
|
|
) {
|
2021-08-12 19:30:49 +00:00
|
|
|
throw new Error('invalid input shape.');
|
|
|
|
|
}
|
2024-08-14 23:51:22 +00:00
|
|
|
if (
|
|
|
|
|
(X.type !== 'float32' && X.type !== 'float64') ||
|
|
|
|
|
(scale.type !== 'float32' && scale.type !== 'float64') ||
|
|
|
|
|
(B.type !== 'float32' && B.type !== 'float64') ||
|
|
|
|
|
(mean.type !== 'float32' && mean.type !== 'float64') ||
|
|
|
|
|
(var_.type !== 'float32' && var_.type !== 'float64')
|
|
|
|
|
) {
|
2021-08-12 19:30:49 +00:00
|
|
|
throw new Error('invalid input tensor types.');
|
|
|
|
|
}
|
|
|
|
|
};
|