mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
156 lines
5.8 KiB
TypeScript
156 lines
5.8 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {Graph} from '../../../graph';
|
|
import {OperatorImplementation, OperatorInitialization} from '../../../operators';
|
|
import {Tensor} from '../../../tensor';
|
|
import {getGlsl} from '../glsl-source';
|
|
import {WebGLInferenceHandler} from '../inference-handler';
|
|
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';
|
|
|
|
export const instanceNormalization: OperatorImplementation<number> =
|
|
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], epsilon: number): Tensor[] => {
|
|
validateInputs(inputs);
|
|
|
|
const meanAndVariance = inferenceHandler.run(createMeanAndVarianceProgramInfoLoader(inputs[0]), inputs);
|
|
const output = inferenceHandler.run(
|
|
createComputeOutputProgramInfoLoader(inferenceHandler, inputs[0], epsilon, meanAndVariance.dims),
|
|
[inputs[0], meanAndVariance, inputs[1], inputs[2]]);
|
|
return [output];
|
|
};
|
|
|
|
export const parseInstanceNormalizationAttributes: OperatorInitialization<number> = (node: Graph.Node): number =>
|
|
node.attributes.getFloat('epsilon', 1e-5);
|
|
|
|
const meanAndVarianceProgramMetadata = {
|
|
name: 'InstanceNormalization_MeanAndVariance',
|
|
inputNames: ['X'],
|
|
inputTypes: [TextureType.unpacked],
|
|
};
|
|
|
|
const createMeanAndVarianceProgramInfo = (metadata: ProgramMetadata, input: Tensor): ProgramInfo => {
|
|
const xDims = input.dims.slice();
|
|
const channel = xDims[1];
|
|
const channelSize = xDims[2] * xDims[3];
|
|
const outputShape = [xDims[0], channel];
|
|
|
|
const shaderSource = `
|
|
vec4 process(int[2] indices) {
|
|
vec4 v = vec4(0.0);
|
|
int a[4];
|
|
a[0] = indices[0];
|
|
a[1] = indices[1];
|
|
float temp = 0.0;
|
|
for(int a2=0; a2<${xDims[2]}; a2++) {
|
|
a[2] = a2;
|
|
for(int a3=0; a3<${xDims[3]}; a3++) {
|
|
a[3] = a3;
|
|
float x = _X(a);
|
|
temp += x;
|
|
}
|
|
}
|
|
float mean = temp / float(${channelSize});
|
|
temp = 0.0;
|
|
for(int a2=0; a2<${xDims[2]}; a2++) {
|
|
a[2] = a2;
|
|
for(int a3=0; a3<${xDims[3]}; a3++) {
|
|
a[3] = a3;
|
|
float x = _X(a);
|
|
temp += (x - mean) * (x - mean);
|
|
}
|
|
}
|
|
v.r = mean;
|
|
v.g = temp / float(${channelSize});
|
|
|
|
return v;
|
|
}`;
|
|
return {
|
|
...metadata,
|
|
output: {dims: outputShape, type: input.type, textureType: TextureType.packedLastDimension},
|
|
shaderSource
|
|
};
|
|
};
|
|
|
|
const createMeanAndVarianceProgramInfoLoader = (input: Tensor): ProgramInfoLoader => ({
|
|
...meanAndVarianceProgramMetadata,
|
|
get: () => createMeanAndVarianceProgramInfo(meanAndVarianceProgramMetadata, input)
|
|
});
|
|
|
|
const computeOutputProgramMetadata = {
|
|
name: 'InstanceNormalization_ComputeOutput',
|
|
inputNames: ['X', 'MeanAndVariance', 'Scale', 'B'],
|
|
inputTypes: [TextureType.unpacked, TextureType.packedLastDimension, TextureType.unpacked, TextureType.unpacked],
|
|
};
|
|
|
|
const createComputeOutputProgramInfo =
|
|
(inferenceHandler: WebGLInferenceHandler, metadata: ProgramMetadata, input: Tensor, epsilon: number,
|
|
meanAndVarianceShape: readonly number[]): ProgramInfo => {
|
|
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
|
const [textureWidth, textureHeight] =
|
|
inferenceHandler.calculateTextureWidthAndHeight(meanAndVarianceShape, TextureType.packedLastDimension);
|
|
const [meanAndVarianceWidth, meanAndVarianceHeight] = [textureWidth / 4, textureHeight];
|
|
const shaderSource = `
|
|
vec4 get_MeanAndVariance(int[2] mv) {
|
|
int offset = indicesToOffset_MeanAndVariance(mv);
|
|
vec2 coords = offsetToCoords(offset, ${meanAndVarianceWidth}, ${meanAndVarianceHeight});
|
|
return ${glsl.texture2D}(MeanAndVariance, coords);
|
|
}
|
|
|
|
float process(int[4] indices) {
|
|
int mv[2];
|
|
mv[0] = indices[0];
|
|
mv[1] = indices[1];
|
|
vec4 mean_and_variance = get_MeanAndVariance(mv);
|
|
float mean = mean_and_variance.r;
|
|
float variance = mean_and_variance.g;
|
|
|
|
int sb[1];
|
|
sb[0] = indices[1];
|
|
float scale = _Scale(sb);
|
|
float b = _B(sb);
|
|
|
|
return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b;
|
|
}`;
|
|
return {
|
|
...metadata,
|
|
output: {dims: input.dims, type: input.type, textureType: TextureType.unpacked},
|
|
variables: [{name: 'epsilon', type: 'float', data: epsilon}],
|
|
shaderSource
|
|
};
|
|
};
|
|
|
|
const createComputeOutputProgramInfoLoader =
|
|
(inferenceHandler: WebGLInferenceHandler, input: Tensor, epsilon: number, meanAndVarianceShape: readonly number[]):
|
|
ProgramInfoLoader => {
|
|
const metadata = {...computeOutputProgramMetadata, cacheHint: `${epsilon}`};
|
|
return {
|
|
...metadata,
|
|
get: () => createComputeOutputProgramInfo(inferenceHandler, metadata, input, epsilon, meanAndVarianceShape)
|
|
};
|
|
};
|
|
|
|
const validateInputs = (inputs: Tensor[]): void => {
|
|
if (!inputs || inputs.length !== 3) {
|
|
throw new Error('InstanceNormalization requires 3 inputs.');
|
|
}
|
|
|
|
const X = inputs[0];
|
|
const scale = inputs[1];
|
|
const B = inputs[2];
|
|
|
|
// input should at least have three dimensions - N,C,dim1,...,dimn
|
|
// other inputs can have only one dimensions
|
|
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
|
|
throw new Error('Invalid input shape.');
|
|
}
|
|
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
|
|
throw new Error('Input shapes are mismatched.');
|
|
}
|
|
if ((X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') ||
|
|
(B.type !== 'float32' && B.type !== 'float64')) {
|
|
throw new Error('Invalid input type.');
|
|
}
|
|
if (inputs[0].dims.length !== 4) {
|
|
throw new Error('Only support 4-D input shape.');
|
|
}
|
|
};
|