onnxruntime/js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts
Yulong Wang 4ebc9c3b5e
[JS] onnxruntime-web (#7394)
* add web

* add script and test

* fix lint

* add test/data/ops

* add test/data/node/ to gitignore

* modify scripts

* add onnxjs

* fix tests

* fix test-runner

* fix sourcemap

* fix onnxjs profiling

* update test list

* update README

* resolve comments

* set wasm as default backend

* rename package

* update copyright header

* do not use class "Buffer" in browser context

* revise readme
2021-04-27 00:04:25 -07:00

129 lines
4.4 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {GlslContext, GlslLib, GlslLibRoutineNode, TopologicalSortGlslRoutines} from './glsl-definitions';
import {replaceInlines} from './glsl-function-inliner';
import {glslRegistry} from './glsl-registered-libs';
import {getDefaultFragShaderMain, getFragShaderPreamble} from './glsl-source';
import {ProgramInfo, VariableInfo} from './types';
import {WebGLContext} from './webgl-context';
/**
* Preprocessor for the additions to the GLSL language
* It deals with:
* @include directives
* @inline
* Loop unrolling (not implemented)
* Macro resolution (not implemented)
*/
export class GlslPreprocessor {
readonly context: GlslContext;
readonly libs: {[name: string]: GlslLib} = {};
readonly glslLibRoutineDependencyGraph: {[routineName: string]: GlslLibRoutineNode} = {};
constructor(glContext: WebGLContext, programInfo: ProgramInfo) {
this.context = new GlslContext(glContext, programInfo);
// construct GlslLibs
Object.keys(glslRegistry).forEach((name: string) => {
const lib = new glslRegistry[name](this.context);
this.libs[name] = lib;
});
// construct GlslRoutineDependencyGraph
const map = this.glslLibRoutineDependencyGraph;
for (const libName in this.libs) {
const lib = this.libs[libName];
const routinesInLib = lib.getFunctions();
for (const routine in routinesInLib) {
const key = libName + '.' + routine;
let currentNode: GlslLibRoutineNode;
if (map[key]) {
currentNode = map[key];
currentNode.routineBody = routinesInLib[routine].routineBody;
} else {
currentNode = new GlslLibRoutineNode(key, routinesInLib[routine].routineBody);
map[key] = currentNode;
}
const dependencies = routinesInLib[routine].dependencies;
if (dependencies) {
for (let i = 0; i < dependencies.length; ++i) {
if (!map[dependencies[i]]) {
const node = new GlslLibRoutineNode(dependencies[i]);
map[dependencies[i]] = node;
currentNode.addDependency(node);
} else {
currentNode.addDependency(map[dependencies[i]]);
}
}
}
}
}
}
preprocess(): string {
const programInfo = this.context.programInfo;
let source = programInfo.shaderSource;
// append main() function
if (!this.context.programInfo.hasMain) {
source = `${source}
${getDefaultFragShaderMain(this.context.glContext.version, programInfo.outputLayout.shape.length)}`;
}
// replace inlines
source = replaceInlines(source);
// concat final source string
return `${getFragShaderPreamble(this.context.glContext.version)}
${this.getUniforms(programInfo.samplers, programInfo.variables)}
${this.getImports(source)}
${source}`;
}
protected getImports(script: string): string {
const routinesIncluded = this.selectGlslLibRoutinesToBeIncluded(script);
if (routinesIncluded.length === 0) {
return '';
}
let routines = '';
for (let i = 0; i < routinesIncluded.length; ++i) {
if (routinesIncluded[i].routineBody) {
routines += routinesIncluded[i].routineBody + '\n';
} else {
throw new Error(`Missing body for the Glsl Library routine: ${routinesIncluded[i].name}`);
}
}
return routines;
}
private selectGlslLibRoutinesToBeIncluded(script: string): GlslLibRoutineNode[] {
const nodes: GlslLibRoutineNode[] = [];
Object.keys(this.glslLibRoutineDependencyGraph).forEach(classAndRoutine => {
const routine = classAndRoutine.split('.')[1];
if (script.indexOf(routine) !== -1) {
nodes.push(this.glslLibRoutineDependencyGraph[classAndRoutine]);
}
});
return TopologicalSortGlslRoutines.returnOrderedNodes(nodes);
}
protected getUniforms(samplers?: string[], variables?: VariableInfo[]): string {
const uniformLines: string[] = [];
if (samplers) {
for (const sampler of samplers) {
uniformLines.push(`uniform sampler2D ${sampler};`);
}
}
if (variables) {
for (const variable of variables) {
uniformLines.push(
`uniform ${variable.type} ${variable.name}${variable.arrayLength ? `[${variable.arrayLength}]` : ''};`);
}
}
return uniformLines.join('\n');
}
}