onnxruntime/js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts
Yulong Wang 4ebc9c3b5e
[JS] onnxruntime-web (#7394)
* add web

* add script and test

* fix lint

* add test/data/ops

* add test/data/node/ to gitignore

* modify scripts

* add onnxjs

* fix tests

* fix test-runner

* fix sourcemap

* fix onnxjs profiling

* update test list

* update README

* resolve comments

* set wasm as default backend

* rename package

* update copyright header

* do not use class "Buffer" in browser context

* revise readme
2021-04-27 00:04:25 -07:00

149 lines
5.6 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InstanceNormalization} from '../../../ops/instance-normalization';
import {Tensor} from '../../../tensor';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {Artifact, ProgramInfo, RunData, TextureLayout} from '../types';
export class WebGLInstanceNormalization extends InstanceNormalization {
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
if (!this.artifacts) {
this.artifacts = [];
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
programInfos.forEach((pi) => {
const artifact = inferenceHandler.session.programManager.build(pi);
this.artifacts.push(artifact);
});
}
const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs);
runDatas.forEach((v, i) => inferenceHandler.session.programManager.run(this.artifacts[i], v));
return [runDatas[1].outputTextureData.tensor];
}
checkInputTypes(inputs: Tensor[]): boolean {
if (!super.checkInputTypes(inputs)) {
return false;
}
if (inputs[0].dims.length !== 4) {
// currently webgl implementation only support 4-D input.
return false;
}
return true;
}
createMeanAndVarianceProgramInfo(inferenceHandler: WebGLInferenceHandler, xLayout: TextureLayout): ProgramInfo {
const xDims = xLayout.shape;
const channel = xDims[1];
const channelSize = xDims[2] * xDims[3];
const outputShape = [xDims[0], channel];
const outputUnpackedShape = [xDims[0], channel * 4];
const shaderSource = `
vec4 process(int[2] indices) {
vec4 v = vec4(0.0);
int a[4];
a[0] = indices[0];
a[1] = indices[1];
float temp = 0.0;
for(int a2=0; a2<${xDims[2]}; a2++) {
a[2] = a2;
for(int a3=0; a3<${xDims[3]}; a3++) {
a[3] = a3;
float x = _X(a);
temp += x;
}
}
float mean = temp / float(${channelSize});
temp = 0.0;
for(int a2=0; a2<${xDims[2]}; a2++) {
a[2] = a2;
for(int a3=0; a3<${xDims[3]}; a3++) {
a[3] = a3;
float x = _X(a);
temp += (x - mean) * (x - mean);
}
}
v.r = mean;
v.g = temp / float(${channelSize});
return v;
}`;
return {
inputLayouts: [xLayout],
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape, 4, outputUnpackedShape),
samplers: ['X'],
shaderSource,
};
}
createComputOutputProgramInfo(
inferenceHandler: WebGLInferenceHandler, xLayout: TextureLayout, scaleLayout: TextureLayout,
bLayout: TextureLayout, meanAndVarianceLayout: TextureLayout): ProgramInfo {
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
const shaderSource = `
vec4 get_MeanAndVariance(int[2] mv) {
int offset = indicesToOffset_MeanAndVariance(mv);
vec2 coords = offsetToCoords(offset, ${meanAndVarianceLayout.width}, ${meanAndVarianceLayout.height});
return ${glsl.texture2D}(MeanAndVariance, coords);
}
float process(int[4] indices) {
int mv[2];
mv[0] = indices[0];
mv[1] = indices[1];
vec4 mean_and_variance = get_MeanAndVariance(mv);
float mean = mean_and_variance.r;
float variance = mean_and_variance.g;
int sb[1];
sb[0] = indices[1];
float scale = _Scale(sb);
float b = _B(sb);
return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b;
}`;
return {
inputLayouts: [xLayout, meanAndVarianceLayout, scaleLayout, bLayout],
outputLayout: inferenceHandler.createTextureLayoutFromShape(xLayout.shape),
samplers: ['X', 'MeanAndVariance', 'Scale', 'B'],
variables: [{name: 'epsilon', type: 'float'}],
shaderSource,
};
}
createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] {
const xLayout = inferenceHandler.getOrCreateTextureLayout(inputs[0]);
const scaleLayout = inferenceHandler.getOrCreateTextureLayout(inputs[1]);
const bLayout = inferenceHandler.getOrCreateTextureLayout(inputs[2]);
const meanAndVarianceProgramInfo = this.createMeanAndVarianceProgramInfo(inferenceHandler, xLayout);
const computeOutputProgramInfo = this.createComputOutputProgramInfo(
inferenceHandler, xLayout, scaleLayout, bLayout, meanAndVarianceProgramInfo.outputLayout);
const programInfos: ProgramInfo[] = [meanAndVarianceProgramInfo, computeOutputProgramInfo];
return programInfos;
}
createRunDatas(inferenceHandler: WebGLInferenceHandler, programInfos: ProgramInfo[], inputs: Tensor[]): RunData[] {
const dataType = inputs[0].type;
const inputTD = inferenceHandler.getOrCreateTextureData(inputs[0], programInfos[0].inputLayouts[0]);
const scaleTD = inferenceHandler.getOrCreateTextureData(inputs[1], programInfos[1].inputLayouts[2]);
const bTD = inferenceHandler.getOrCreateTextureData(inputs[2], programInfos[1].inputLayouts[3]);
const runDatas: RunData[] = [];
runDatas.push({
inputTextureDatas: [inputTD],
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[0].outputLayout, dataType),
uniformData: {}
});
runDatas.push({
inputTextureDatas: [inputTD, runDatas[0].outputTextureData, scaleTD, bTD],
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[1].outputLayout, dataType),
uniformData: {'epsilon': this.epsilon}
});
return runDatas;
}
protected artifacts: Artifact[];
}