mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
[JS] onnxruntime-web (#7394)
* add web * add script and test * fix lint * add test/data/ops * add test/data/node/ to gitignore * modify scripts * add onnxjs * fix tests * fix test-runner * fix sourcemap * fix onnxjs profiling * update test list * update README * resolve comments * set wasm as default backend * rename package * update copyright header * do not use class "Buffer" in browser context * revise readme
This commit is contained in:
parent
d13e5b2fd9
commit
4ebc9c3b5e
191 changed files with 129412 additions and 43 deletions
|
|
@ -108,13 +108,14 @@ module.exports = {
|
|||
'jsdoc/newline-after-description': 'error',
|
||||
}
|
||||
}, {
|
||||
files: ['node/script/**/*.ts', 'node/test/**/*.ts'], rules: {
|
||||
files: ['node/script/**/*.ts', 'node/test/**/*.ts', 'web/script/**/*.ts', 'web/test/**/*.ts'], rules: {
|
||||
'@typescript-eslint/naming-convention': 'off',
|
||||
'@typescript-eslint/no-empty-function': 'off',
|
||||
'@typescript-eslint/no-explicit-any': 'off',
|
||||
'@typescript-eslint/no-require-imports': 'off',
|
||||
'@typescript-eslint/no-var-requires': 'off',
|
||||
'@typescript-eslint/no-non-null-assertion': 'off',
|
||||
'@typescript-eslint/no-unnecessary-type-assertion': 'off',
|
||||
'camelcase': 'off',
|
||||
'prefer-arrow/prefer-arrow-functions': 'off',
|
||||
'import/no-extraneous-dependencies': 'off',
|
||||
|
|
@ -124,6 +125,28 @@ module.exports = {
|
|||
'no-empty': 'off',
|
||||
'no-unused-expressions': 'off',
|
||||
}
|
||||
}, {
|
||||
files: ['web/lib/**/*.ts'], rules: {
|
||||
'no-underscore-dangle': 'off',
|
||||
}
|
||||
}, {
|
||||
files: ['web/lib/onnxjs/**/*.ts'], rules: {
|
||||
// TODO: those rules are useful. should turn on them in future (webgl refactor)
|
||||
'@typescript-eslint/no-empty-function': 'off',
|
||||
'@typescript-eslint/explicit-module-boundary-types': 'off',
|
||||
'@typescript-eslint/no-non-null-assertion': 'off',
|
||||
'@typescript-eslint/no-use-before-define': 'off',
|
||||
'@typescript-eslint/no-unnecessary-type-assertion': 'off',
|
||||
'@typescript-eslint/restrict-plus-operands': 'off',
|
||||
'import/no-internal-modules': 'off',
|
||||
'prefer-arrow/prefer-arrow-functions': 'off',
|
||||
'no-param-reassign': 'off',
|
||||
'guard-for-in': 'off'
|
||||
}
|
||||
}, {
|
||||
files: ['web/lib/wasm/binding/**/*.ts'], rules: {
|
||||
'@typescript-eslint/naming-convention': 'off'
|
||||
}
|
||||
}],
|
||||
extends: [
|
||||
'eslint:recommended',
|
||||
|
|
|
|||
34
js/.vscode/launch.json
vendored
Normal file
34
js/.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "chrome",
|
||||
"request": "attach",
|
||||
"name": "Attach to Chrome",
|
||||
"port": 9333,
|
||||
"webRoot": "${workspaceFolder}",
|
||||
"sourceMapPathOverrides": {
|
||||
"webpack://ort/*": "${webRoot}/common/*",
|
||||
"webpack:///*": "${webRoot}/web/*"
|
||||
},
|
||||
"sourceMaps": true,
|
||||
"smartStep": true
|
||||
},
|
||||
{
|
||||
"name": "Remote Browser via Webkit Adaptor",
|
||||
"type": "chrome",
|
||||
"request": "attach",
|
||||
"port": 9000,
|
||||
"webRoot": "${workspaceFolder}",
|
||||
"sourceMapPathOverrides": {
|
||||
"webpack://ort/*": "${webRoot}/common/*",
|
||||
"webpack:///*": "${webRoot}/web/*"
|
||||
},
|
||||
"sourceMaps": true,
|
||||
"smartStep": true
|
||||
}
|
||||
]
|
||||
}
|
||||
4
js/.vscode/settings.json
vendored
4
js/.vscode/settings.json
vendored
|
|
@ -26,7 +26,9 @@
|
|||
"common/lib/**/*.js.map": true,
|
||||
"common/lib/**/*.js": true,
|
||||
"node/lib/**/*.js.map": true,
|
||||
"node/lib/**/*.js": true
|
||||
"node/lib/**/*.js": true,
|
||||
"web/lib/**/*.js.map": true,
|
||||
"web/lib/**/*.js": true
|
||||
},
|
||||
"files.insertFinalNewline": true,
|
||||
"files.trimTrailingWhitespace": true,
|
||||
|
|
|
|||
27
js/README.md
27
js/README.md
|
|
@ -102,13 +102,34 @@ It should be able to consumed by from projects that uses NPM packages (through a
|
|||
|
||||
> language: typescript
|
||||
|
||||
> dependency: onnxruntime-common, onnxruntime_wasm.wasm
|
||||
> dependency: onnxruntime-common, ONNXRuntime WebAssembly
|
||||
|
||||
> folder: <ORT_ROOT>/js/web
|
||||
|
||||
NOTE: This is the successor of ONNX.js.
|
||||
This project is a library for running ONNX models on browsers. It is the successor of [ONNX.js](https://github.com/Microsoft/onnxjs).
|
||||
|
||||
<!-- TODO: update this section for onnxruntime web -->
|
||||
### Requirements
|
||||
|
||||
Node.js v12+ (recommended v14+)
|
||||
|
||||
### Build
|
||||
|
||||
1. Install NPM packages
|
||||
|
||||
1. in `/js/`, run `npm ci`.
|
||||
2. in `/js/common/`, run `npm ci`.
|
||||
3. in `/js/web/`, run `npm ci`.
|
||||
|
||||
2. Follow [instructions](https://www.onnxruntime.ai/docs/how-to/build.html#apis-and-language-bindings) for building ONNX Runtime WebAssembly.
|
||||
|
||||
3. Copy files `onnxruntime_wasm.*` from build output folder to `<ORT_ROOT>/js/web/lib/wasm/binding/`.
|
||||
|
||||
4. Use following command in folder `<ORT_ROOT>/js/web` to build:
|
||||
```
|
||||
npm run build
|
||||
```
|
||||
|
||||
### Distribution
|
||||
|
||||
It should be able to consumed by both from projects that uses NPM packages (through a Node.js folder structure of `node_modules` folder that generated by `npm install onnxruntime-web`) and from a CDN service that serves a `.min.js` file and one or multiple `.wasm` file(s).
|
||||
|
||||
|
|
|
|||
13
js/common/README.md
Normal file
13
js/common/README.md
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# ONNX Runtime JavaScript API
|
||||
|
||||
ONNX Runtime JavaScript API is a unified API for all JavaScript usages. It's dependency of the following NPM packages:
|
||||
|
||||
- onnxruntime-node
|
||||
- onnxruntime-web
|
||||
- onnxruntime-react-native
|
||||
|
||||
This package (onnxruntime-common) is not designed for using directly. Please consider to install one of the NPM packages above according to target platform.
|
||||
|
||||
## License
|
||||
|
||||
License information can be found [here](../../README.md#license).
|
||||
|
|
@ -16,8 +16,11 @@ export declare namespace SessionHandler {
|
|||
export interface SessionHandler {
|
||||
dispose(): Promise<void>;
|
||||
|
||||
readonly inputNames: string[];
|
||||
readonly outputNames: string[];
|
||||
readonly inputNames: readonly string[];
|
||||
readonly outputNames: readonly string[];
|
||||
|
||||
startProfiling(): void;
|
||||
endProfiling(): void;
|
||||
|
||||
run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
|
||||
|
|
|
|||
|
|
@ -190,6 +190,13 @@ export class InferenceSession implements InferenceSessionInterface {
|
|||
return new InferenceSession(handler);
|
||||
}
|
||||
|
||||
startProfiling(): void {
|
||||
this.handler.startProfiling();
|
||||
}
|
||||
endProfiling(): void {
|
||||
this.handler.endProfiling();
|
||||
}
|
||||
|
||||
get inputNames(): readonly string[] {
|
||||
return this.handler.inputNames;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -212,6 +212,20 @@ export interface InferenceSession {
|
|||
|
||||
//#endregion
|
||||
|
||||
//#region profiling
|
||||
|
||||
/**
|
||||
* Start profiling.
|
||||
*/
|
||||
startProfiling(): void;
|
||||
|
||||
/**
|
||||
* End profiling.
|
||||
*/
|
||||
endProfiling(): void;
|
||||
|
||||
//#endregion
|
||||
|
||||
//#region metadata
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -3,5 +3,5 @@
|
|||
"compilerOptions": {
|
||||
"declarationDir": "./types"
|
||||
},
|
||||
"include": ["lib/"]
|
||||
"include": ["lib"]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,8 +40,10 @@ function buildConfig({
|
|||
|
||||
module.exports = (env, argv) => {
|
||||
return [
|
||||
buildConfig({ format: 'umd', mode: 'development', devtool: 'inline-source-map', target: 'es5' }),
|
||||
buildConfig({ format: 'umd', suffix: '.min', target: 'es5' }),
|
||||
buildConfig({ suffix: '.es6', mode: 'development', devtool: 'inline-source-map', target: 'es6' }),
|
||||
buildConfig({ mode: 'development', devtool: 'inline-source-map', target: 'es5' }),
|
||||
buildConfig({ suffix: '.es6.min', target: 'es6' }),
|
||||
buildConfig({ suffix: '.min', target: 'es5' }),
|
||||
buildConfig({ format: 'commonjs', suffix: '.node', target: 'es5' }),
|
||||
];
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# ONNX Runtime Node.js API
|
||||
# ONNX Runtime Node.js Binding
|
||||
|
||||
ONNX Runtime Node.js binding enables Node.js applications to run ONNX model inference.
|
||||
|
||||
|
|
@ -7,13 +7,7 @@ ONNX Runtime Node.js binding enables Node.js applications to run ONNX model infe
|
|||
Install the latest stable version:
|
||||
|
||||
```
|
||||
npm install onnxruntime
|
||||
```
|
||||
|
||||
Install the latest dev version:
|
||||
|
||||
```
|
||||
npm install onnxruntime@dev
|
||||
npm install onnxruntime-node
|
||||
```
|
||||
|
||||
Refer to [Node.js samples](../../samples/nodejs/README.md) for samples and tutorials.
|
||||
|
|
@ -32,4 +26,4 @@ To use on platforms without pre-built binaries, you can build Node.js binding fr
|
|||
|
||||
## License
|
||||
|
||||
License information can be found [here](../README.md#license).
|
||||
License information can be found [here](../../README.md#license).
|
||||
|
|
|
|||
|
|
@ -25,6 +25,12 @@ class OnnxruntimeSessionHandler implements SessionHandler {
|
|||
readonly inputNames: string[];
|
||||
readonly outputNames: string[];
|
||||
|
||||
startProfiling(): void {
|
||||
// TODO: implement profiling
|
||||
}
|
||||
endProfiling(): void {
|
||||
// TODO: implement profiling
|
||||
}
|
||||
|
||||
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
|
||||
Promise<SessionHandler.ReturnType> {
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {execSync, spawnSync} from 'child_process';
|
||||
import * as fs from 'fs-extra';
|
||||
import minimist from 'minimist';
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import * as fs from 'fs-extra';
|
||||
import klawSync from 'klaw-sync';
|
||||
import * as path from 'path';
|
||||
|
|
|
|||
|
|
@ -88,12 +88,12 @@ export function run(testDataFolder: string): void {
|
|||
}
|
||||
|
||||
if (session !== null) {
|
||||
const feeds = {};
|
||||
const feeds: Record<string, Tensor> = {};
|
||||
if (inputs.length !== session.inputNames.length) {
|
||||
throw new RangeError('input length does not match name list');
|
||||
}
|
||||
for (let i = 0; i < inputs.length; i++) {
|
||||
feeds[session.inputNames[i]] = inputs[i];
|
||||
feeds[session.inputNames[i]] = inputs[i]!;
|
||||
}
|
||||
const outputs = await session.run(feeds);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,8 @@
|
|||
{
|
||||
"include": ["lib", "script", "test"],
|
||||
"compileOnSave": true,
|
||||
"extends": "../tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"module": "CommonJS",
|
||||
"moduleResolution": "node",
|
||||
"declaration": true,
|
||||
"declarationDir": "./types",
|
||||
"esModuleInterop": true,
|
||||
"target": "es2015",
|
||||
"lib": ["es2015", "ESNext.BigInt"],
|
||||
"sourceMap": true,
|
||||
"noUnusedLocals": true,
|
||||
"noImplicitReturns": true,
|
||||
"noImplicitThis": true,
|
||||
"alwaysStrict": true,
|
||||
"strictNullChecks": true,
|
||||
"noUnusedParameters": false,
|
||||
"pretty": true,
|
||||
"allowUnreachableCode": false,
|
||||
"experimentalDecorators": true,
|
||||
"downlevelIteration": true
|
||||
}
|
||||
"declarationDir": "./types"
|
||||
},
|
||||
"include": ["lib", "script", "test"]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,5 @@
|
|||
"experimentalDecorators": true,
|
||||
"downlevelIteration": true,
|
||||
"incremental": true
|
||||
},
|
||||
"exclude": ["node_modules/"]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
19
js/web/.gitignore
vendored
Normal file
19
js/web/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
/.vscode/ipch/
|
||||
|
||||
node_modules/
|
||||
types/
|
||||
dist/
|
||||
|
||||
tsconfig.tsbuildinfo
|
||||
|
||||
lib/**/*.js
|
||||
lib/**/*.js.map
|
||||
test/**/*.js
|
||||
test/**/*.js.map
|
||||
script/**/*.js
|
||||
script/**/*.js.map
|
||||
|
||||
lib/wasm/binding/**/*.wasm
|
||||
!lib/wasm/binding/**/*.d.ts
|
||||
|
||||
test/data/node/
|
||||
153
js/web/README.md
Normal file
153
js/web/README.md
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
# ONNX Runtime Web
|
||||
|
||||
ONNX Runtime Web is a Javascript library for running ONNX models on browsers and on Node.js.
|
||||
|
||||
ONNX Runtime Web has adopted WebAssembly and WebGL technologies for providing an optimized ONNX model inference runtime for both CPUs and GPUs.
|
||||
|
||||
### Why ONNX models
|
||||
|
||||
The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard for representing machine learning models. The biggest advantage of ONNX is that it allows interoperability across different open source AI frameworks, which itself offers more flexibility for AI frameworks adoption. See [Getting ONNX Models](#Getting-ONNX-models).
|
||||
|
||||
### Why ONNX Runtime Web
|
||||
|
||||
With ONNX Runtime Web, web developers can score pre-trained ONNX models directly on browsers with various benefits of reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience.
|
||||
|
||||
ONNX Runtime Web can run on both CPU and GPU. For running on CPU, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. Furthermore, ONNX Runtime Web utilizes [Web Workers](https://developer.mozilla.org/en-US/docs/Web/API/Web_Workers_API/Using_web_workers) to provide a "multi-threaded" environment to parallelize data processing. Empirical evaluation shows very promising performance gains on CPU by taking full advantage of WebAssembly and Web Workers. For running on GPUs, a popular standard for accessing GPU capabilities - WebGL is adopted. ONNX Runtime Web has further adopted several novel optimization techniques for reducing data transfer between CPU and GPU, as well as some techniques to reduce GPU processing cycles to further push the performance to the maximum.
|
||||
|
||||
See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports.
|
||||
|
||||
## Getting Started
|
||||
|
||||
There are multiple ways to use ONNX Runtime Web in a project:
|
||||
|
||||
### Using `<script>` tag
|
||||
|
||||
This is the most straightforward way to use ONNX Runtime Web. The following HTML example shows how to use it:
|
||||
|
||||
```html
|
||||
<html>
|
||||
<head> </head>
|
||||
|
||||
<body>
|
||||
<!-- Load ONNX Runtime Web -->
|
||||
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
|
||||
<!-- Code that consume ONNX Runtime Web -->
|
||||
<script>
|
||||
async function runMyModel() {
|
||||
// create a session
|
||||
const myOrtSession = await ort.InferenceSession.create(
|
||||
"./my-model.onnx"
|
||||
);
|
||||
// generate model input
|
||||
const input0 = new ort.Tensor(
|
||||
new Float32Array([1.0, 2.0, 3.0, 4.0]) /* data */,
|
||||
[2, 2] /* dims */
|
||||
);
|
||||
// execute the model
|
||||
const outputs = await myOrtSession.run({ input_0: input0 });
|
||||
// consume the output
|
||||
const outputTensor = outputs["output_0"];
|
||||
console.log(`model output tensor: ${outputTensor.data}.`);
|
||||
}
|
||||
runMyModel();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
<!-- TODO: Refer to [browser/Add](./examples/browser/add) for an example. -->
|
||||
|
||||
### Using NPM and bundling tools
|
||||
|
||||
Modern browser based applications are usually built by frameworks like [Angular](https://angular.io/), [React](https://reactjs.org/), [Vue.js](https://vuejs.org/) and so on. This solution usually builds the source code into one or more bundle file(s). The following TypeScript example shows how to use ONNX Runtime Web in an async context:
|
||||
|
||||
1. Import `Tensor` and `InferenceSession`.
|
||||
|
||||
```ts
|
||||
import { Tensor, InferenceSession } from "onnxruntime-web";
|
||||
```
|
||||
|
||||
2. Create an instance of `InferenceSession` and load ONNX model.
|
||||
|
||||
```ts
|
||||
// use the following in an async method
|
||||
const url = "./data/models/resnet/model.onnx";
|
||||
const session = await InferenceSession.create(url);
|
||||
```
|
||||
|
||||
3. Create your input Tensor(s) similar to the example below. You need to do any pre-processing required by
|
||||
your model at this stage. For that refer to the documentation of the model you have:
|
||||
|
||||
```javascript
|
||||
// creating an array of input Tensors is the easiest way. For other options see the API documentation
|
||||
const input0 = new Tensor(new Float32Array([1.0, 2.0, 3.0, 4.0]), [2, 2]);
|
||||
```
|
||||
|
||||
4. Run the model with the input Tensors. The output Tensor(s) are available once the run operation is complete:
|
||||
|
||||
```javascript
|
||||
// run this in an async method:
|
||||
// assume model's input name is 'input_0' and output name is 'output_0'
|
||||
const outputs = await session.run({ input_0: input0 });
|
||||
const outputTensor = outputs.output_0;
|
||||
```
|
||||
|
||||
5. Bundle your code. All web application frameworks offer bundling tools and instructions. Specifically, you can specify onnxruntime-web as an external dependency:
|
||||
|
||||
```js
|
||||
// a webpack example
|
||||
externals: {
|
||||
'onnxruntime-web': 'ort', // add this line in your webpack.config.js
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
so that you can consume the file `ort.min.js` from a CDN provider demonstrated as above.
|
||||
|
||||
<!-- TODO More verbose examples on how to use ONNX Runtime Web are located under the `examples` folder. For further info see [Examples](./examples/README.md) -->
|
||||
|
||||
## Documents
|
||||
|
||||
### Developers
|
||||
|
||||
For information on ONNX.js development, please check [Development](./docs/development.md)
|
||||
|
||||
For API reference, please check [API](./docs/api.md).
|
||||
|
||||
### Getting ONNX models
|
||||
|
||||
You can get ONNX models easily in multiple ways:
|
||||
|
||||
- Choose a pre-trained ONNX model from the [ONNX Model Zoo](https://github.com/onnx/models)
|
||||
- Convert models from mainstream frameworks, e.g. PyTorch, TensorFlow and Keras, by following [ONNX tutorials](https://github.com/onnx/tutorials)
|
||||
- Use your data to generate a customized ONNX model from [Azure Custom Vision service](https://docs.microsoft.com/en-us/azure/cognitive-services/Custom-Vision-Service/home)
|
||||
- [Train a custom model in AzureML](https://github.com/Azure/MachineLearningNotebooks/tree/master/training) and save it in the ONNX format
|
||||
|
||||
Learn more about ONNX
|
||||
|
||||
- [ONNX website](http://onnx.ai/)
|
||||
- [ONNX on GitHub](https://github.com/onnx/onnx)
|
||||
|
||||
### Compatibility
|
||||
|
||||
| OS/Browser | Chrome | Edge | Safari | Electron |
|
||||
| :--------------: | :----------------: | :----------------: | :----------------: | :----------------: |
|
||||
| Windows 10 | :heavy_check_mark: | :heavy_check_mark: | - | :heavy_check_mark: |
|
||||
| macOS | :heavy_check_mark: | - | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Ubuntu LTS 18.04 | :heavy_check_mark: | - | - | :heavy_check_mark: |
|
||||
| iOS | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | - |
|
||||
| Android | :heavy_check_mark: | - | - | - |
|
||||
|
||||
### Operators
|
||||
|
||||
#### WebAssembly backend
|
||||
|
||||
ONNX Runtime Web currently support all operators in [ai.onnx](https://github.com/onnx/onnx/blob/master/docs/Operators.md) and [ai.onnx.ml](https://github.com/onnx/onnx/blob/master/docs/Operators-ml.md).
|
||||
|
||||
#### WebGL backend
|
||||
|
||||
ONNX Runtime Web currently supports most operators in [ai.onnx](https://github.com/onnx/onnx/blob/rel-1.2.3/docs/Operators.md) operator set v7 (opset v7). See [operators.md](./docs/operators.md) for a complete, detailed list of which ONNX operators are supported by WebGL backend.
|
||||
|
||||
## License
|
||||
|
||||
License information can be found [here](../../README.md#license).
|
||||
177
js/web/karma.conf.js
Normal file
177
js/web/karma.conf.js
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT license.
|
||||
|
||||
const bundleMode = require('minimist')(process.argv)['bundle-mode'] || 'dev'; // 'dev'|'perf'|undefined;
|
||||
const karmaPlugins = require('minimist')(process.argv)['karma-plugins'] || undefined;
|
||||
const commonFile = bundleMode === 'dev' ? '../common/dist/ort-common.js' : '../common/dist/ort-common.min.js'
|
||||
const mainFile = bundleMode === 'dev' ? 'test/ort.dev.js' : 'test/ort.perf.js';
|
||||
|
||||
// it's a known issue that Safari does not work with "localhost" in BrowserStack:
|
||||
// https://www.browserstack.com/question/663
|
||||
//
|
||||
// we need to read machine IP address to replace "localhost":
|
||||
// https://stackoverflow.com/a/8440736
|
||||
//
|
||||
function getMachineIpAddress() {
|
||||
var os = require('os');
|
||||
var ifaces = os.networkInterfaces();
|
||||
|
||||
for (const ifname in ifaces) {
|
||||
for (const iface of ifaces[ifname]) {
|
||||
if ('IPv4' !== iface.family || iface.internal !== false) {
|
||||
// skip over internal (i.e. 127.0.0.1) and non-ipv4 addresses
|
||||
continue;
|
||||
}
|
||||
|
||||
// returns the first available IP address
|
||||
return iface.address;
|
||||
}
|
||||
}
|
||||
|
||||
// if no available IP address, fallback to "localhost".
|
||||
return 'localhost';
|
||||
}
|
||||
|
||||
module.exports = function (config) {
|
||||
config.set({
|
||||
// global config of your BrowserStack account
|
||||
browserStack: {
|
||||
username: process.env.BROWSER_STACK_USERNAME,
|
||||
accessKey: process.env.BROWSER_STACK_ACCESS_KEY,
|
||||
forceLocal: true,
|
||||
startTunnel: true,
|
||||
},
|
||||
frameworks: ['mocha'],
|
||||
files: [
|
||||
{ pattern: commonFile },
|
||||
{ pattern: 'test/testdata-config.js' },
|
||||
{ pattern: mainFile },
|
||||
{ pattern: 'test/testdata-file-cache-*.json', included: false },
|
||||
//{ pattern: 'test/onnx-worker.js', included: false },
|
||||
{ pattern: 'test/data/**/*', included: false, nocache: true },
|
||||
{ pattern: 'dist/onnxruntime_wasm.wasm', included: false },
|
||||
{ pattern: 'dist/onnxruntime_wasm_threads.wasm', included: false },
|
||||
{ pattern: 'dist/onnxruntime_wasm_threads.worker.js', included: false },
|
||||
],
|
||||
proxies: {
|
||||
'/base/test/onnxruntime_wasm.wasm': '/base/dist/onnxruntime_wasm.wasm',
|
||||
'/onnxruntime_wasm_threads.wasm': '/base/dist/onnxruntime_wasm_threads.wasm',
|
||||
'/onnxruntime_wasm_threads.worker.js': '/base/dist/onnxruntime_wasm_threads.worker.js',
|
||||
},
|
||||
plugins: karmaPlugins,
|
||||
client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } },
|
||||
preprocessors: { mainFile: ['sourcemap'] },
|
||||
reporters: ['mocha', 'BrowserStack'],
|
||||
browsers: [],
|
||||
captureTimeout: 120000,
|
||||
reportSlowerThan: 100,
|
||||
browserDisconnectTimeout: 600000,
|
||||
browserNoActivityTimeout: 300000,
|
||||
browserDisconnectTolerance: 0,
|
||||
browserSocketTimeout: 60000,
|
||||
hostname: getMachineIpAddress(),
|
||||
customLaunchers: {
|
||||
ChromeTest: { base: 'Chrome', flags: ['--window-size=1,1'] },
|
||||
ChromeDebug: { debug: true, base: 'Chrome', flags: ['--remote-debugging-port=9333'] },
|
||||
//
|
||||
// ==== BrowserStack browsers ====
|
||||
//
|
||||
|
||||
// Windows
|
||||
//
|
||||
BS_WIN_10_Chrome_73: {
|
||||
base: 'BrowserStack',
|
||||
browser: 'Chrome',
|
||||
browser_version: '73.0',
|
||||
os: 'Windows',
|
||||
os_version: '10',
|
||||
},
|
||||
BS_WIN_10_Edge_18: {
|
||||
base: 'BrowserStack',
|
||||
os: 'Windows',
|
||||
os_version: '10',
|
||||
browser: 'Edge',
|
||||
browser_version: '18.0',
|
||||
},
|
||||
BS_WIN_10_Firefox_66: {
|
||||
base: 'BrowserStack',
|
||||
os: 'Windows',
|
||||
os_version: '10',
|
||||
browser: 'Firefox',
|
||||
browser_version: '66.0',
|
||||
},
|
||||
BS_WIN_7_Chrome_63: {
|
||||
base: 'BrowserStack',
|
||||
browser: 'Chrome',
|
||||
browser_version: '63.0',
|
||||
os: 'Windows',
|
||||
os_version: '7',
|
||||
},
|
||||
|
||||
// macOS
|
||||
//
|
||||
BS_MAC_10_14_Safari_12: {
|
||||
base: 'BrowserStack',
|
||||
os: 'OS X',
|
||||
os_version: 'Mojave',
|
||||
browser: 'Safari',
|
||||
browser_version: '12.0',
|
||||
},
|
||||
BS_MAC_10_14_Chrome_73: {
|
||||
base: 'BrowserStack',
|
||||
os: 'OS X',
|
||||
os_version: 'Mojave',
|
||||
browser: 'Chrome',
|
||||
browser_version: '73.0',
|
||||
},
|
||||
BS_MAC_10_13_Safari_11_1: {
|
||||
base: 'BrowserStack',
|
||||
os: 'OS X',
|
||||
os_version: 'High Sierra',
|
||||
browser: 'Safari',
|
||||
browser_version: '11.1',
|
||||
},
|
||||
|
||||
// iPhone
|
||||
//
|
||||
BS_IOS_12_1_iPhoneXS: {
|
||||
base: 'BrowserStack',
|
||||
device: 'iPhone XS',
|
||||
real_mobile: true,
|
||||
os: 'ios',
|
||||
os_version: '12.1',
|
||||
},
|
||||
BS_IOS_11_iPhoneX: {
|
||||
base: 'BrowserStack',
|
||||
device: 'iPhone X',
|
||||
real_mobile: true,
|
||||
os: 'ios',
|
||||
os_version: '11',
|
||||
},
|
||||
BS_IOS_10_3_iPhone7: {
|
||||
base: 'BrowserStack',
|
||||
device: 'iPhone 7',
|
||||
real_mobile: true,
|
||||
os: 'ios',
|
||||
os_version: '10.3',
|
||||
},
|
||||
|
||||
// Android
|
||||
//
|
||||
BS_ANDROID_9_Pixel_3: {
|
||||
base: 'BrowserStack',
|
||||
device: 'Google Pixel 3',
|
||||
real_mobile: true,
|
||||
os: 'android',
|
||||
os_version: '9.0',
|
||||
},
|
||||
BS_ANDROID_7_1_Galaxy_Note_8: {
|
||||
base: 'BrowserStack',
|
||||
device: 'Samsung Galaxy Note 8',
|
||||
real_mobile: true,
|
||||
os: 'android',
|
||||
os_version: '7.1',
|
||||
},
|
||||
}
|
||||
});
|
||||
};
|
||||
56
js/web/lib/backend-onnxjs.ts
Normal file
56
js/web/lib/backend-onnxjs.ts
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
/* eslint-disable import/no-internal-modules */
|
||||
import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common';
|
||||
import {Session} from './onnxjs/session';
|
||||
import {OnnxjsSessionHandler} from './onnxjs/session-handler';
|
||||
|
||||
class OnnxjsBackend implements Backend {
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-function
|
||||
async init(): Promise<void> {}
|
||||
|
||||
async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
|
||||
Promise<SessionHandler> {
|
||||
// NOTE: Session.Config(from onnx.js) is not compatible with InferenceSession.SessionOptions(from
|
||||
// onnxruntime-common).
|
||||
// In future we should remove Session.Config and use InferenceSession.SessionOptions.
|
||||
// Currently we allow this to happen to make test runner work.
|
||||
const session = new Session(options as unknown as Session.Config);
|
||||
|
||||
// typescript cannot merge method override correctly (so far in 4.2.3). need if-else to call the method.
|
||||
if (typeof pathOrBuffer === 'string') {
|
||||
await session.loadModel(pathOrBuffer);
|
||||
} else {
|
||||
await session.loadModel(pathOrBuffer);
|
||||
}
|
||||
|
||||
return new OnnxjsSessionHandler(session);
|
||||
}
|
||||
}
|
||||
|
||||
export const onnxjsBackend = new OnnxjsBackend();
|
||||
|
||||
export interface WebGLFlags {
|
||||
/**
|
||||
* set or get the WebGL Context ID (webgl or webgl2)
|
||||
*/
|
||||
contextId?: 'webgl'|'webgl2';
|
||||
/**
|
||||
* set or get the maximum batch size for matmul. 0 means to disable batching.
|
||||
*/
|
||||
matmulMaxBatchSize?: number;
|
||||
/**
|
||||
* set or get the texture cache mode
|
||||
*/
|
||||
textureCacheMode?: 'initializerOnly'|'full';
|
||||
/**
|
||||
* set or get the packed texture mode
|
||||
*/
|
||||
pack?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a set of flags for ONNX.js backend.
|
||||
*/
|
||||
export const flags: WebGLFlags = env.webgl = env.webgl as WebGLFlags || {};
|
||||
50
js/web/lib/backend-wasm.ts
Normal file
50
js/web/lib/backend-wasm.ts
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common';
|
||||
|
||||
import {init, OnnxruntimeWebAssemblySessionHandler} from './wasm';
|
||||
|
||||
class OnnxruntimeWebAssemblyBackend implements Backend {
|
||||
async init(): Promise<void> {
|
||||
await init();
|
||||
}
|
||||
createSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise<SessionHandler>;
|
||||
createSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): Promise<SessionHandler>;
|
||||
async createSessionHandler(pathOrBuffer: string|Uint8Array, _options?: InferenceSession.SessionOptions):
|
||||
Promise<SessionHandler> {
|
||||
let buffer: Uint8Array;
|
||||
if (typeof pathOrBuffer === 'string') {
|
||||
const response = await fetch(pathOrBuffer);
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
buffer = new Uint8Array(arrayBuffer);
|
||||
} else {
|
||||
buffer = pathOrBuffer;
|
||||
}
|
||||
const handler = new OnnxruntimeWebAssemblySessionHandler();
|
||||
// TODO: support SessionOptions
|
||||
handler.loadModel(buffer);
|
||||
return Promise.resolve(handler);
|
||||
}
|
||||
}
|
||||
|
||||
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
|
||||
|
||||
export interface WebAssemblyFlags {
|
||||
/**
|
||||
* set or get number of worker(s)
|
||||
*
|
||||
* This setting is available only when WebAssembly multithread feature is available in current context.
|
||||
*/
|
||||
worker?: number;
|
||||
|
||||
/**
|
||||
* set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds.
|
||||
*/
|
||||
initTimeout?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a set of flags for WebAssembly backend.
|
||||
*/
|
||||
export const flags: WebAssemblyFlags = env.wasm = env.wasm as WebAssemblyFlags || {};
|
||||
10
js/web/lib/index.ts
Normal file
10
js/web/lib/index.ts
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
export * from 'onnxruntime-common';
|
||||
import {registerBackend} from 'onnxruntime-common';
|
||||
import {onnxjsBackend} from './backend-onnxjs';
|
||||
import {wasmBackend} from './backend-wasm';
|
||||
|
||||
registerBackend('webgl', onnxjsBackend, 1);
|
||||
registerBackend('wasm', wasmBackend, 2);
|
||||
201
js/web/lib/onnxjs/attribute.ts
Normal file
201
js/web/lib/onnxjs/attribute.ts
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import Long from 'long';
|
||||
import {onnx} from 'onnx-proto';
|
||||
|
||||
import {Tensor} from './tensor';
|
||||
import {LongUtil} from './util';
|
||||
|
||||
export declare namespace Attribute {
|
||||
export interface DataTypeMap {
|
||||
float: number;
|
||||
int: number;
|
||||
string: string;
|
||||
tensor: Tensor;
|
||||
floats: number[];
|
||||
ints: number[];
|
||||
strings: string[];
|
||||
tensors: Tensor[];
|
||||
}
|
||||
|
||||
export type DataType = keyof DataTypeMap;
|
||||
}
|
||||
|
||||
type ValueTypes = Attribute.DataTypeMap[Attribute.DataType];
|
||||
|
||||
type Value = [ValueTypes, Attribute.DataType];
|
||||
|
||||
export class Attribute {
|
||||
constructor(attributes: onnx.IAttributeProto[]|null|undefined) {
|
||||
this._attributes = new Map();
|
||||
if (attributes !== null && attributes !== undefined) {
|
||||
for (const attr of attributes) {
|
||||
this._attributes.set(attr.name!, [Attribute.getValue(attr), Attribute.getType(attr)]);
|
||||
}
|
||||
|
||||
if (this._attributes.size < attributes.length) {
|
||||
throw new Error('duplicated attribute names');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set(key: string, type: Attribute.DataType, value: ValueTypes): void {
|
||||
this._attributes.set(key, [value, type]);
|
||||
}
|
||||
delete(key: string): void {
|
||||
this._attributes.delete(key);
|
||||
}
|
||||
|
||||
getFloat(key: string, defaultValue?: Attribute.DataTypeMap['float']) {
|
||||
return this.get(key, 'float', defaultValue);
|
||||
}
|
||||
|
||||
getInt(key: string, defaultValue?: Attribute.DataTypeMap['int']) {
|
||||
return this.get(key, 'int', defaultValue);
|
||||
}
|
||||
|
||||
getString(key: string, defaultValue?: Attribute.DataTypeMap['string']) {
|
||||
return this.get(key, 'string', defaultValue);
|
||||
}
|
||||
|
||||
getTensor(key: string, defaultValue?: Attribute.DataTypeMap['tensor']) {
|
||||
return this.get(key, 'tensor', defaultValue);
|
||||
}
|
||||
|
||||
getFloats(key: string, defaultValue?: Attribute.DataTypeMap['floats']) {
|
||||
return this.get(key, 'floats', defaultValue);
|
||||
}
|
||||
|
||||
getInts(key: string, defaultValue?: Attribute.DataTypeMap['ints']) {
|
||||
return this.get(key, 'ints', defaultValue);
|
||||
}
|
||||
|
||||
getStrings(key: string, defaultValue?: Attribute.DataTypeMap['strings']) {
|
||||
return this.get(key, 'strings', defaultValue);
|
||||
}
|
||||
|
||||
getTensors(key: string, defaultValue?: Attribute.DataTypeMap['tensors']) {
|
||||
return this.get(key, 'tensors', defaultValue);
|
||||
}
|
||||
|
||||
private get<V extends Attribute.DataTypeMap[Attribute.DataType]>(
|
||||
key: string, type: Attribute.DataType, defaultValue?: V): V {
|
||||
const valueAndType = this._attributes.get(key);
|
||||
if (valueAndType === undefined) {
|
||||
if (defaultValue !== undefined) {
|
||||
return defaultValue;
|
||||
}
|
||||
throw new Error(`required attribute not found: ${key}`);
|
||||
}
|
||||
if (valueAndType[1] !== type) {
|
||||
throw new Error(`type mismatch: expected ${type} but got ${valueAndType[1]}`);
|
||||
}
|
||||
return valueAndType[0] as V;
|
||||
}
|
||||
|
||||
private static getType(attr: onnx.IAttributeProto): Attribute.DataType {
|
||||
switch (attr.type!) {
|
||||
case onnx.AttributeProto.AttributeType.FLOAT:
|
||||
return 'float';
|
||||
case onnx.AttributeProto.AttributeType.INT:
|
||||
return 'int';
|
||||
case onnx.AttributeProto.AttributeType.STRING:
|
||||
return 'string';
|
||||
case onnx.AttributeProto.AttributeType.TENSOR:
|
||||
return 'tensor';
|
||||
case onnx.AttributeProto.AttributeType.FLOATS:
|
||||
return 'floats';
|
||||
case onnx.AttributeProto.AttributeType.INTS:
|
||||
return 'ints';
|
||||
case onnx.AttributeProto.AttributeType.STRINGS:
|
||||
return 'strings';
|
||||
case onnx.AttributeProto.AttributeType.TENSORS:
|
||||
return 'tensors';
|
||||
default:
|
||||
throw new Error(`attribute type is not supported yet: ${onnx.AttributeProto.AttributeType[attr.type!]}`);
|
||||
}
|
||||
}
|
||||
|
||||
private static getValue(attr: onnx.IAttributeProto) {
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.GRAPH ||
|
||||
attr.type === onnx.AttributeProto.AttributeType.GRAPHS) {
|
||||
throw new Error('graph attribute is not supported yet');
|
||||
}
|
||||
|
||||
const value = this.getValueNoCheck(attr);
|
||||
|
||||
// cast LONG to number
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.INT && Long.isLong(value)) {
|
||||
return value.toNumber();
|
||||
}
|
||||
|
||||
// cast LONG[] to number[]
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.INTS) {
|
||||
const arr = (value as Array<number|Long>);
|
||||
const numberValue: number[] = new Array<number>(arr.length);
|
||||
|
||||
for (let i = 0; i < arr.length; i++) {
|
||||
const maybeLong = arr[i];
|
||||
numberValue[i] = LongUtil.longToNumber(maybeLong);
|
||||
}
|
||||
|
||||
return numberValue;
|
||||
}
|
||||
|
||||
// cast onnx.TensorProto to onnxjs.Tensor
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.TENSOR) {
|
||||
return Tensor.fromProto(value as onnx.ITensorProto);
|
||||
}
|
||||
|
||||
// cast onnx.TensorProto[] to onnxjs.Tensor[]
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.TENSORS) {
|
||||
const tensorProtos = value as onnx.ITensorProto[];
|
||||
return tensorProtos.map(value => Tensor.fromProto(value));
|
||||
}
|
||||
|
||||
// cast Uint8Array to string
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.STRING) {
|
||||
const utf8String = value as Uint8Array;
|
||||
return Buffer.from(utf8String.buffer, utf8String.byteOffset, utf8String.byteLength).toString();
|
||||
}
|
||||
|
||||
// cast Uint8Array[] to string[]
|
||||
if (attr.type === onnx.AttributeProto.AttributeType.STRINGS) {
|
||||
const utf8Strings = value as Uint8Array[];
|
||||
return utf8Strings.map(
|
||||
utf8String => Buffer.from(utf8String.buffer, utf8String.byteOffset, utf8String.byteLength).toString());
|
||||
}
|
||||
|
||||
return value as ValueTypes;
|
||||
}
|
||||
|
||||
private static getValueNoCheck(attr: onnx.IAttributeProto) {
|
||||
switch (attr.type!) {
|
||||
case onnx.AttributeProto.AttributeType.FLOAT:
|
||||
return attr.f;
|
||||
case onnx.AttributeProto.AttributeType.INT:
|
||||
return attr.i;
|
||||
case onnx.AttributeProto.AttributeType.STRING:
|
||||
return attr.s;
|
||||
case onnx.AttributeProto.AttributeType.TENSOR:
|
||||
return attr.t;
|
||||
case onnx.AttributeProto.AttributeType.GRAPH:
|
||||
return attr.g;
|
||||
case onnx.AttributeProto.AttributeType.FLOATS:
|
||||
return attr.floats;
|
||||
case onnx.AttributeProto.AttributeType.INTS:
|
||||
return attr.ints;
|
||||
case onnx.AttributeProto.AttributeType.STRINGS:
|
||||
return attr.strings;
|
||||
case onnx.AttributeProto.AttributeType.TENSORS:
|
||||
return attr.tensors;
|
||||
case onnx.AttributeProto.AttributeType.GRAPHS:
|
||||
return attr.graphs;
|
||||
default:
|
||||
throw new Error(`unsupported attribute type: ${onnx.AttributeProto.AttributeType[attr.type!]}`);
|
||||
}
|
||||
}
|
||||
|
||||
protected _attributes: Map<string, Value>;
|
||||
}
|
||||
146
js/web/lib/onnxjs/backend.ts
Normal file
146
js/web/lib/onnxjs/backend.ts
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {WebGLBackend} from './backends/backend-webgl';
|
||||
import {Graph} from './graph';
|
||||
import {Operator} from './operators';
|
||||
import {OpSet} from './opset';
|
||||
import {Session} from './session';
|
||||
|
||||
export interface InferenceHandler {
|
||||
/**
|
||||
* dispose the inference handler. it will be called as the last step in Session.run()
|
||||
*/
|
||||
dispose(): void;
|
||||
}
|
||||
|
||||
export interface SessionHandler {
|
||||
/**
|
||||
* transform the graph at initialization time
|
||||
* @param graphTransformer the graph transformer to manipulate the model graph
|
||||
*/
|
||||
transformGraph?(graphTransformer: Graph.Transformer): void;
|
||||
|
||||
/**
|
||||
* create an instance of InferenceHandler to use in a Session.run() call
|
||||
*/
|
||||
createInferenceHandler(): InferenceHandler;
|
||||
|
||||
/**
|
||||
* dispose the session handler. it will be called when a session is being disposed explicitly
|
||||
*/
|
||||
dispose(): void;
|
||||
|
||||
/**
|
||||
* Resolves the operator from the name and opset version; backend specific
|
||||
* @param node the node to resolve
|
||||
* @param opsets a list of opsets that exported from the model
|
||||
* @param graph the completely initialized graph
|
||||
*/
|
||||
resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator;
|
||||
|
||||
/**
|
||||
* This method let's the sessionHandler know that the graph initialization is complete
|
||||
* @param graph the completely initialized graph
|
||||
*/
|
||||
onGraphInitialized?(graph: Graph): void;
|
||||
|
||||
/**
|
||||
* a reference to the corresponding backend
|
||||
*/
|
||||
readonly backend: Backend;
|
||||
|
||||
/**
|
||||
* a reference to the session context
|
||||
*/
|
||||
readonly context: Session.Context;
|
||||
}
|
||||
|
||||
export interface Backend {
|
||||
/**
|
||||
* initialize the backend. will be called only once, when the first time the
|
||||
* backend it to be used
|
||||
*/
|
||||
initialize(): boolean|Promise<boolean>;
|
||||
|
||||
/**
|
||||
* create an instance of SessionHandler to use in a Session object's lifecycle
|
||||
*/
|
||||
createSessionHandler(context: Session.Context): SessionHandler;
|
||||
|
||||
/**
|
||||
* dispose the backend. currently this will not be called
|
||||
*/
|
||||
dispose(): void;
|
||||
}
|
||||
|
||||
// caches all initialized backend instances
|
||||
const backendsCache: Map<string, Backend> = new Map();
|
||||
|
||||
export const backend: {[name: string]: Backend} = {
|
||||
webgl: new WebGLBackend(),
|
||||
};
|
||||
|
||||
/**
|
||||
* Resolve a reference to the backend. If a hint is specified, the corresponding
|
||||
* backend will be used.
|
||||
*/
|
||||
export async function resolveBackend(hint?: string|readonly string[]): Promise<Backend> {
|
||||
if (!hint) {
|
||||
return resolveBackend(['webgl']);
|
||||
} else {
|
||||
const hints = typeof hint === 'string' ? [hint] : hint;
|
||||
|
||||
for (const backendHint of hints) {
|
||||
const cache = backendsCache.get(backendHint);
|
||||
if (cache) {
|
||||
return cache;
|
||||
}
|
||||
|
||||
const backend = await tryLoadBackend(backendHint);
|
||||
if (backend) {
|
||||
return backend;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error('no available backend to use');
|
||||
}
|
||||
|
||||
async function tryLoadBackend(backendHint: string): Promise<Backend|undefined> {
|
||||
const backendObj = backend;
|
||||
|
||||
if (typeof backendObj[backendHint] !== 'undefined' && isBackend(backendObj[backendHint])) {
|
||||
const backend = backendObj[backendHint];
|
||||
let init = backend.initialize();
|
||||
if (typeof init === 'object' && 'then' in init) {
|
||||
init = await init;
|
||||
}
|
||||
if (init) {
|
||||
backendsCache.set(backendHint, backend);
|
||||
return backend;
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function isBackend(obj: unknown) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
const o = obj as any;
|
||||
|
||||
// check if an object is a Backend instance
|
||||
if (
|
||||
'initialize' in o && typeof o.initialize === 'function' && // initialize()
|
||||
'createSessionHandler' in o && typeof o.createSessionHandler === 'function' && // createSessionHandler()
|
||||
'dispose' in o && typeof o.dispose === 'function' // dispose()
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
export type BackendType = Backend;
|
||||
export type SessionHandlerType = ReturnType<BackendType['createSessionHandler']>;
|
||||
export type InferenceHandlerType = ReturnType<SessionHandlerType['createInferenceHandler']>;
|
||||
77
js/web/lib/onnxjs/backends/backend-webgl.ts
Normal file
77
js/web/lib/onnxjs/backends/backend-webgl.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {flags} from '../../backend-onnxjs';
|
||||
import {Backend, SessionHandler} from '../backend';
|
||||
import {Logger} from '../instrument';
|
||||
import {Session} from '../session';
|
||||
|
||||
import {WebGLSessionHandler} from './webgl/session-handler';
|
||||
import {WebGLContext} from './webgl/webgl-context';
|
||||
import {createWebGLContext} from './webgl/webgl-context-factory';
|
||||
|
||||
/**
|
||||
* WebGLBackend is the entry point for all WebGL opeartions
|
||||
* When it starts it created the WebGLRenderingContext
|
||||
* and other main framework components such as Program and Texture Managers
|
||||
*/
|
||||
export class WebGLBackend implements Backend {
|
||||
glContext: WebGLContext;
|
||||
|
||||
get contextId(): 'webgl'|'webgl2'|undefined {
|
||||
return flags.contextId;
|
||||
}
|
||||
set contextId(value: 'webgl'|'webgl2'|undefined) {
|
||||
flags.contextId = value;
|
||||
}
|
||||
|
||||
get matmulMaxBatchSize(): number|undefined {
|
||||
return flags.matmulMaxBatchSize;
|
||||
}
|
||||
set matmulMaxBatchSize(value: number|undefined) {
|
||||
flags.matmulMaxBatchSize = value;
|
||||
}
|
||||
|
||||
get textureCacheMode(): 'initializerOnly'|'full'|undefined {
|
||||
return flags.textureCacheMode;
|
||||
}
|
||||
set textureCacheMode(value: 'initializerOnly'|'full'|undefined) {
|
||||
flags.textureCacheMode = value;
|
||||
}
|
||||
|
||||
get pack(): boolean|undefined {
|
||||
return flags.pack;
|
||||
}
|
||||
set pack(value: boolean|undefined) {
|
||||
flags.pack = value;
|
||||
}
|
||||
|
||||
initialize(): boolean {
|
||||
try {
|
||||
this.glContext = createWebGLContext(this.contextId);
|
||||
if (typeof this.matmulMaxBatchSize !== 'number') {
|
||||
this.matmulMaxBatchSize = 16;
|
||||
}
|
||||
if (typeof this.textureCacheMode !== 'string') {
|
||||
this.textureCacheMode = 'full';
|
||||
}
|
||||
if (typeof this.pack !== 'boolean') {
|
||||
this.pack = false;
|
||||
}
|
||||
Logger.verbose(
|
||||
'WebGLBackend',
|
||||
`Created WebGLContext: ${typeof this.glContext} with matmulMaxBatchSize: ${
|
||||
this.matmulMaxBatchSize}; textureCacheMode: ${this.textureCacheMode}; pack: ${this.pack}.`);
|
||||
return true;
|
||||
} catch (e) {
|
||||
Logger.warning('WebGLBackend', `Unable to initialize WebGLBackend. ${e}`);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
createSessionHandler(context: Session.Context): SessionHandler {
|
||||
return new WebGLSessionHandler(this, context);
|
||||
}
|
||||
dispose(): void {
|
||||
this.glContext.dispose();
|
||||
}
|
||||
}
|
||||
74
js/web/lib/onnxjs/backends/webgl/glsl-array-lib.ts
Normal file
74
js/web/lib/onnxjs/backends/webgl/glsl-array-lib.ts
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions';
|
||||
/**
|
||||
* This library produces routines needed for non-constant access to uniform arrays
|
||||
*/
|
||||
export class ArrayGlslLib extends GlslLib {
|
||||
getFunctions(): {[name: string]: GlslLibRoutine} {
|
||||
return this.generate();
|
||||
}
|
||||
getCustomTypes(): {[name: string]: string} {
|
||||
return {};
|
||||
}
|
||||
constructor(context: GlslContext) {
|
||||
super(context);
|
||||
}
|
||||
protected generate(): {[name: string]: GlslLibRoutine} {
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
for (let i = 1; i <= 16; i++) {
|
||||
result[`setItem${i}`] = new GlslLibRoutine(this.generateSetItem(i));
|
||||
result[`getItem${i}`] = new GlslLibRoutine(this.generateGetItem(i));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
protected generateSetItem(length: number): string {
|
||||
let block = `
|
||||
if(index < 0)
|
||||
index = ${length} + index;
|
||||
if (index == 0)
|
||||
a[0] = value;
|
||||
`;
|
||||
for (let i = 1; i < length - 1; ++i) {
|
||||
block += `
|
||||
else if (index == ${i})
|
||||
a[${i}] = value;
|
||||
`;
|
||||
}
|
||||
block += `
|
||||
else
|
||||
a[${length - 1}] = value;
|
||||
`;
|
||||
const body = `
|
||||
void setItem${length}(out float a[${length}], int index, float value) {
|
||||
${block}
|
||||
}
|
||||
`;
|
||||
return body;
|
||||
}
|
||||
protected generateGetItem(length: number): string {
|
||||
let block = `
|
||||
if(index < 0)
|
||||
index = ${length} + index;
|
||||
if (index == 0)
|
||||
return a[0];
|
||||
`;
|
||||
for (let i = 1; i < length - 1; ++i) {
|
||||
block += `
|
||||
else if (index == ${i})
|
||||
return a[${i}];
|
||||
`;
|
||||
}
|
||||
block += `
|
||||
else
|
||||
return a[${length - 1}];
|
||||
`;
|
||||
const body = `
|
||||
float getItem${length}(float a[${length}], int index) {
|
||||
${block}
|
||||
}
|
||||
`;
|
||||
return body;
|
||||
}
|
||||
}
|
||||
1215
js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts
Normal file
1215
js/web/lib/onnxjs/backends/webgl/glsl-coordinate-lib.ts
Normal file
File diff suppressed because it is too large
Load diff
119
js/web/lib/onnxjs/backends/webgl/glsl-definitions.ts
Normal file
119
js/web/lib/onnxjs/backends/webgl/glsl-definitions.ts
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {ProgramInfo} from './types';
|
||||
import {WebGLContext} from './webgl-context';
|
||||
|
||||
/* eslint-disable @typescript-eslint/naming-convention */
|
||||
export enum FunctionType {
|
||||
ValueBased,
|
||||
Positional
|
||||
}
|
||||
export interface GlslFunction<T extends FunctionType> {
|
||||
body: string;
|
||||
name: string;
|
||||
type: T;
|
||||
}
|
||||
export type GlslValueFunction = GlslFunction<FunctionType.ValueBased>;
|
||||
export interface GlslPositionalFunction extends GlslFunction<FunctionType.Positional> {
|
||||
inputShape: readonly number[];
|
||||
outputShape: readonly number[];
|
||||
}
|
||||
|
||||
export class GlslContext {
|
||||
constructor(public glContext: WebGLContext, public programInfo: ProgramInfo) {}
|
||||
}
|
||||
export abstract class GlslLib {
|
||||
constructor(public context: GlslContext) {}
|
||||
abstract getFunctions(): {[name: string]: GlslLibRoutine};
|
||||
abstract getCustomTypes(): {[name: string]: string};
|
||||
}
|
||||
|
||||
// abstraction to represent a GLSL library routine and it's dependencies
|
||||
export class GlslLibRoutine {
|
||||
constructor(public routineBody: string, public dependencies?: string[]) {}
|
||||
}
|
||||
|
||||
// abstraction to represent a GLSL library routine and it's dependencies AS GRAPH Nodes
|
||||
// this level of abstraction is used to topologically sort routines before fragment shade inclusion
|
||||
export class GlslLibRoutineNode {
|
||||
dependencies: GlslLibRoutineNode[];
|
||||
routineBody: string;
|
||||
constructor(public name: string, routineBody?: string, dependencies?: GlslLibRoutineNode[]) {
|
||||
if (dependencies) {
|
||||
this.dependencies = dependencies;
|
||||
} else {
|
||||
this.dependencies = [];
|
||||
}
|
||||
|
||||
if (routineBody) {
|
||||
this.routineBody = routineBody;
|
||||
}
|
||||
}
|
||||
addDependency(node: GlslLibRoutineNode) {
|
||||
if (node) {
|
||||
this.dependencies.push(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// topologically sort GLSL library routines (graph nodes abstraction) before shader script inclusion
|
||||
export class TopologicalSortGlslRoutines {
|
||||
static returnOrderedNodes(nodes: GlslLibRoutineNode[]): GlslLibRoutineNode[] {
|
||||
if (!nodes || nodes.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (nodes.length === 1) {
|
||||
return nodes;
|
||||
}
|
||||
|
||||
const cycleCheck = new Set<string>();
|
||||
const alreadyTraversed = new Set<string>();
|
||||
const result = new Array<GlslLibRoutineNode>();
|
||||
|
||||
this.createOrderedNodes(nodes, cycleCheck, alreadyTraversed, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private static createOrderedNodes(
|
||||
graphNodes: GlslLibRoutineNode[], cycleCheck: Set<string>, alreadyTraversed: Set<string>,
|
||||
result: GlslLibRoutineNode[]) {
|
||||
for (let i = 0; i < graphNodes.length; ++i) {
|
||||
this.dfsTraverse(graphNodes[i], cycleCheck, alreadyTraversed, result);
|
||||
}
|
||||
}
|
||||
|
||||
private static dfsTraverse(
|
||||
root: GlslLibRoutineNode, cycleCheck: Set<string>, alreadyTraversed: Set<string>, result: GlslLibRoutineNode[]) {
|
||||
// if this root has already been traversed return
|
||||
if (!root || alreadyTraversed.has(root.name)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// cyclic dependency has been detected
|
||||
if (cycleCheck.has(root.name)) {
|
||||
throw new Error('Cyclic dependency detected. Can\'t topologically sort routines needed for shader.');
|
||||
}
|
||||
|
||||
// hold this node to detect cycles if any
|
||||
cycleCheck.add(root.name);
|
||||
|
||||
// traverse children in a dfs fashion
|
||||
const dependencies = root.dependencies;
|
||||
if (dependencies && dependencies.length > 0) {
|
||||
for (let i = 0; i < dependencies.length; ++i) {
|
||||
this.dfsTraverse(dependencies[i], cycleCheck, alreadyTraversed, result);
|
||||
}
|
||||
}
|
||||
|
||||
// add to result holder
|
||||
result.push(root);
|
||||
|
||||
// mark this node as traversed so that we don't traverse from this again
|
||||
alreadyTraversed.add(root.name);
|
||||
|
||||
// release the hold
|
||||
cycleCheck.delete(root.name);
|
||||
}
|
||||
}
|
||||
99
js/web/lib/onnxjs/backends/webgl/glsl-encoding-lib.ts
Normal file
99
js/web/lib/onnxjs/backends/webgl/glsl-encoding-lib.ts
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions';
|
||||
|
||||
/**
|
||||
* This GLSL library handles routines converting
|
||||
* float32 to/from Unsigned byte or float 16
|
||||
*/
|
||||
export class EncodingGlslLib extends GlslLib {
|
||||
constructor(context: GlslContext) {
|
||||
super(context);
|
||||
}
|
||||
getFunctions(): {[name: string]: GlslLibRoutine} {
|
||||
return {...this.encodeFloat32(), ...this.decodeFloat32()};
|
||||
}
|
||||
getCustomTypes(): {[name: string]: string} {
|
||||
return {};
|
||||
}
|
||||
protected encodeFloat32(): {[name: string]: GlslLibRoutine} {
|
||||
return {
|
||||
encode: new GlslLibRoutine(`highp vec4 encode(highp float f) {
|
||||
return vec4(f, 0.0, 0.0, 0.0);
|
||||
}
|
||||
`)
|
||||
};
|
||||
}
|
||||
protected decodeFloat32(): {[name: string]: GlslLibRoutine} {
|
||||
return {
|
||||
decode: new GlslLibRoutine(`highp float decode(highp vec4 rgba) {
|
||||
return rgba.r;
|
||||
}
|
||||
`)
|
||||
};
|
||||
}
|
||||
/**
|
||||
* returns the routine to encode encode a 32bit float to a vec4 (of unsigned bytes)
|
||||
* @credit: https://stackoverflow.com/questions/7059962/how-do-i-convert-a-vec4-rgba-value-to-a-float
|
||||
*/
|
||||
protected encodeUint8(): {[name: string]: GlslLibRoutine} {
|
||||
const endianness = EncodingGlslLib.isLittleEndian() ? 'rgba.rgba=rgba.abgr;' : '';
|
||||
return {
|
||||
encode: new GlslLibRoutine(`
|
||||
highp vec4 encode(highp float f) {
|
||||
highp float F = abs(f);
|
||||
highp float Sign = step(0.0,-f);
|
||||
highp float Exponent = floor(log2(F));
|
||||
highp float Mantissa = (exp2(- Exponent) * F);
|
||||
Exponent = floor(log2(F) + 127.0) + floor(log2(Mantissa));
|
||||
highp vec4 rgba;
|
||||
rgba[0] = 128.0 * Sign + floor(Exponent*exp2(-1.0));
|
||||
rgba[1] = 128.0 * mod(Exponent,2.0) + mod(floor(Mantissa*128.0),128.0);
|
||||
rgba[2] = floor(mod(floor(Mantissa*exp2(23.0 -8.0)),exp2(8.0)));
|
||||
rgba[3] = floor(exp2(23.0)*mod(Mantissa,exp2(-15.0)));
|
||||
${endianness}
|
||||
rgba = rgba / 255.0; // values need to be normalized to [0,1]
|
||||
return rgba;
|
||||
}
|
||||
`)
|
||||
};
|
||||
}
|
||||
/**
|
||||
* returns the routine to encode a vec4 of unsigned bytes to float32
|
||||
* @credit: https://stackoverflow.com/questions/7059962/how-do-i-convert-a-vec4-rgba-value-to-a-float
|
||||
*/
|
||||
protected decodeUint8(): {[name: string]: GlslLibRoutine} {
|
||||
const endianness = EncodingGlslLib.isLittleEndian() ? 'rgba.rgba=rgba.abgr;' : '';
|
||||
return {
|
||||
decode: new GlslLibRoutine(`
|
||||
highp float decode(highp vec4 rgba) {
|
||||
rgba = rgba * 255.0; // values need to be de-normalized from [0,1] to [0,255]
|
||||
${endianness}
|
||||
highp float Sign = 1.0 - step(128.0,rgba[0])*2.0;
|
||||
highp float Exponent = 2.0 * mod(rgba[0],128.0) + step(128.0,rgba[1]) - 127.0;
|
||||
highp float Mantissa = mod(rgba[1],128.0)*65536.0 + rgba[2]*256.0 +rgba[3] + float(0x800000);
|
||||
highp float Result = Sign * exp2(Exponent) * (Mantissa * exp2(-23.0 ));
|
||||
return Result;
|
||||
}
|
||||
`)
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Determines if the machine is little endian or not
|
||||
* @credit: https://gist.github.com/TooTallNate/4750953
|
||||
*/
|
||||
static isLittleEndian(): boolean {
|
||||
const b = new ArrayBuffer(4);
|
||||
const a = new Uint32Array(b);
|
||||
const c = new Uint8Array(b);
|
||||
a[0] = 0xdeadbeef;
|
||||
if (c[0] === 0xef) {
|
||||
return true;
|
||||
}
|
||||
if (c[0] === 0xde) {
|
||||
return false;
|
||||
}
|
||||
throw new Error('unknown endianness');
|
||||
}
|
||||
}
|
||||
45
js/web/lib/onnxjs/backends/webgl/glsl-fragcolor-lib.ts
Normal file
45
js/web/lib/onnxjs/backends/webgl/glsl-fragcolor-lib.ts
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions';
|
||||
import {getGlsl} from './glsl-source';
|
||||
|
||||
/**
|
||||
* This GLSL library handles routines around reading a texlet and writing to it
|
||||
* Reading and writing could be more than just dealing with one channel
|
||||
* It may require encoding/decoding to/from 4 channels into one
|
||||
*/
|
||||
export class FragColorGlslLib extends GlslLib {
|
||||
constructor(context: GlslContext) {
|
||||
super(context);
|
||||
}
|
||||
getFunctions(): {[name: string]: GlslLibRoutine} {
|
||||
return {...this.setFragColor(), ...this.getColorAsFloat()};
|
||||
}
|
||||
getCustomTypes(): {[name: string]: string} {
|
||||
return {};
|
||||
}
|
||||
protected setFragColor(): {[name: string]: GlslLibRoutine} {
|
||||
const glsl = getGlsl(this.context.glContext.version);
|
||||
return {
|
||||
setFragColor: new GlslLibRoutine(
|
||||
`
|
||||
void setFragColor(float value) {
|
||||
${glsl.output} = encode(value);
|
||||
}
|
||||
`,
|
||||
['encoding.encode'])
|
||||
};
|
||||
}
|
||||
protected getColorAsFloat(): {[name: string]: GlslLibRoutine} {
|
||||
return {
|
||||
getColorAsFloat: new GlslLibRoutine(
|
||||
`
|
||||
float getColorAsFloat(vec4 color) {
|
||||
return decode(color);
|
||||
}
|
||||
`,
|
||||
['encoding.decode'])
|
||||
};
|
||||
}
|
||||
}
|
||||
53
js/web/lib/onnxjs/backends/webgl/glsl-function-inliner.ts
Normal file
53
js/web/lib/onnxjs/backends/webgl/glsl-function-inliner.ts
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
const INLINE_FUNC_DEF_REGEX = /@inline[\s\n\r]+(\w+)[\s\n\r]+([0-9a-zA-Z_]+)\s*\(([^)]*)\)\s*{(([^}]|[\n\r])*)}/gm;
|
||||
const FUNC_CALL_REGEX = '(\\w+)?\\s+([_0-9a-zA-Z]+)\\s+=\\s+__FUNC__\\((.*)\\)\\s*;';
|
||||
/**
|
||||
* GLSL preprocessor responsible for resolving @inline directives
|
||||
*/
|
||||
export function replaceInlines(script: string): string {
|
||||
const inlineDefs: {[name: string]: {params: Array<{type: string; name: string}|null>; body: string}} = {};
|
||||
let match;
|
||||
while ((match = INLINE_FUNC_DEF_REGEX.exec(script)) !== null) {
|
||||
const params = match[3]
|
||||
.split(',')
|
||||
.map(s => {
|
||||
const tokens = s.trim().split(' ');
|
||||
if (tokens && tokens.length === 2) {
|
||||
return {type: tokens[0], name: tokens[1]};
|
||||
}
|
||||
return null;
|
||||
})
|
||||
.filter(v => v !== null);
|
||||
inlineDefs[match[2]] = {params, body: match[4]};
|
||||
}
|
||||
for (const name in inlineDefs) {
|
||||
const regexString = FUNC_CALL_REGEX.replace('__FUNC__', name);
|
||||
const regex = new RegExp(regexString, 'gm');
|
||||
while ((match = regex.exec(script)) !== null) {
|
||||
const type = match[1];
|
||||
const variable = match[2];
|
||||
const params = match[3].split(',');
|
||||
const declLine = (type) ? `${type} ${variable};` : '';
|
||||
let newBody: string = inlineDefs[name].body;
|
||||
let paramRedecLine = '';
|
||||
inlineDefs[name].params.forEach((v, i) => {
|
||||
if (v) {
|
||||
paramRedecLine += `${v.type} ${v.name} = ${params[i]};\n`;
|
||||
}
|
||||
});
|
||||
newBody = `${paramRedecLine}\n ${newBody}`;
|
||||
newBody = newBody.replace('return', `${variable} = `);
|
||||
const replacement = `
|
||||
${declLine}
|
||||
{
|
||||
${newBody}
|
||||
}
|
||||
`;
|
||||
script = script.replace(match[0], replacement);
|
||||
}
|
||||
}
|
||||
script = script.replace(INLINE_FUNC_DEF_REGEX, '');
|
||||
return script;
|
||||
}
|
||||
129
js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts
Normal file
129
js/web/lib/onnxjs/backends/webgl/glsl-preprocessor.ts
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {GlslContext, GlslLib, GlslLibRoutineNode, TopologicalSortGlslRoutines} from './glsl-definitions';
|
||||
import {replaceInlines} from './glsl-function-inliner';
|
||||
import {glslRegistry} from './glsl-registered-libs';
|
||||
import {getDefaultFragShaderMain, getFragShaderPreamble} from './glsl-source';
|
||||
import {ProgramInfo, VariableInfo} from './types';
|
||||
import {WebGLContext} from './webgl-context';
|
||||
|
||||
/**
|
||||
* Preprocessor for the additions to the GLSL language
|
||||
* It deals with:
|
||||
* @include directives
|
||||
* @inline
|
||||
* Loop unrolling (not implemented)
|
||||
* Macro resolution (not implemented)
|
||||
*/
|
||||
export class GlslPreprocessor {
|
||||
readonly context: GlslContext;
|
||||
readonly libs: {[name: string]: GlslLib} = {};
|
||||
readonly glslLibRoutineDependencyGraph: {[routineName: string]: GlslLibRoutineNode} = {};
|
||||
|
||||
constructor(glContext: WebGLContext, programInfo: ProgramInfo) {
|
||||
this.context = new GlslContext(glContext, programInfo);
|
||||
|
||||
// construct GlslLibs
|
||||
Object.keys(glslRegistry).forEach((name: string) => {
|
||||
const lib = new glslRegistry[name](this.context);
|
||||
this.libs[name] = lib;
|
||||
});
|
||||
|
||||
// construct GlslRoutineDependencyGraph
|
||||
const map = this.glslLibRoutineDependencyGraph;
|
||||
for (const libName in this.libs) {
|
||||
const lib = this.libs[libName];
|
||||
const routinesInLib = lib.getFunctions();
|
||||
for (const routine in routinesInLib) {
|
||||
const key = libName + '.' + routine;
|
||||
let currentNode: GlslLibRoutineNode;
|
||||
if (map[key]) {
|
||||
currentNode = map[key];
|
||||
currentNode.routineBody = routinesInLib[routine].routineBody;
|
||||
} else {
|
||||
currentNode = new GlslLibRoutineNode(key, routinesInLib[routine].routineBody);
|
||||
map[key] = currentNode;
|
||||
}
|
||||
const dependencies = routinesInLib[routine].dependencies;
|
||||
if (dependencies) {
|
||||
for (let i = 0; i < dependencies.length; ++i) {
|
||||
if (!map[dependencies[i]]) {
|
||||
const node = new GlslLibRoutineNode(dependencies[i]);
|
||||
map[dependencies[i]] = node;
|
||||
currentNode.addDependency(node);
|
||||
} else {
|
||||
currentNode.addDependency(map[dependencies[i]]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preprocess(): string {
|
||||
const programInfo = this.context.programInfo;
|
||||
let source = programInfo.shaderSource;
|
||||
|
||||
// append main() function
|
||||
if (!this.context.programInfo.hasMain) {
|
||||
source = `${source}
|
||||
${getDefaultFragShaderMain(this.context.glContext.version, programInfo.outputLayout.shape.length)}`;
|
||||
}
|
||||
// replace inlines
|
||||
source = replaceInlines(source);
|
||||
|
||||
// concat final source string
|
||||
return `${getFragShaderPreamble(this.context.glContext.version)}
|
||||
${this.getUniforms(programInfo.samplers, programInfo.variables)}
|
||||
${this.getImports(source)}
|
||||
${source}`;
|
||||
}
|
||||
|
||||
protected getImports(script: string): string {
|
||||
const routinesIncluded = this.selectGlslLibRoutinesToBeIncluded(script);
|
||||
|
||||
if (routinesIncluded.length === 0) {
|
||||
return '';
|
||||
}
|
||||
|
||||
let routines = '';
|
||||
for (let i = 0; i < routinesIncluded.length; ++i) {
|
||||
if (routinesIncluded[i].routineBody) {
|
||||
routines += routinesIncluded[i].routineBody + '\n';
|
||||
} else {
|
||||
throw new Error(`Missing body for the Glsl Library routine: ${routinesIncluded[i].name}`);
|
||||
}
|
||||
}
|
||||
|
||||
return routines;
|
||||
}
|
||||
private selectGlslLibRoutinesToBeIncluded(script: string): GlslLibRoutineNode[] {
|
||||
const nodes: GlslLibRoutineNode[] = [];
|
||||
|
||||
Object.keys(this.glslLibRoutineDependencyGraph).forEach(classAndRoutine => {
|
||||
const routine = classAndRoutine.split('.')[1];
|
||||
if (script.indexOf(routine) !== -1) {
|
||||
nodes.push(this.glslLibRoutineDependencyGraph[classAndRoutine]);
|
||||
}
|
||||
});
|
||||
|
||||
return TopologicalSortGlslRoutines.returnOrderedNodes(nodes);
|
||||
}
|
||||
|
||||
protected getUniforms(samplers?: string[], variables?: VariableInfo[]): string {
|
||||
const uniformLines: string[] = [];
|
||||
if (samplers) {
|
||||
for (const sampler of samplers) {
|
||||
uniformLines.push(`uniform sampler2D ${sampler};`);
|
||||
}
|
||||
}
|
||||
if (variables) {
|
||||
for (const variable of variables) {
|
||||
uniformLines.push(
|
||||
`uniform ${variable.type} ${variable.name}${variable.arrayLength ? `[${variable.arrayLength}]` : ''};`);
|
||||
}
|
||||
}
|
||||
return uniformLines.join('\n');
|
||||
}
|
||||
}
|
||||
18
js/web/lib/onnxjs/backends/webgl/glsl-registered-libs.ts
Normal file
18
js/web/lib/onnxjs/backends/webgl/glsl-registered-libs.ts
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {CoordsGlslLib} from './glsl-coordinate-lib';
|
||||
import {GlslContext, GlslLib} from './glsl-definitions';
|
||||
import {EncodingGlslLib} from './glsl-encoding-lib';
|
||||
import {FragColorGlslLib} from './glsl-fragcolor-lib';
|
||||
import {ShapeUtilsGlslLib} from './glsl-shape-utils-lib';
|
||||
import {VecGlslLib} from './glsl-vec-lib';
|
||||
|
||||
export const glslRegistry: {[name: string]: new (context: GlslContext) => GlslLib} = {
|
||||
'encoding': EncodingGlslLib,
|
||||
'fragcolor': FragColorGlslLib,
|
||||
'vec': VecGlslLib,
|
||||
'shapeUtils': ShapeUtilsGlslLib,
|
||||
'coordinates': CoordsGlslLib,
|
||||
// 'arrays': ArrayGlslSLib
|
||||
};
|
||||
171
js/web/lib/onnxjs/backends/webgl/glsl-shape-utils-lib.ts
Normal file
171
js/web/lib/onnxjs/backends/webgl/glsl-shape-utils-lib.ts
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions';
|
||||
|
||||
/**
|
||||
* GLSL Library responsible for data types and routines for manipulating
|
||||
* coordinates and mapping to/from tensor indices
|
||||
*/
|
||||
export class ShapeUtilsGlslLib extends GlslLib {
|
||||
constructor(context: GlslContext) {
|
||||
super(context);
|
||||
}
|
||||
getFunctions(): {[name: string]: GlslLibRoutine} {
|
||||
return {
|
||||
...this.bcastIndex(),
|
||||
...this.bcastMatmulIndex(),
|
||||
...this.offsetToIndices(),
|
||||
...this.indicesToOffset(),
|
||||
...this.incrementIndices()
|
||||
};
|
||||
}
|
||||
getCustomTypes() {
|
||||
return {};
|
||||
}
|
||||
protected bcastIndex(): {[name: string]: GlslLibRoutine} {
|
||||
const programInfo = this.context.programInfo;
|
||||
const outputRank = programInfo.outputLayout.shape.length;
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
this.context.programInfo.samplers.forEach((name, i) => {
|
||||
const shape = programInfo.inputLayouts[i].shape;
|
||||
if (shape.length <= outputRank) {
|
||||
const rank = shape.length;
|
||||
const dimOffset = outputRank - rank;
|
||||
const funcName = `bcastIndices_${name}`;
|
||||
let block = '';
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
block += `
|
||||
realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) );
|
||||
`;
|
||||
}
|
||||
const body = `
|
||||
void ${funcName} (int bcastedIndices[${outputRank}], out int realIndices[${rank}]) {
|
||||
${block}
|
||||
}
|
||||
`;
|
||||
result[funcName] = new GlslLibRoutine(body);
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
protected bcastMatmulIndex(): {[name: string]: GlslLibRoutine} {
|
||||
const programInfo = this.context.programInfo;
|
||||
const outputRank = programInfo.outputLayout.shape.length;
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
this.context.programInfo.samplers.forEach((name, i) => {
|
||||
const shape = programInfo.inputLayouts[i].shape;
|
||||
if (!(shape.length < 2 || shape.length > outputRank)) {
|
||||
const rank = shape.length;
|
||||
const dimOffset = outputRank - rank;
|
||||
const funcName = `bcastMatmulIndices_${name}`;
|
||||
let block = '';
|
||||
for (let i = 0; i < rank - 2; ++i) {
|
||||
block += `
|
||||
realIndices[${i}] = int( mod(float(bcastedIndices[${dimOffset + i}]), ${shape[i]}.0) );
|
||||
`;
|
||||
}
|
||||
const body = `
|
||||
void ${funcName}(int bcastedIndices[${outputRank}], out int realIndices[${rank}]) {
|
||||
${block}
|
||||
realIndices[${rank - 1}] = bcastedIndices[${outputRank - 1}];
|
||||
realIndices[${rank - 2}] = bcastedIndices[${outputRank - 2}];
|
||||
}
|
||||
`;
|
||||
result[funcName] = new GlslLibRoutine(body);
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
protected indicesToOffset(): {[name: string]: GlslLibRoutine} {
|
||||
const programInfo = this.context.programInfo;
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
this.context.programInfo.samplers.forEach((name, i) => {
|
||||
const shape = programInfo.inputLayouts[i].shape;
|
||||
const strides = programInfo.inputLayouts[i].strides;
|
||||
const rank = shape.length;
|
||||
let funcName = `indicesToOffset_${name}`;
|
||||
result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides));
|
||||
funcName = `indicesToOffset_${name}_T`;
|
||||
result[funcName] =
|
||||
new GlslLibRoutine(ShapeUtilsGlslLib.indexToOffsetSingle(funcName, rank, strides.slice().reverse()));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
static indexToOffsetSingle(name: string, rank: number, strides: readonly number[]): string {
|
||||
let block = '';
|
||||
for (let i = rank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
offset += indices[${i}] * ${strides[i]};
|
||||
`;
|
||||
}
|
||||
return `
|
||||
int ${name}(int indices[${rank}]) {
|
||||
int offset = 0;
|
||||
${block}
|
||||
return offset;
|
||||
}
|
||||
`;
|
||||
}
|
||||
protected offsetToIndices(): {[name: string]: GlslLibRoutine} {
|
||||
const programInfo = this.context.programInfo;
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
this.context.programInfo.samplers.forEach((name, i) => {
|
||||
const shape = programInfo.inputLayouts[i].shape;
|
||||
const strides = programInfo.inputLayouts[i].strides;
|
||||
const rank = shape.length;
|
||||
let funcName = `offsetToIndices_${name}`;
|
||||
result[funcName] = new GlslLibRoutine(ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides));
|
||||
funcName = `offsetToIndices_${name}_T`;
|
||||
result[funcName] =
|
||||
new GlslLibRoutine(ShapeUtilsGlslLib.offsetToIndicesSingle(funcName, rank, strides.slice().reverse()));
|
||||
});
|
||||
return result;
|
||||
}
|
||||
static offsetToIndicesSingle(name: string, rank: number, strides: readonly number[]): string {
|
||||
const stridesBlock = [];
|
||||
for (let i = 0; i < rank - 1; ++i) {
|
||||
stridesBlock.push(`
|
||||
indices[${i}] = offset / ${strides[i]};`);
|
||||
stridesBlock.push(`
|
||||
offset -= indices[${i}] * ${strides[i]};`);
|
||||
}
|
||||
stridesBlock.push(`
|
||||
indices[${rank - 1}] = offset;`);
|
||||
return `
|
||||
void ${name}(int offset, out int indices[${rank}]) {
|
||||
${stridesBlock.join('')}
|
||||
}
|
||||
`;
|
||||
}
|
||||
protected incrementIndices(): {[name: string]: GlslLibRoutine} {
|
||||
const programInfo = this.context.programInfo;
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
this.context.programInfo.samplers.forEach((name, i) => {
|
||||
const shape = programInfo.inputLayouts[i].shape;
|
||||
const rank = shape.length;
|
||||
const funcName = `incrementIndices_${name}`;
|
||||
let shapeInit = '';
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
shapeInit += `
|
||||
shape[${i}] = ${shape[i]};`;
|
||||
}
|
||||
const body = `
|
||||
void ${funcName}(int axis, out int indices[${rank}]) {
|
||||
int shape[${rank}];
|
||||
${shapeInit};
|
||||
for(int i = ${rank} -1 ; i >= 0; --i) {
|
||||
if(i > axis) continue;
|
||||
indices[i] += 1;
|
||||
if(indices[i] < shape[i]) {
|
||||
break;
|
||||
}
|
||||
indices[i] = 0;
|
||||
}
|
||||
}
|
||||
`;
|
||||
result[funcName] = new GlslLibRoutine(body);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
}
|
||||
103
js/web/lib/onnxjs/backends/webgl/glsl-source.ts
Normal file
103
js/web/lib/onnxjs/backends/webgl/glsl-source.ts
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
/**
|
||||
* represent a version irrelevant abstraction of for GLSL source code
|
||||
*/
|
||||
export interface Glsl {
|
||||
readonly version: string;
|
||||
readonly attribute: string;
|
||||
readonly varyingVertex: string;
|
||||
readonly varyingFrag: string;
|
||||
readonly texture2D: string;
|
||||
readonly output: string;
|
||||
readonly outputDeclaration: string;
|
||||
}
|
||||
|
||||
const GLSL_ES_2_0: Glsl = {
|
||||
version: '',
|
||||
attribute: 'attribute',
|
||||
varyingVertex: 'varying',
|
||||
varyingFrag: 'varying',
|
||||
texture2D: 'texture2D',
|
||||
output: 'gl_FragColor',
|
||||
outputDeclaration: '',
|
||||
};
|
||||
const GLSL_ES_3_0: Glsl = {
|
||||
version: '#version 300 es',
|
||||
attribute: 'in',
|
||||
varyingVertex: 'out',
|
||||
varyingFrag: 'in',
|
||||
texture2D: 'texture',
|
||||
output: 'outputColor',
|
||||
outputDeclaration: 'out vec4 outputColor;',
|
||||
};
|
||||
|
||||
export function getGlsl(version: 1|2) {
|
||||
return version === 1 ? GLSL_ES_2_0 : GLSL_ES_3_0;
|
||||
}
|
||||
|
||||
export function getVertexShaderSource(version: 1|2): string {
|
||||
const glsl = getGlsl(version);
|
||||
return `${glsl.version}
|
||||
precision highp float;
|
||||
${glsl.attribute} vec3 position;
|
||||
${glsl.attribute} vec2 textureCoord;
|
||||
|
||||
${glsl.varyingVertex} vec2 TexCoords;
|
||||
|
||||
void main()
|
||||
{
|
||||
gl_Position = vec4(position, 1.0);
|
||||
TexCoords = textureCoord;
|
||||
}`;
|
||||
}
|
||||
|
||||
export function getFragShaderPreamble(version: 1|2): string {
|
||||
const glsl = getGlsl(version);
|
||||
return `${glsl.version}
|
||||
precision highp float;
|
||||
precision highp int;
|
||||
precision highp sampler2D;
|
||||
${glsl.varyingFrag} vec2 TexCoords;
|
||||
${glsl.outputDeclaration}
|
||||
const vec2 halfCR = vec2(0.5, 0.5);
|
||||
|
||||
// Custom vector types to handle higher dimenalities.
|
||||
struct ivec5
|
||||
{
|
||||
int x;
|
||||
int y;
|
||||
int z;
|
||||
int w;
|
||||
int u;
|
||||
};
|
||||
|
||||
struct ivec6
|
||||
{
|
||||
int x;
|
||||
int y;
|
||||
int z;
|
||||
int w;
|
||||
int u;
|
||||
int v;
|
||||
};
|
||||
|
||||
int imod(int x, int y) {
|
||||
return x - y * (x / y);
|
||||
}
|
||||
|
||||
`;
|
||||
}
|
||||
|
||||
export function getDefaultFragShaderMain(version: 1|2, outputShapeLength: number): string {
|
||||
const glsl = getGlsl(version);
|
||||
return `
|
||||
void main() {
|
||||
int indices[${outputShapeLength}];
|
||||
toVec(TexCoords, indices);
|
||||
vec4 result = vec4(process(indices));
|
||||
${glsl.output} = result;
|
||||
}
|
||||
`;
|
||||
}
|
||||
113
js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts
Normal file
113
js/web/lib/onnxjs/backends/webgl/glsl-vec-lib.ts
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {GlslContext, GlslLib, GlslLibRoutine} from './glsl-definitions';
|
||||
|
||||
/**
|
||||
* GLSL Library responsible for vec routines
|
||||
* Vec is an varible length int array. The length is fixed at the time of
|
||||
* generating the library functions from the dimensions of the output.
|
||||
*/
|
||||
export class VecGlslLib extends GlslLib {
|
||||
constructor(context: GlslContext) {
|
||||
super(context);
|
||||
}
|
||||
getCustomTypes(): {[name: string]: string} {
|
||||
return {};
|
||||
}
|
||||
getFunctions(): {[name: string]: GlslLibRoutine} {
|
||||
return {...this.binaryVecFunctions(), ...this.copyVec(), ...this.setVecItem(), ...this.getVecItem()};
|
||||
}
|
||||
protected binaryVecFunctions(): {[name: string]: GlslLibRoutine} {
|
||||
const outputLayout = this.context.programInfo.outputLayout;
|
||||
const rank = outputLayout.shape.length;
|
||||
const nameOp: {[name: string]: string} = {add: '+=', sub: '-=', mul: '*=', div: '/='};
|
||||
const result: {[name: string]: GlslLibRoutine} = {};
|
||||
for (const name in nameOp) {
|
||||
const fname = `${name}Vec`;
|
||||
let assignmentBlock = '';
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
assignmentBlock += `
|
||||
dest[${i}] ${nameOp[name]} src[${i}];
|
||||
`;
|
||||
}
|
||||
const body = `
|
||||
void ${fname}(int src[${rank}], out int dest[${rank}]) {
|
||||
${assignmentBlock}
|
||||
}
|
||||
`;
|
||||
result[fname] = new GlslLibRoutine(body);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
protected copyVec(): {[name: string]: GlslLibRoutine} {
|
||||
const outputLayout = this.context.programInfo.outputLayout;
|
||||
const rank = outputLayout.shape.length;
|
||||
let assignmentBlock = '';
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
assignmentBlock += `
|
||||
dest[${i}] = src[${i}];
|
||||
`;
|
||||
}
|
||||
const body = `
|
||||
void copyVec(int src[${rank}], out int dest[${rank}]) {
|
||||
${assignmentBlock}
|
||||
}
|
||||
`;
|
||||
return {copyVec: new GlslLibRoutine(body)};
|
||||
}
|
||||
|
||||
protected setVecItem(): {[name: string]: GlslLibRoutine} {
|
||||
const outputLayout = this.context.programInfo.outputLayout;
|
||||
const rank = outputLayout.shape.length;
|
||||
let block = `
|
||||
if(index < 0)
|
||||
index =${rank} + index;
|
||||
if (index == 0)
|
||||
m[0] = value;
|
||||
`;
|
||||
for (let i = 1; i < rank - 1; ++i) {
|
||||
block += `
|
||||
else if (index == ${i})
|
||||
m[${i}] = value;
|
||||
`;
|
||||
}
|
||||
block += `
|
||||
else
|
||||
m[${rank - 1}] = value;
|
||||
`;
|
||||
const body = `
|
||||
void setVecItem(out int m[${rank}], int index, int value) {
|
||||
${block}
|
||||
}
|
||||
`;
|
||||
return {setVecItem: new GlslLibRoutine(body)};
|
||||
}
|
||||
protected getVecItem(): {[name: string]: GlslLibRoutine} {
|
||||
const outputLayout = this.context.programInfo.outputLayout;
|
||||
const rank = outputLayout.shape.length;
|
||||
let block = `
|
||||
if(index < 0)
|
||||
index = ${rank} + index;
|
||||
if (index == 0)
|
||||
return m[0];
|
||||
`;
|
||||
for (let i = 1; i < rank - 1; ++i) {
|
||||
block += `
|
||||
else if (index == ${i})
|
||||
return m[${i}];
|
||||
`;
|
||||
}
|
||||
block += `
|
||||
else
|
||||
return m[${rank - 1}];
|
||||
`;
|
||||
const body = `
|
||||
int getVecItem(int m[${rank}], int index) {
|
||||
${block}
|
||||
}
|
||||
`;
|
||||
return {getVecItem: new GlslLibRoutine(body)};
|
||||
}
|
||||
}
|
||||
283
js/web/lib/onnxjs/backends/webgl/inference-handler.ts
Normal file
283
js/web/lib/onnxjs/backends/webgl/inference-handler.ts
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {InferenceHandler} from '../../backend';
|
||||
import {Logger} from '../../instrument';
|
||||
import {Tensor} from '../../tensor';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {WebGLPack} from './ops/pack';
|
||||
|
||||
import {WebGLUint8Encode} from './ops/uint8-encode';
|
||||
import {WebGLUnpack} from './ops/unpack';
|
||||
import {WebGLSessionHandler} from './session-handler';
|
||||
import {Encoder} from './texture-data-encoder';
|
||||
import {WidthHeightPrefs} from './texture-layout-strategy';
|
||||
import {Artifact, RunData, TextureData, TextureLayout, WebGLOperator} from './types';
|
||||
import {getPackedShape} from './utils';
|
||||
|
||||
export class WebGLInferenceHandler implements InferenceHandler {
|
||||
private textureDataCache: Map<Tensor.Id, TextureData>;
|
||||
constructor(public session: WebGLSessionHandler) {
|
||||
this.textureDataCache = new Map();
|
||||
}
|
||||
|
||||
run(op: WebGLOperator, inputs: Tensor[]): Tensor[] {
|
||||
let artifact = this.session.programManager.getArtifact(op);
|
||||
if (!artifact) {
|
||||
const programInfo = op.createProgramInfo(this, inputs);
|
||||
artifact = this.session.programManager.build(programInfo);
|
||||
this.session.programManager.setArtifact(op, artifact);
|
||||
}
|
||||
const runData = op.createRunData(this, artifact.programInfo, inputs);
|
||||
this.runProgram(artifact, runData);
|
||||
return [runData.outputTextureData.tensor];
|
||||
}
|
||||
|
||||
runProgram(artifact: Artifact, runData: RunData) {
|
||||
// pack/unpack inputs
|
||||
runData.inputTextureDatas.forEach(input => {
|
||||
if (input.isPacked && !artifact.programInfo.expectPackedInputs) {
|
||||
// unpack this input
|
||||
const unpacked = this.unpack(input);
|
||||
input.height = unpacked.height;
|
||||
input.isPacked = unpacked.isPacked;
|
||||
input.texture = unpacked.texture;
|
||||
input.width = unpacked.width;
|
||||
|
||||
} else if (!input.isPacked && artifact.programInfo.expectPackedInputs) {
|
||||
// pack this input
|
||||
const packed = this.pack(input);
|
||||
input.height = packed.height;
|
||||
input.isPacked = packed.isPacked;
|
||||
input.texture = packed.texture;
|
||||
input.width = packed.width;
|
||||
}
|
||||
});
|
||||
|
||||
// output should match
|
||||
if (!!runData.outputTextureData.isPacked !== !!artifact.programInfo.expectPackedoutputs) {
|
||||
throw new Error('output property packed inconsistent');
|
||||
}
|
||||
|
||||
this.session.programManager.run(artifact, runData);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a TextureData object from a tensor.
|
||||
* Usage = Encoder.Usage.UploadOnly.
|
||||
* If a related texture data is found in cache, returns it;
|
||||
* Otherwise:
|
||||
* Creates a new texture layout if not provided;
|
||||
* Creates WebGLTexture with the layout;
|
||||
* Upload tensor data to the texture;
|
||||
* Creates a texture data object associated with the given tensor.
|
||||
* @param tensor the tensor with data to upload
|
||||
*/
|
||||
getOrCreateTextureData(tensor: Tensor, layout?: TextureLayout) {
|
||||
let td = this.getTextureData(tensor.dataId);
|
||||
if (!td) {
|
||||
Logger.verbose('InferenceHandler', `Creating new TextureData for dims: [${tensor.dims}]`);
|
||||
if (!layout) {
|
||||
layout = this.createTextureLayoutFromShape(tensor.dims.slice());
|
||||
}
|
||||
// graph inputs or initializers
|
||||
td = this.createTextureData(layout, tensor.type, tensor.numberData, tensor, Encoder.Usage.UploadOnly);
|
||||
} else {
|
||||
Logger.verbose('InferenceHandler', `Retrieving TextureData from cache: [${tensor.dims}]`);
|
||||
}
|
||||
return td;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a TextureData object from the given data type and texture layout.
|
||||
* Usage = Encoder.Usage.Default.
|
||||
* @param dataType the tensor data type
|
||||
*/
|
||||
createTextureDataFromLayout(layout: TextureLayout, dataType: Tensor.DataType): TextureData {
|
||||
return this.createTextureData(layout, dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a TextureData object using the given data and bind to the given tensor.
|
||||
* Usage = Encoder.Usage.UploadOnly.
|
||||
* NOTE: this function is a hack for Conv implementation. should remove this function, after rewriting Conv
|
||||
* implementation by Graph.Transformer
|
||||
* @param dataType the tensor data type
|
||||
* @param data the actual data to upload
|
||||
* @param tensor the tensor to bind. tensor's data is ignored.
|
||||
*/
|
||||
createTextureDataFromLayoutBindTensor(
|
||||
layout: TextureLayout, dataType: Tensor.DataType, data: Tensor.NumberType, tensor: Tensor): TextureData {
|
||||
return this.createTextureData(layout, dataType, data, tensor, Encoder.Usage.UploadOnly);
|
||||
}
|
||||
|
||||
private createTextureData(
|
||||
layout: TextureLayout, dataType: Tensor.DataType, data?: Tensor.NumberType, tensor?: Tensor,
|
||||
usage?: Encoder.Usage): TextureData {
|
||||
Logger.verbose('InferenceHandler', `Creating TextureData: layout:[${JSON.stringify(layout)}]`);
|
||||
const texture = this.session.textureManager.createTextureFromLayout(dataType, layout, data, usage);
|
||||
return this.createTextureDataFromTexture(layout, dataType, texture, tensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a TextureData object, using the given texture.
|
||||
* This function does not create new texture. Usually used in scenarios using texture sharing. (eg. Reshape)
|
||||
* @param dataType the tensor data type
|
||||
* @param texture the WebGLTexture object to share
|
||||
* @param tensorId the tensor ID of the shared tensor data
|
||||
*/
|
||||
createSharedTextureData(layout: TextureLayout, dataType: Tensor.DataType, texture: WebGLTexture, tensorId: Tensor.Id):
|
||||
TextureData {
|
||||
return this.createTextureDataFromTexture(layout, dataType, texture, undefined, tensorId);
|
||||
}
|
||||
|
||||
private createTextureDataFromTexture(
|
||||
layout: TextureLayout, dataType: Tensor.DataType, texture: WebGLTexture, tensor?: Tensor, tensorId?: Tensor.Id) {
|
||||
const textureData: TextureData = {
|
||||
...layout,
|
||||
tensor: tensor ||
|
||||
new Tensor(
|
||||
layout.unpackedShape, dataType,
|
||||
(_id: Tensor.Id) => {
|
||||
const data = this.readTexture(textureData);
|
||||
if (dataType === 'float32') {
|
||||
return Float32Array.from(data);
|
||||
} else if (dataType === 'bool') {
|
||||
return Uint8Array.from(data);
|
||||
}
|
||||
return data;
|
||||
},
|
||||
undefined, undefined, tensorId),
|
||||
texture
|
||||
};
|
||||
this.setTextureData(textureData.tensor.dataId, textureData);
|
||||
return textureData;
|
||||
}
|
||||
|
||||
getTextureData(tensorId: Tensor.Id): TextureData|undefined {
|
||||
return this.session.isInitializer(tensorId) ? this.session.getTextureData(tensorId) :
|
||||
this.textureDataCache.get(tensorId);
|
||||
}
|
||||
setTextureData(tensorId: Tensor.Id, td: TextureData): void {
|
||||
if (this.session.isInitializer(tensorId)) {
|
||||
this.session.setTextureData(tensorId, td);
|
||||
} else {
|
||||
this.textureDataCache.set(tensorId, td);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a TextureLayout object from a tensor. If a related texture data is found, returns the cached texture layout.
|
||||
*/
|
||||
getOrCreateTextureLayout(
|
||||
tensor: Tensor, channels: 1|4 = 1, isPacked = false, unpackedShape?: readonly number[],
|
||||
reverseWH = false): TextureLayout {
|
||||
const td = this.getTextureData(tensor.dataId);
|
||||
if (td) {
|
||||
return td;
|
||||
}
|
||||
return this.createTextureLayoutFromShape(
|
||||
channels === 1 || isPacked ? tensor.dims : getPackedShape(tensor.dims), channels, unpackedShape,
|
||||
isPacked || reverseWH ? {isPacked, reverseWH} : undefined);
|
||||
}
|
||||
/**
|
||||
* Create a TextureLayout object from shape.
|
||||
*/
|
||||
createTextureLayoutFromShape(
|
||||
shape: readonly number[], channels: 1|4 = 1, unpackedShape?: readonly number[],
|
||||
prefs?: WidthHeightPrefs): TextureLayout {
|
||||
const isPacked = !!(prefs && prefs.isPacked);
|
||||
const [texWidth, texHeight] =
|
||||
this.session.layoutStrategy.computeTextureWH(isPacked ? unpackedShape || shape : shape, prefs);
|
||||
let [width, height] = [texWidth, texHeight];
|
||||
if (prefs && prefs.reverseWH) {
|
||||
width = texHeight;
|
||||
height = texWidth;
|
||||
}
|
||||
const rank = shape.length;
|
||||
let inferredDims = shape.slice(0);
|
||||
if (rank === 0) {
|
||||
inferredDims = [1];
|
||||
}
|
||||
if (channels === 1) {
|
||||
// unpackedShape will take `shape` and not `inferredDims` so as to create a scalar Tensor if need be
|
||||
unpackedShape = shape;
|
||||
} else if (isPacked) {
|
||||
if (channels !== 4) {
|
||||
throw new Error('a packed texture must be 4-channel');
|
||||
}
|
||||
unpackedShape = shape;
|
||||
if (rank > 0) {
|
||||
inferredDims[rank - 1] = Math.ceil(inferredDims[rank - 1] / 2);
|
||||
}
|
||||
if (rank > 1) {
|
||||
inferredDims[rank - 2] = Math.ceil(inferredDims[rank - 2] / 2);
|
||||
}
|
||||
} else if (!unpackedShape) {
|
||||
throw new Error('Unpacked shape is needed when using channels > 1');
|
||||
}
|
||||
return {
|
||||
width,
|
||||
height,
|
||||
channels,
|
||||
isPacked,
|
||||
shape: inferredDims,
|
||||
strides: ShapeUtil.computeStrides(inferredDims),
|
||||
unpackedShape
|
||||
};
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this.session.textureManager.clearActiveTextures();
|
||||
this.textureDataCache.forEach(td => this.session.textureManager.releaseTexture(td));
|
||||
this.textureDataCache = new Map();
|
||||
}
|
||||
|
||||
readTexture(textureData: TextureData): Tensor.NumberType {
|
||||
if (textureData.isPacked) {
|
||||
return this.readTexture(this.unpack(textureData));
|
||||
}
|
||||
if (!this.session.backend.glContext.isFloat32DownloadSupported) {
|
||||
const op = new WebGLUint8Encode();
|
||||
const uint8TD = op.runInternal(this, textureData);
|
||||
return this.session.textureManager.readUint8TextureAsFloat(uint8TD);
|
||||
}
|
||||
return this.session.textureManager.readTexture(textureData, textureData.tensor.type, textureData.channels);
|
||||
}
|
||||
|
||||
pack(input: TextureData): TextureData {
|
||||
const key = `${input.shape}`;
|
||||
let op = this.session.packOpCache.get(key);
|
||||
if (!op) {
|
||||
op = new WebGLPack();
|
||||
this.session.packOpCache.set(key, op);
|
||||
}
|
||||
let artifact = this.session.programManager.getArtifact(op);
|
||||
if (!artifact) {
|
||||
const programInfo = op.createProgramInfo(this, [input.tensor]);
|
||||
artifact = this.session.programManager.build(programInfo);
|
||||
this.session.programManager.setArtifact(op, artifact);
|
||||
}
|
||||
const runData = op.createRunData(this, artifact.programInfo, [input.tensor]);
|
||||
this.runProgram(artifact, runData);
|
||||
return runData.outputTextureData;
|
||||
}
|
||||
|
||||
unpack(input: TextureData): TextureData {
|
||||
const key = `${input.shape}`;
|
||||
let op = this.session.unpackOpCache.get(key);
|
||||
if (!op) {
|
||||
op = new WebGLUnpack();
|
||||
this.session.unpackOpCache.set(key, op);
|
||||
}
|
||||
let artifact = this.session.programManager.getArtifact(op);
|
||||
if (!artifact) {
|
||||
const programInfo = op.createProgramInfo(this, [input.tensor]);
|
||||
artifact = this.session.programManager.build(programInfo);
|
||||
this.session.programManager.setArtifact(op, artifact);
|
||||
}
|
||||
const runData = op.createRunData(this, artifact.programInfo, [input.tensor]);
|
||||
this.runProgram(artifact, runData);
|
||||
return runData.outputTextureData;
|
||||
}
|
||||
}
|
||||
108
js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts
Normal file
108
js/web/lib/onnxjs/backends/webgl/op-resolve-rules.ts
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {FLOAT_TYPES, NUMBER_TYPES} from '../../operators';
|
||||
import {OpSet} from '../../opset';
|
||||
|
||||
import {WebGLBatchNormalization} from './ops/batch-normalization';
|
||||
import * as binaryOps from './ops/binary-op';
|
||||
import {WebGLClip} from './ops/clip';
|
||||
import {WebGLConcat} from './ops/concat';
|
||||
import {WebGLConv} from './ops/conv';
|
||||
import {WebGLDropout} from './ops/dropout';
|
||||
import {WebGLElu} from './ops/elu';
|
||||
import {WebGLFlatten} from './ops/flatten';
|
||||
import {WebGLGather} from './ops/gather';
|
||||
import {WebGLGemm} from './ops/gemm';
|
||||
import {WebGLImageScaler} from './ops/image-scaler';
|
||||
import {WebGLInstanceNormalization} from './ops/instance-normalization';
|
||||
import {WebGLLeakyRelu} from './ops/leaky-relu';
|
||||
import {WebGLMatMul} from './ops/matmul';
|
||||
import {WebGLPad} from './ops/pad';
|
||||
import {WebGLAveragePool, WebGLGlobalAveragePool, WebGLGlobalMaxPool, WebGLMaxPool} from './ops/pool';
|
||||
import * as reduceOps from './ops/reduce';
|
||||
import {WebGLReshape} from './ops/reshape';
|
||||
import {WebGLSlice, WebGLSliceV10} from './ops/slice';
|
||||
import {WebGLSoftmax} from './ops/softmax';
|
||||
import {WebGLSplit} from './ops/split';
|
||||
import {WebGLSqueeze} from './ops/squeeze';
|
||||
import {WebGLSum} from './ops/sum';
|
||||
import {WebGLTile} from './ops/tile';
|
||||
import {WebGLTranspose} from './ops/transpose';
|
||||
import * as unaryOps from './ops/unary-op';
|
||||
import {WebGLUnsqueeze} from './ops/unsqueeze';
|
||||
import {WebGLUpsample} from './ops/upsample';
|
||||
|
||||
export const WEBGL_OP_RESOLVE_RULES: readonly OpSet.ResolveRule[] = [
|
||||
['Abs', '', '6+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslAbs())],
|
||||
['Acos', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAcos())],
|
||||
['Add', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslAdd())],
|
||||
['And', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslAnd())],
|
||||
['Asin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAsin())],
|
||||
['Atan', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslAtan())],
|
||||
['AveragePool', '', '7-10', () => new WebGLAveragePool()], // TODO: support new attributes for AveragePool-10
|
||||
['BatchNormalization', '', '7+', () => new WebGLBatchNormalization()],
|
||||
['Ceil', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCeil())],
|
||||
['Clip', '', '6-10', () => new WebGLClip()],
|
||||
['Concat', '', '4+', () => new WebGLConcat()],
|
||||
['Conv', '', '1+', () => new WebGLConv()],
|
||||
['Cos', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslCos())],
|
||||
['Div', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslDiv())],
|
||||
['Dropout', '', '7+', () => new WebGLDropout()],
|
||||
['Equal', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslEqual(), undefined, 'bool')],
|
||||
['Elu', '', '6+', () => new WebGLElu()],
|
||||
['Exp', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslExp())],
|
||||
['Flatten', '', '1+', () => new WebGLFlatten()],
|
||||
['Floor', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslFloor())],
|
||||
['Gather', '', '1+', () => new WebGLGather()],
|
||||
['Gemm', '', '7-10', () => new WebGLGemm(false)],
|
||||
['Gemm', '', '11+', () => new WebGLGemm(true)],
|
||||
['GlobalAveragePool', '', '1+', () => new WebGLGlobalAveragePool()],
|
||||
['GlobalMaxPool', '', '1+', () => new WebGLGlobalMaxPool()],
|
||||
['Greater', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslGreater(), undefined, 'bool')],
|
||||
['Identity', '', '1+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslIdentity())],
|
||||
['ImageScaler', '', '1+', () => new WebGLImageScaler()],
|
||||
['InstanceNormalization', '', '6+', () => new WebGLInstanceNormalization()],
|
||||
['LeakyRelu', '', '6+', () => new WebGLLeakyRelu()],
|
||||
['Less', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslLess(), undefined, 'bool')],
|
||||
['Log', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslLog())],
|
||||
['MatMul', '', '1+', () => new WebGLMatMul()],
|
||||
['MaxPool', '', '1-9', () => new WebGLMaxPool()], // TODO: support new attributes for MaxPool-8 and MaxPool-10
|
||||
['Mul', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslMul())],
|
||||
['Neg', '', '6+', () => new unaryOps.WebGLUnaryOp(NUMBER_TYPES, unaryOps.glslNeg())],
|
||||
['Not', '', '1+', () => new unaryOps.WebGLUnaryOp(['bool'], unaryOps.glslNot())],
|
||||
['Or', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslOr())],
|
||||
['Pad', '', '2-10', () => new WebGLPad()],
|
||||
['Pow', '', '7+', () => new binaryOps.WebGLBinaryOp(FLOAT_TYPES, binaryOps.glslPow())],
|
||||
['PRelu', '', '7+', () => new binaryOps.WebGLBinaryOp(FLOAT_TYPES, binaryOps.glslPRelu())],
|
||||
['ReduceLogSum', '', '1+', () => new reduceOps.WebGLReduceLogSum()],
|
||||
['ReduceMax', '', '1+', () => new reduceOps.WebGLReduceMax()],
|
||||
['ReduceMean', '', '1+', () => new reduceOps.WebGLReduceMean()],
|
||||
['ReduceMin', '', '1+', () => new reduceOps.WebGLReduceMin()],
|
||||
['ReduceProd', '', '1+', () => new reduceOps.WebGLReduceProd()],
|
||||
['ReduceSum', '', '1+', () => new reduceOps.WebGLReduceSum()],
|
||||
['ReduceSumSquare', '', '1+', () => new reduceOps.WebGLReduceSumSquare()],
|
||||
['Relu', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslRelu())],
|
||||
['Reshape', '', '5+', () => new WebGLReshape()],
|
||||
['Sigmoid', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSigmoid())],
|
||||
['Sin', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSin())],
|
||||
['Slice', '', '10+', () => new WebGLSliceV10()], // TODO: support 'steps' for Slice-10
|
||||
['Slice', '', '1-9', () => new WebGLSlice()],
|
||||
['Softmax', '', '1+', () => new WebGLSoftmax()],
|
||||
// 'Split' operator has an optional attribute 'split'
|
||||
// this attribute determines how the specified axis of input data
|
||||
// is split. When the attribute is missing, we need the count of number of outputs
|
||||
// so that we can determine the 'split' attribute from the runtime input to the Operator
|
||||
['Split', '', '2+', (node) => new WebGLSplit(node.outputs.length)],
|
||||
['Sqrt', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslSqrt())],
|
||||
['Squeeze', '', '1+', () => new WebGLSqueeze()],
|
||||
['Sub', '', '7+', () => new binaryOps.WebGLBinaryOp(NUMBER_TYPES, binaryOps.glslSub())],
|
||||
['Sum', '', '6+', () => new WebGLSum()], // TODO: support multidirectional broadcast for Sum-8
|
||||
['Tan', '', '7+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslTan())],
|
||||
['Tanh', '', '6+', () => new unaryOps.WebGLUnaryOp(FLOAT_TYPES, unaryOps.glslTanh())],
|
||||
['Tile', '', '6+', () => new WebGLTile()],
|
||||
['Transpose', '', '1+', () => new WebGLTranspose()],
|
||||
['Upsample', '', '7-8', () => new WebGLUpsample()],
|
||||
['Unsqueeze', '', '1+', () => new WebGLUnsqueeze()],
|
||||
['Xor', '', '7+', () => new binaryOps.WebGLBinaryOp(['bool'], binaryOps.glslXor())],
|
||||
];
|
||||
43
js/web/lib/onnxjs/backends/webgl/ops/batch-normalization.ts
Normal file
43
js/web/lib/onnxjs/backends/webgl/ops/batch-normalization.ts
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {BatchNormalization} from '../../../ops/batch-normalization';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData} from '../types';
|
||||
|
||||
export class WebGLBatchNormalization extends BatchNormalization {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputLayouts = inputs.map(t => handler.getOrCreateTextureLayout(t));
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const rank = outputShape.length;
|
||||
const scale = inputLayouts[1];
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
float process(int[${rank}] indices) {
|
||||
vec2 position = offsetToCoords(indices[1], ${scale.width}, ${scale.height});
|
||||
float scale = getColorAsFloat(${glsl.texture2D}(Scale, position));
|
||||
float mean = getColorAsFloat(${glsl.texture2D}(Mean, position));
|
||||
float variance = getColorAsFloat(${glsl.texture2D}(Variance, position));
|
||||
float b = getColorAsFloat(${glsl.texture2D}(B, position));
|
||||
|
||||
return scale * ( (_A(indices) - mean) / sqrt(variance + float(${this.epsilon})) ) + b;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts,
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A', 'Scale', 'B', 'Mean', 'Variance'],
|
||||
shaderSource
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
inputs.slice(1).forEach(t => inputTDs.push(handler.getOrCreateTextureData(t)));
|
||||
const outputTD = handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type);
|
||||
return {inputTextureDatas: inputTDs, outputTextureData: outputTD, uniformData: {}};
|
||||
}
|
||||
}
|
||||
252
js/web/lib/onnxjs/backends/webgl/ops/binary-op.ts
Normal file
252
js/web/lib/onnxjs/backends/webgl/ops/binary-op.ts
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {BinaryOp} from '../../../ops/binary-op';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {BroadcastUtil, ShapeUtil} from '../../../util';
|
||||
import {FunctionType, GlslValueFunction} from '../glsl-definitions';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLBinaryOp extends BinaryOp implements WebGLOperator {
|
||||
constructor(
|
||||
typeConstraint: readonly Tensor.DataType[], protected glslFunc: GlslValueFunction, opType?: string,
|
||||
resultType?: Tensor.DataType) {
|
||||
super(typeConstraint, opType, resultType);
|
||||
}
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputLayouts = inputs.map(t => handler.getOrCreateTextureLayout(t));
|
||||
const isBroadcast = !ShapeUtil.areEqual(inputs[0].dims, inputs[1].dims);
|
||||
if (isBroadcast) {
|
||||
const outputShape = BroadcastUtil.calcShape(inputs[0].dims, inputs[1].dims, false);
|
||||
if (!outputShape) {
|
||||
throw new Error('Can\'t perform binary op on the given tensors');
|
||||
}
|
||||
const outputRank = outputShape.length;
|
||||
const aRank = inputs[0].dims.length !== 0 ? inputs[0].dims.length : 1;
|
||||
const bRank = inputs[1].dims.length !== 0 ? inputs[1].dims.length : 1;
|
||||
const aBcast = inputs[0].dims.length !== 0 ? 'bcastIndices_A(indices, aindices);' : 'aindices[0] = 0;';
|
||||
const bBcast = inputs[1].dims.length !== 0 ? 'bcastIndices_B(indices, bindices);' : 'bindices[0] = 0;';
|
||||
const shaderSource = `
|
||||
${this.glslFunc.body}
|
||||
float process(int indices[${outputRank}]) {
|
||||
int aindices[${aRank}];
|
||||
int bindices[${bRank}];
|
||||
${aBcast}
|
||||
${bBcast}
|
||||
return ${this.glslFunc.name}(_A(aindices), _B(bindices));
|
||||
}`;
|
||||
return {
|
||||
inputLayouts,
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A', 'B'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
${this.glslFunc.body}
|
||||
void main() {
|
||||
vec4 v1 = ${glsl.texture2D}(A, TexCoords);
|
||||
vec4 v2 = ${glsl.texture2D}(B, TexCoords);
|
||||
vec4 result = ${this.glslFunc.name}(v1, v2);
|
||||
${glsl.output} = result;
|
||||
}
|
||||
`;
|
||||
return {
|
||||
hasMain: true,
|
||||
inputLayouts,
|
||||
outputLayout: handler.createTextureLayoutFromShape(inputs[0].dims),
|
||||
samplers: ['A', 'B'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(
|
||||
programInfo.outputLayout, this.resultType ? this.resultType : inputs[0].type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function glslAdd(): GlslValueFunction {
|
||||
const name = 'add_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return a + b;
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return v1 + v2;
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslDiv(): GlslValueFunction {
|
||||
const name = 'div_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return a / b;
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return v1 / v2;
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslMul(): GlslValueFunction {
|
||||
const name = 'mul_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return a * b;
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return v1 * v2;
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslSub(): GlslValueFunction {
|
||||
const name = 'sub_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return a - b;
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return v1 - v2;
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslEqual(): GlslValueFunction {
|
||||
const name = 'equal_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return float(a == b);
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return vec4( v1 == v2 );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslGreater(): GlslValueFunction {
|
||||
const name = 'greater_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return float(a > b);
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return vec4( v1.r > v2.r ,
|
||||
v1.g > v2.g,
|
||||
v1.b > v2.b,
|
||||
v1.a > v2.a );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslLess(): GlslValueFunction {
|
||||
const name = 'less_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return float(a < b);
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return vec4( v1.r < v2.r ,
|
||||
v1.g < v2.g,
|
||||
v1.b < v2.b,
|
||||
v1.a < v2.a );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslAnd(): GlslValueFunction {
|
||||
const name = 'and_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return float( bool(a) && bool(b) );
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
bvec4 b1 = bvec4(v1);
|
||||
bvec4 b2 = bvec4(v2);
|
||||
return vec4( b1.r && b2.r ,
|
||||
b1.g && b2.g,
|
||||
b1.b && b2.b,
|
||||
b1.a && b2.a );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslOr(): GlslValueFunction {
|
||||
const name = 'or_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return float( bool(a) || bool(b) );
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
bvec4 b1 = bvec4(v1);
|
||||
bvec4 b2 = bvec4(v2);
|
||||
return vec4( b1.r || b2.r ,
|
||||
b1.g || b2.g,
|
||||
b1.b || b2.b,
|
||||
b1.a || b2.a );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslXor(): GlslValueFunction {
|
||||
const name = 'xor_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return float( bool(a) ^^ bool(b) );
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
bvec4 b1 = bvec4(v1);
|
||||
bvec4 b2 = bvec4(v2);
|
||||
return vec4( b1.r ^^ b2.r ,
|
||||
b1.g ^^ b2.g,
|
||||
b1.b ^^ b2.b,
|
||||
b1.a ^^ b2.a );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslPow(): GlslValueFunction {
|
||||
return glslBuiltinBinary('pow');
|
||||
}
|
||||
export function glslPRelu(): GlslValueFunction {
|
||||
const name = 'prelu_';
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return a < 0.0 ? a * b: a;
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return vec4(
|
||||
v1.r < 0.0 ? v1.r * v2.r: v1.r,
|
||||
v1.g < 0.0 ? v1.g * v2.g: v1.g,
|
||||
v1.b < 0.0 ? v1.b * v2.b: v1.b,
|
||||
v1.a < 0.0 ? v1.a * v2.a: v1.a
|
||||
);
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
|
||||
function glslBuiltinBinary(fname: string): GlslValueFunction {
|
||||
const name = `${fname}_`;
|
||||
const body = `
|
||||
float ${name}(float a, float b) {
|
||||
return ${fname}(a, b);
|
||||
}
|
||||
vec4 ${name}(vec4 v1, vec4 v2) {
|
||||
return ${fname}(v1, v2);
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
41
js/web/lib/onnxjs/backends/webgl/ops/clip.ts
Normal file
41
js/web/lib/onnxjs/backends/webgl/ops/clip.ts
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Clip} from '../../../ops/clip';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLClip extends Clip implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
const float min = float(${this.min});
|
||||
const float max = float(${this.max});
|
||||
void main() {
|
||||
float v = ${glsl.texture2D}(A, TexCoords).r;
|
||||
${glsl.output} = vec4(clamp(v, min, max));
|
||||
}
|
||||
`;
|
||||
return {
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
hasMain: true,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
149
js/web/lib/onnxjs/backends/webgl/ops/concat.ts
Normal file
149
js/web/lib/onnxjs/backends/webgl/ops/concat.ts
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Concat} from '../../../ops/concat';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLConcat extends Concat implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
if (this.axis >= inputShape.length || this.axis < (-1 * inputShape.length)) {
|
||||
throw new Error('axis specified for concat doesn\'t match input dimensionality');
|
||||
}
|
||||
if (this.axis < 0) {
|
||||
this.axis = inputShape.length + this.axis;
|
||||
}
|
||||
// ensure all of the non-concatenated axes match each other
|
||||
// calculate the shape of the output tensor while we do that
|
||||
const outputShape = inputShape.slice(0);
|
||||
for (let i = 1; i < inputs.length; i++) {
|
||||
const dataNShape = inputs[i].dims.slice();
|
||||
for (let axisIndex = 0; axisIndex < inputShape.length; axisIndex++) {
|
||||
// add to the placeholder for computing output shape
|
||||
if (axisIndex === this.axis) {
|
||||
outputShape[this.axis] += dataNShape[axisIndex];
|
||||
}
|
||||
// ensure all non-cancatenated axes match each other
|
||||
else if (inputShape[axisIndex] !== dataNShape[axisIndex]) {
|
||||
throw new Error('non concat dimensions must match');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rank = outputShape.length;
|
||||
|
||||
let getTextureIndexWhereDataResidesMethod = '';
|
||||
// in most cases linear search is sufficient, as in most scenarios, only 2 tensors are concatenated
|
||||
if (inputs.length < 5) {
|
||||
getTextureIndexWhereDataResidesMethod = this.getTextureIndexWhereDataResidesLinearSearch(inputs.length);
|
||||
} else {
|
||||
getTextureIndexWhereDataResidesMethod = this.getTextureIndexWhereDataResidesBinarySearch(inputs.length);
|
||||
}
|
||||
|
||||
const fetchDataFromCorrectTextureMethod = this.fetchDataFromCorrectTextureMethod(inputs.length, rank);
|
||||
const getValueFromArrayIndexMethod = this.getValueFromArrayIndexMethod(inputs.length);
|
||||
const samplers = inputs.map((v, i) => `X${i}`);
|
||||
const shaderSource = `
|
||||
${fetchDataFromCorrectTextureMethod}
|
||||
${getValueFromArrayIndexMethod}
|
||||
${getTextureIndexWhereDataResidesMethod}
|
||||
float process(int indices[${rank}]) {
|
||||
int textureIndex = getTextureWhereDataResides (indices[${this.axis}]);
|
||||
|
||||
if(textureIndex != 0) {
|
||||
indices[${this.axis}] = indices[${
|
||||
this.axis}] - int(getValueFromArrayIndex(sizeInConcatAxis, textureIndex-int(1)));
|
||||
}
|
||||
|
||||
return fetchDataFromCorrectTexture(textureIndex, indices);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers,
|
||||
variables: [{name: 'sizeInConcatAxis', type: 'int', arrayLength: inputs.length}],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
const sizeInConcatAxis = new Array<number>(programInfo.inputLayouts.length);
|
||||
let previousSum = 0;
|
||||
for (let i = 0; i < programInfo.inputLayouts.length; ++i) {
|
||||
previousSum += programInfo.inputLayouts[i].shape[this.axis];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
}
|
||||
const uniformData = {sizeInConcatAxis};
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData
|
||||
};
|
||||
}
|
||||
private getTextureIndexWhereDataResidesLinearSearch(numberOfTensors: number): string {
|
||||
return `int getTextureWhereDataResides(int index) {
|
||||
for(int i=0; i<${numberOfTensors}; i++) {
|
||||
if(index < int(sizeInConcatAxis[i])){
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}`;
|
||||
}
|
||||
|
||||
// TODO: Implement BinarySearch in GLSL
|
||||
private getTextureIndexWhereDataResidesBinarySearch(numberOfTensors: number): string {
|
||||
return this.getTextureIndexWhereDataResidesLinearSearch(numberOfTensors);
|
||||
}
|
||||
|
||||
private fetchDataFromCorrectTextureMethod(numberOfTensors: number, tensorRank: number) {
|
||||
const codeLines: string[] = [`float fetchDataFromCorrectTexture(int textureIndex, int indices[${tensorRank}]) {`];
|
||||
for (let i = 0; i < numberOfTensors; ++i) {
|
||||
if (i === 0) {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`if (textureIndex == ${i}) { return _X${i}(indices); }`);
|
||||
} else if (i === numberOfTensors - 1) {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`else { return _X${i}(indices); }`);
|
||||
} else {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`else if (textureIndex == ${i}) { return _X${i}(indices); }`);
|
||||
}
|
||||
}
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
'}');
|
||||
return codeLines.join('\n');
|
||||
}
|
||||
|
||||
private getValueFromArrayIndexMethod(arrayRank: number): string {
|
||||
const codeLines: string[] = [`int getValueFromArrayIndex(int arr[${arrayRank}], int index) {`];
|
||||
for (let i = 0; i < arrayRank; ++i) {
|
||||
if (i === 0) {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`if (index == ${i}) { return arr[${i}]; }`);
|
||||
} else if (i === arrayRank - 1) {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`else { return arr[${i}]; }`);
|
||||
} else {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`else if (index == ${i}) { return arr[${i}]; }`);
|
||||
}
|
||||
}
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
'}');
|
||||
|
||||
return codeLines.join('\n');
|
||||
}
|
||||
}
|
||||
284
js/web/lib/onnxjs/backends/webgl/ops/conv.ts
Normal file
284
js/web/lib/onnxjs/backends/webgl/ops/conv.ts
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Logger} from '../../../instrument';
|
||||
import {Conv} from '../../../ops/conv';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {PoolConvUtil} from '../../../util';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {Artifact, ProgramInfo, RunData, TextureLayout} from '../types';
|
||||
import {WebGLContext} from '../webgl-context';
|
||||
|
||||
export class WebGLConv extends Conv {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
const programManager = inferenceHandler.session.programManager;
|
||||
if (!this.artifacts) {
|
||||
this.artifacts = [];
|
||||
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
|
||||
for (let i = 0; i < programInfos.length; ++i) {
|
||||
const artifact = inferenceHandler.session.programManager.build(programInfos[i]);
|
||||
this.artifacts.push(artifact);
|
||||
}
|
||||
}
|
||||
const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs);
|
||||
programManager.run(this.artifacts[0], runDatas[0]);
|
||||
programManager.run(this.artifacts[1], runDatas[1]);
|
||||
return [runDatas[1].outputTextureData.tensor];
|
||||
}
|
||||
createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] {
|
||||
const xshape = inputs[0].dims.slice();
|
||||
const kshape = inputs[1].dims.slice();
|
||||
// if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims
|
||||
if (this.kernelShape.length === 0) {
|
||||
const wDims = inputs[1].dims;
|
||||
for (let i = 2; i < wDims.length; ++i) {
|
||||
this.kernelShape.push(wDims[i]);
|
||||
}
|
||||
}
|
||||
PoolConvUtil.adjustPadsBasedOnAutoPad(
|
||||
inputs[0].dims, this.strides, this.dilations, this.kernelShape, this.pads, this.autoPad);
|
||||
Logger.verbose(
|
||||
'Conv',
|
||||
`autpPad:${this.autoPad}, dilations:${this.dilations}, group:${this.group}, kernelShape:${
|
||||
this.kernelShape}, pads:${this.pads}, strides:${this.strides}`);
|
||||
const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
|
||||
const im2colProgramInfo = this.createIm2ColProgramInfo(inferenceHandler, inputs, outputShape);
|
||||
const dotProductProgramInfo =
|
||||
this.createDotProductProgramInfo(inferenceHandler, im2colProgramInfo.outputLayout, inputs, outputShape);
|
||||
return [im2colProgramInfo, dotProductProgramInfo];
|
||||
}
|
||||
createRunDatas(inferenceHandler: WebGLInferenceHandler, programInfos: ProgramInfo[], inputs: Tensor[]): RunData[] {
|
||||
const k = inputs[1];
|
||||
const b = inputs.length >= 3 ? inputs[2] : undefined;
|
||||
let kTD = inferenceHandler.getTextureData(k.dataId);
|
||||
if (!kTD) {
|
||||
Logger.verbose('Conv', 'Did not find the adjustedKernel texture in the cache. Creating rew.');
|
||||
const newKernelData =
|
||||
WebGLConv.prepKernelForDotProduct(k.dims.slice(), this.group, 4, k.floatData as Float32Array);
|
||||
// hack: should use graph transformer to rewrite initializer K
|
||||
kTD = inferenceHandler.createTextureDataFromLayoutBindTensor(
|
||||
programInfos[1].inputLayouts[1], k.type, newKernelData, k);
|
||||
}
|
||||
const runtDataIm2Col = {
|
||||
inputTextureDatas: [inferenceHandler.getOrCreateTextureData(inputs[0])],
|
||||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[0].outputLayout, inputs[0].type),
|
||||
uniformData: {}
|
||||
};
|
||||
const inputTDs = [runtDataIm2Col.outputTextureData, kTD];
|
||||
if (b) {
|
||||
inputTDs.push(inferenceHandler.getOrCreateTextureData(b));
|
||||
}
|
||||
const outputTD = inferenceHandler.createTextureDataFromLayout(programInfos[1].outputLayout, inputs[0].type);
|
||||
const runDataDotProduct = {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: outputTD,
|
||||
uniformData: {},
|
||||
draw: (glContext: WebGLContext, artifact: Artifact) => {
|
||||
const gl = glContext.gl;
|
||||
const sharedDim = artifact.programInfo.params!.sharedDim as number;
|
||||
const sharedDimReadSize = artifact.programInfo.params!.sharedDimReadSize as number;
|
||||
const sharedDimOffsetLocation = artifact.uniformLocations.find(l => l.name === 'sharedDimOffset')!.location;
|
||||
let blend = false;
|
||||
for (let k = 0; k < sharedDim; k += sharedDimReadSize) {
|
||||
Logger.verbose('MatMul2D', `k = ${k}, sharedDim: ${sharedDim}, readSize = ${sharedDimReadSize}`);
|
||||
|
||||
if (k === sharedDimReadSize) {
|
||||
blend = true;
|
||||
gl.enable(gl.BLEND);
|
||||
glContext.checkError();
|
||||
gl.blendEquation(gl.FUNC_ADD);
|
||||
glContext.checkError();
|
||||
gl.blendFunc(gl.ONE, gl.ONE);
|
||||
glContext.checkError();
|
||||
}
|
||||
|
||||
gl.uniform1i(sharedDimOffsetLocation, k);
|
||||
glContext.checkError();
|
||||
glContext.draw();
|
||||
}
|
||||
|
||||
if (blend) {
|
||||
gl.disable(gl.BLEND);
|
||||
glContext.checkError();
|
||||
}
|
||||
}
|
||||
};
|
||||
return [runtDataIm2Col, runDataDotProduct];
|
||||
}
|
||||
createIm2ColProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], outputShape: number[]):
|
||||
ProgramInfo {
|
||||
const xshape = inputs[0].dims.slice();
|
||||
const kshape = inputs[1].dims.slice();
|
||||
|
||||
const rank = outputShape.length;
|
||||
const im2colDims = WebGLConv.calcIm2ColDims(xshape, kshape, outputShape, 4);
|
||||
const outputLayout = inferenceHandler.createTextureLayoutFromShape(
|
||||
im2colDims, 4, [im2colDims[0], im2colDims[1], im2colDims[2], im2colDims[3] * 4], {breakAxis: 3});
|
||||
const shaderSource = `
|
||||
const int XC = ${xshape[1]};
|
||||
const int XH = ${xshape[2]};
|
||||
const int XW = ${xshape[3]};
|
||||
const int KH = ${this.kernelShape[0]};
|
||||
const int KW = ${this.kernelShape[1]};
|
||||
const int dilationH = ${this.dilations[0]};
|
||||
const int dilationW = ${this.dilations[1]};
|
||||
const int strideH = ${this.strides[0]};
|
||||
const int strideW = ${this.strides[1]};
|
||||
const int padH = ${this.pads[0]};
|
||||
const int padW = ${this.pads[1]};
|
||||
const int KHKW = KH*KW;
|
||||
const int XCKHKW = XC * KHKW;
|
||||
const int outputChannels = 4;
|
||||
|
||||
vec4 process(int indices[${rank}]) {
|
||||
int b = indices[0]; // batch size
|
||||
int oh = indices[1] * strideH - padH; //output height
|
||||
int ow = indices[2] * strideW - padW; //output width
|
||||
int p = indices[3] * outputChannels; //patch
|
||||
vec4 v = vec4(0.0);
|
||||
for(int i=0; i < outputChannels; ++i) {
|
||||
if(p < XCKHKW) {
|
||||
int patchC = p / KHKW;
|
||||
int patchH = (p - patchC*KHKW) / KW;
|
||||
int patchW = (p - patchC*KHKW) - patchH * KW;
|
||||
int xh2 = oh + patchH * dilationH;
|
||||
int xw2 = ow + patchW * dilationW;
|
||||
int x[${xshape.length}];
|
||||
x[0] = b;
|
||||
x[1] = patchC;
|
||||
x[2] = xh2;
|
||||
x[3] = xw2;
|
||||
if(xh2 >= 0 &&
|
||||
xh2 < XH &&
|
||||
xw2 >= 0 &&
|
||||
xw2 < XW) {
|
||||
v[i] = _X(x);
|
||||
}
|
||||
}
|
||||
++p;
|
||||
}
|
||||
return v;
|
||||
}
|
||||
`;
|
||||
return {
|
||||
inputLayouts: [inferenceHandler.createTextureLayoutFromShape(xshape)],
|
||||
outputLayout,
|
||||
samplers: ['X'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createDotProductProgramInfo(
|
||||
inferenceHandler: WebGLInferenceHandler, im2colLayout: TextureLayout, inputs: Tensor[],
|
||||
outputShape: number[]): ProgramInfo {
|
||||
const xshape = inputs[0].dims.slice();
|
||||
const kshape = inputs[1].dims.slice();
|
||||
const adjustedKernelShape = [kshape[0], Math.ceil((xshape[1] * kshape[2] * kshape[3]) / 4)];
|
||||
const kLayout = inferenceHandler.createTextureLayoutFromShape(
|
||||
adjustedKernelShape, 4, [adjustedKernelShape[0], adjustedKernelShape[1] * 4], {breakAxis: 1});
|
||||
|
||||
let bLayout: TextureLayout|undefined;
|
||||
const rank = outputShape.length;
|
||||
|
||||
const inputLayouts = [im2colLayout, kLayout];
|
||||
if (inputs.length === 3) {
|
||||
bLayout = inferenceHandler.createTextureLayoutFromShape(inputs[2].dims.slice());
|
||||
inputLayouts.push(bLayout);
|
||||
}
|
||||
const outputLayout = inferenceHandler.createTextureLayoutFromShape(outputShape);
|
||||
const initValue = (inputs.length < 3) ? '0.0' : '_B(b)';
|
||||
const sharedDim = im2colLayout.shape[3];
|
||||
const blendEnabled = inferenceHandler.session.backend.glContext.isBlendSupported;
|
||||
const sharedDimReadSize = blendEnabled && inferenceHandler.session.backend.matmulMaxBatchSize ?
|
||||
this.calcSharedDimReadSize(inferenceHandler.session.backend.matmulMaxBatchSize, sharedDim) :
|
||||
sharedDim;
|
||||
const samplers = ['Im2Col', 'K'];
|
||||
if (inputs.length === 3) {
|
||||
samplers.push('B');
|
||||
}
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
float process(int indices[${rank}]) {
|
||||
int b[1];
|
||||
b[0] = indices[1];
|
||||
int im2col[${im2colLayout.shape.length}];
|
||||
im2col[0] = indices[0];
|
||||
im2col[1] = indices[2];
|
||||
im2col[2] = indices[3];
|
||||
int im2colOffset = im2col[0] * ${im2colLayout.strides[0]} + im2col[1] * ${
|
||||
im2colLayout.strides[1]} + im2col[2] * ${im2colLayout.strides[2]} + sharedDimOffset;
|
||||
int kernelOffset = indices[1] * ${kLayout.strides[0]} + sharedDimOffset;
|
||||
float sum = sharedDimOffset == 0 ? ${initValue} : 0.0;
|
||||
for (int i = 0; i < ${sharedDimReadSize}; ++i) {
|
||||
vec2 im2colCoords = offsetToCoords(im2colOffset, ${im2colLayout.width}, ${im2colLayout.height});
|
||||
vec2 kernelCoords = offsetToCoords(kernelOffset, ${kLayout.width}, ${kLayout.height});
|
||||
sum += dot(${glsl.texture2D}(Im2Col, im2colCoords), ${glsl.texture2D}(K, kernelCoords));
|
||||
++im2colOffset;
|
||||
++kernelOffset;
|
||||
}
|
||||
return sum;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: inputs.length === 3 ? [im2colLayout, kLayout, bLayout!] : [im2colLayout, kLayout],
|
||||
outputLayout,
|
||||
shaderSource,
|
||||
samplers,
|
||||
variables: [{name: 'sharedDimOffset', type: 'int'}],
|
||||
params: {sharedDim, sharedDimReadSize}
|
||||
};
|
||||
}
|
||||
static prepKernelForDotProduct(shape: number[], group: number, channels: number, kernel: Float32Array): Float32Array {
|
||||
if (group === 1 && (channels === 1 || (shape[2] * shape[3]) % channels === 0)) {
|
||||
return kernel;
|
||||
}
|
||||
const numFeatureMaps = shape[0];
|
||||
const oldRowSize = shape[1] * shape[2] * shape[3];
|
||||
const newRowSize = Math.ceil(oldRowSize * group / channels) * channels;
|
||||
const newSize = numFeatureMaps * newRowSize;
|
||||
const buffer = new Float32Array(newSize);
|
||||
for (let f = 0; f < numFeatureMaps; ++f) {
|
||||
const oldOffset = f * oldRowSize;
|
||||
const newOffset = f * newRowSize + f % group * oldRowSize;
|
||||
buffer.set(kernel.subarray(oldOffset, oldOffset + oldRowSize), newOffset);
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
static calcIm2ColDims(inputShape: number[], kernelShape: number[], outputShape: number[], channels = 1): number[] {
|
||||
return [
|
||||
outputShape[0], outputShape[2], outputShape[3],
|
||||
Math.ceil(inputShape[1] * kernelShape[2] * kernelShape[3] / channels)
|
||||
];
|
||||
}
|
||||
static calcOutputShape(
|
||||
inputShape: number[], kernelShape: number[], dilations: number[], adjustPads: number[],
|
||||
strides: number[]): number[] {
|
||||
const batchSize = inputShape[0];
|
||||
const inputSpatialShape = inputShape.slice(2);
|
||||
const spatialRank = inputSpatialShape.length;
|
||||
const outChannels = kernelShape[0];
|
||||
const kernelSpatialShape = kernelShape.slice(2);
|
||||
const dilatedKernelShape = kernelSpatialShape.map((v, i) => v + (v - 1) * (dilations[i] - 1));
|
||||
const inputSpatialShapeWithPad = inputSpatialShape.map((v, i) => v + adjustPads[i] + adjustPads[i + spatialRank]);
|
||||
const outputSpatialShape =
|
||||
inputSpatialShapeWithPad.map((v, i) => Math.floor((v - dilatedKernelShape[i] + strides[i]) / strides[i]));
|
||||
const outputShape = [batchSize, outChannels].concat(...outputSpatialShape);
|
||||
return outputShape;
|
||||
}
|
||||
protected calcSharedDimReadSize(preferredBatchSize: number, sharedDim: number): number {
|
||||
if (preferredBatchSize <= 0 || sharedDim < preferredBatchSize || sharedDim % preferredBatchSize !== 0) {
|
||||
return sharedDim;
|
||||
}
|
||||
return preferredBatchSize;
|
||||
}
|
||||
protected calcBlockSize(outputLayout: TextureLayout): [number, number]|undefined {
|
||||
const preferredRowCount = 64;
|
||||
const preferredColCount = 64;
|
||||
if (outputLayout.height < preferredRowCount) {
|
||||
return undefined;
|
||||
}
|
||||
return [preferredColCount, preferredRowCount];
|
||||
}
|
||||
protected artifacts: Artifact[];
|
||||
protected readSize = 8;
|
||||
protected blockSize = 64;
|
||||
}
|
||||
22
js/web/lib/onnxjs/backends/webgl/ops/dropout.ts
Normal file
22
js/web/lib/onnxjs/backends/webgl/ops/dropout.ts
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Dropout} from '../../../ops/dropout';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLDropout extends Dropout implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
if (this.testMode) {
|
||||
return [inputs[0]];
|
||||
}
|
||||
throw new Error('Non test mode Dropout is not implemented yet');
|
||||
}
|
||||
createProgramInfo(_handler: WebGLInferenceHandler, _inputs: Tensor[]): ProgramInfo {
|
||||
throw new Error('Non test mode Dropout is not implemented yet');
|
||||
}
|
||||
createRunData(_handler: WebGLInferenceHandler, _programInfo: ProgramInfo, _inputs: Tensor[]): RunData {
|
||||
throw new Error('Non test mode Dropout is not implemented yet');
|
||||
}
|
||||
}
|
||||
39
js/web/lib/onnxjs/backends/webgl/ops/elu.ts
Normal file
39
js/web/lib/onnxjs/backends/webgl/ops/elu.ts
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Elu} from '../../../ops/elu';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLElu extends Elu implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
void main() {
|
||||
float v = ${glsl.texture2D}(A, TexCoords).r;
|
||||
${glsl.output} = vec4(v >= 0.0 ? v: (exp(v) - 1.0) * ${this.alpha.toExponential()}); /* float number format */
|
||||
}
|
||||
`;
|
||||
return {
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
hasMain: true,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
17
js/web/lib/onnxjs/backends/webgl/ops/flatten.ts
Normal file
17
js/web/lib/onnxjs/backends/webgl/ops/flatten.ts
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Flatten} from '../../../ops/flatten';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
|
||||
import {reshape} from './reshape';
|
||||
|
||||
export class WebGLFlatten extends Flatten {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
const outputDims = ShapeUtil.flattenShape(inputs[0].dims, this.axis);
|
||||
|
||||
return [reshape(inferenceHandler, inputs[0], outputDims)];
|
||||
}
|
||||
}
|
||||
70
js/web/lib/onnxjs/backends/webgl/ops/gather.ts
Normal file
70
js/web/lib/onnxjs/backends/webgl/ops/gather.ts
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Gather} from '../../../ops/gather';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLGather extends Gather implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
const indexDataShape = inputs[1].dims.slice();
|
||||
const outputShape = new Array(inputShape.length + indexDataShape.length - 1);
|
||||
|
||||
const axis = ShapeUtil.normalizeAxis(this.axis, inputShape.length);
|
||||
const indexCopyOps: string[] = [];
|
||||
for (let i = 0; i < outputShape.length; i++) {
|
||||
// outputShape is divided into three parts: A, B, C
|
||||
// |0 axis| axis + indexDataShape.length | end|
|
||||
// | A | B | C |
|
||||
//
|
||||
// inputIdx: [A, inputs[1][B], C]
|
||||
if (i < axis) { // A
|
||||
outputShape[i] = inputShape[i];
|
||||
indexCopyOps.push(`inputIdx[${i}] = outputIdx[${i}];`);
|
||||
} else {
|
||||
if (i < axis + indexDataShape.length) { // B
|
||||
outputShape[i] = indexDataShape[i - axis];
|
||||
indexCopyOps.push(`indexDataIdx[${i - axis}] = outputIdx[${i}];`);
|
||||
} else { // C
|
||||
outputShape[i] = inputShape[i - indexDataShape.length + 1]; // skip 1 for axis
|
||||
indexCopyOps.push(`inputIdx[${i - indexDataShape.length + 1}] = outputIdx[${i}];`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const orank = outputShape.length || 1;
|
||||
const irank = inputShape.length;
|
||||
const iDrank = indexDataShape.length || 1;
|
||||
const shaderSource = `
|
||||
float process(int outputIdx[${orank}]) {
|
||||
int inputIdx[${irank}];
|
||||
int indexDataIdx[${iDrank}];
|
||||
indexDataIdx[0] = 0;
|
||||
${indexCopyOps.join('\n ')}
|
||||
int idx = int(_B(indexDataIdx));
|
||||
inputIdx[${axis}] = idx < 0 ? idx + ${inputShape[axis]} : idx;
|
||||
return _A(inputIdx);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A', 'B'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
80
js/web/lib/onnxjs/backends/webgl/ops/gemm.ts
Normal file
80
js/web/lib/onnxjs/backends/webgl/ops/gemm.ts
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Gemm} from '../../../ops/gemm';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {GemmUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLGemm extends Gemm implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const aShape = inputs[0].dims.slice();
|
||||
const bShape = inputs[1].dims.slice();
|
||||
const [M, N] = GemmUtil.getShapeOfGemmResult(
|
||||
aShape, this.transA, bShape, this.transB, inputs.length === 3 ? inputs[2].dims : undefined);
|
||||
const oShape = [M, N];
|
||||
if (!oShape) {
|
||||
throw new Error('Can\'t use gemm on the given tensors');
|
||||
}
|
||||
let sharedDim = aShape[aShape.length - 1];
|
||||
let line = '';
|
||||
if (this.transA) {
|
||||
sharedDim = aShape[0];
|
||||
}
|
||||
if (this.transA && this.transB) {
|
||||
line = 'value += _A_T(a) * _B_T(b);';
|
||||
} else if (this.transA && !this.transB) {
|
||||
line = 'value += _A_T(a) * _B(b);';
|
||||
} else if (!this.transA && this.transB) {
|
||||
line = 'value += _A(a) * _B_T(b);';
|
||||
} else if (!this.transA && !this.transB) {
|
||||
line = 'value += _A(a) * _B(b);';
|
||||
}
|
||||
const rank = oShape.length;
|
||||
const declareC = inputs.length === 3 ? `int c[${inputs[2].dims.length}];` : '';
|
||||
const broadcastC = inputs.length === 3 ? 'bcastIndices_C(indices, c);' : '';
|
||||
const calculateC = inputs.length === 3 ? 'value += beta * _C(c);' : '';
|
||||
const shaderSource = `
|
||||
float process(int indices[${rank}]) {
|
||||
int a[${rank}];
|
||||
int b[${rank}];
|
||||
${declareC}
|
||||
|
||||
copyVec(indices, a);
|
||||
copyVec(indices, b);
|
||||
${broadcastC}
|
||||
|
||||
float value = 0.0;
|
||||
for (int k=0; k<${sharedDim}; ++k) {
|
||||
a[${rank - 1}] = k;
|
||||
b[${rank - 2}] = k;
|
||||
${line}
|
||||
}
|
||||
|
||||
value = value * alpha;
|
||||
${calculateC}
|
||||
return value;
|
||||
}`;
|
||||
const inputLayouts = inputs.map(t => inferenceHandler.getOrCreateTextureLayout(t));
|
||||
return {
|
||||
inputLayouts,
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(oShape),
|
||||
samplers: inputs.length === 3 ? ['A', 'B', 'C'] : ['A', 'B'],
|
||||
variables: [{name: 'alpha', type: 'float'}, {name: 'beta', type: 'float'}],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => inferenceHandler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {'alpha': this.alpha, 'beta': this.beta}
|
||||
};
|
||||
}
|
||||
}
|
||||
60
js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts
Normal file
60
js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {ImageScaler} from '../../../ops/image-scaler';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLImageScaler extends ImageScaler implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const rank = outputShape.length;
|
||||
const getBiasMethod = this.createGetBiasMethod(this.bias.length);
|
||||
const shaderSource = `
|
||||
${getBiasMethod}
|
||||
float process(int indices[${rank}]) {
|
||||
return _X(indices) * scale + getBias(bias, indices[1]);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['X'],
|
||||
variables: [{name: 'bias', type: 'float', arrayLength: this.bias.length}, {name: 'scale', type: 'float'}],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {'bias': this.bias, 'scale': this.scale}
|
||||
};
|
||||
}
|
||||
private createGetBiasMethod(numChannels: number): string {
|
||||
const codeLines: string[] = [`float getBias(float bias[${numChannels}], int channel) {`];
|
||||
for (let i = 0; i < numChannels; ++i) {
|
||||
if (i === 0) {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`if (channel == ${i}) { return bias[${i}]; }`);
|
||||
} else if (i === numChannels - 1) {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`else { return bias[${i}]; }`);
|
||||
} else {
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
`else if (channel == ${i}) { return bias[${i}]; }`);
|
||||
}
|
||||
}
|
||||
codeLines.push(
|
||||
'\t' +
|
||||
'}');
|
||||
return codeLines.join('\n');
|
||||
}
|
||||
}
|
||||
149
js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts
Normal file
149
js/web/lib/onnxjs/backends/webgl/ops/instance-normalization.ts
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {InstanceNormalization} from '../../../ops/instance-normalization';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {Artifact, ProgramInfo, RunData, TextureLayout} from '../types';
|
||||
|
||||
export class WebGLInstanceNormalization extends InstanceNormalization {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
if (!this.artifacts) {
|
||||
this.artifacts = [];
|
||||
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
|
||||
programInfos.forEach((pi) => {
|
||||
const artifact = inferenceHandler.session.programManager.build(pi);
|
||||
this.artifacts.push(artifact);
|
||||
});
|
||||
}
|
||||
|
||||
const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs);
|
||||
runDatas.forEach((v, i) => inferenceHandler.session.programManager.run(this.artifacts[i], v));
|
||||
return [runDatas[1].outputTextureData.tensor];
|
||||
}
|
||||
|
||||
checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (!super.checkInputTypes(inputs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inputs[0].dims.length !== 4) {
|
||||
// currently webgl implementation only support 4-D input.
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
createMeanAndVarianceProgramInfo(inferenceHandler: WebGLInferenceHandler, xLayout: TextureLayout): ProgramInfo {
|
||||
const xDims = xLayout.shape;
|
||||
const channel = xDims[1];
|
||||
const channelSize = xDims[2] * xDims[3];
|
||||
const outputShape = [xDims[0], channel];
|
||||
const outputUnpackedShape = [xDims[0], channel * 4];
|
||||
|
||||
const shaderSource = `
|
||||
vec4 process(int[2] indices) {
|
||||
vec4 v = vec4(0.0);
|
||||
int a[4];
|
||||
a[0] = indices[0];
|
||||
a[1] = indices[1];
|
||||
float temp = 0.0;
|
||||
for(int a2=0; a2<${xDims[2]}; a2++) {
|
||||
a[2] = a2;
|
||||
for(int a3=0; a3<${xDims[3]}; a3++) {
|
||||
a[3] = a3;
|
||||
float x = _X(a);
|
||||
temp += x;
|
||||
}
|
||||
}
|
||||
float mean = temp / float(${channelSize});
|
||||
temp = 0.0;
|
||||
for(int a2=0; a2<${xDims[2]}; a2++) {
|
||||
a[2] = a2;
|
||||
for(int a3=0; a3<${xDims[3]}; a3++) {
|
||||
a[3] = a3;
|
||||
float x = _X(a);
|
||||
temp += (x - mean) * (x - mean);
|
||||
}
|
||||
}
|
||||
v.r = mean;
|
||||
v.g = temp / float(${channelSize});
|
||||
|
||||
return v;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [xLayout],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape, 4, outputUnpackedShape),
|
||||
samplers: ['X'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
|
||||
createComputOutputProgramInfo(
|
||||
inferenceHandler: WebGLInferenceHandler, xLayout: TextureLayout, scaleLayout: TextureLayout,
|
||||
bLayout: TextureLayout, meanAndVarianceLayout: TextureLayout): ProgramInfo {
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
vec4 get_MeanAndVariance(int[2] mv) {
|
||||
int offset = indicesToOffset_MeanAndVariance(mv);
|
||||
vec2 coords = offsetToCoords(offset, ${meanAndVarianceLayout.width}, ${meanAndVarianceLayout.height});
|
||||
return ${glsl.texture2D}(MeanAndVariance, coords);
|
||||
}
|
||||
|
||||
float process(int[4] indices) {
|
||||
|
||||
int mv[2];
|
||||
mv[0] = indices[0];
|
||||
mv[1] = indices[1];
|
||||
vec4 mean_and_variance = get_MeanAndVariance(mv);
|
||||
float mean = mean_and_variance.r;
|
||||
float variance = mean_and_variance.g;
|
||||
|
||||
int sb[1];
|
||||
sb[0] = indices[1];
|
||||
float scale = _Scale(sb);
|
||||
float b = _B(sb);
|
||||
|
||||
return scale * (_X(indices) - mean) / sqrt(variance + epsilon) + b;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [xLayout, meanAndVarianceLayout, scaleLayout, bLayout],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(xLayout.shape),
|
||||
samplers: ['X', 'MeanAndVariance', 'Scale', 'B'],
|
||||
variables: [{name: 'epsilon', type: 'float'}],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] {
|
||||
const xLayout = inferenceHandler.getOrCreateTextureLayout(inputs[0]);
|
||||
const scaleLayout = inferenceHandler.getOrCreateTextureLayout(inputs[1]);
|
||||
const bLayout = inferenceHandler.getOrCreateTextureLayout(inputs[2]);
|
||||
const meanAndVarianceProgramInfo = this.createMeanAndVarianceProgramInfo(inferenceHandler, xLayout);
|
||||
const computeOutputProgramInfo = this.createComputOutputProgramInfo(
|
||||
inferenceHandler, xLayout, scaleLayout, bLayout, meanAndVarianceProgramInfo.outputLayout);
|
||||
|
||||
const programInfos: ProgramInfo[] = [meanAndVarianceProgramInfo, computeOutputProgramInfo];
|
||||
return programInfos;
|
||||
}
|
||||
createRunDatas(inferenceHandler: WebGLInferenceHandler, programInfos: ProgramInfo[], inputs: Tensor[]): RunData[] {
|
||||
const dataType = inputs[0].type;
|
||||
const inputTD = inferenceHandler.getOrCreateTextureData(inputs[0], programInfos[0].inputLayouts[0]);
|
||||
const scaleTD = inferenceHandler.getOrCreateTextureData(inputs[1], programInfos[1].inputLayouts[2]);
|
||||
const bTD = inferenceHandler.getOrCreateTextureData(inputs[2], programInfos[1].inputLayouts[3]);
|
||||
const runDatas: RunData[] = [];
|
||||
runDatas.push({
|
||||
inputTextureDatas: [inputTD],
|
||||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[0].outputLayout, dataType),
|
||||
uniformData: {}
|
||||
});
|
||||
runDatas.push({
|
||||
inputTextureDatas: [inputTD, runDatas[0].outputTextureData, scaleTD, bTD],
|
||||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[1].outputLayout, dataType),
|
||||
uniformData: {'epsilon': this.epsilon}
|
||||
});
|
||||
return runDatas;
|
||||
}
|
||||
protected artifacts: Artifact[];
|
||||
}
|
||||
39
js/web/lib/onnxjs/backends/webgl/ops/leaky-relu.ts
Normal file
39
js/web/lib/onnxjs/backends/webgl/ops/leaky-relu.ts
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {LeakyRelu} from '../../../ops/leaky-relu';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLLeakyRelu extends LeakyRelu implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
void main() {
|
||||
float v = ${glsl.texture2D}(A, TexCoords).r;
|
||||
${glsl.output} = vec4(v < 0.0 ? v * float(${this.alpha}) : v);
|
||||
}
|
||||
`;
|
||||
return {
|
||||
hasMain: true,
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
55
js/web/lib/onnxjs/backends/webgl/ops/matmul.ts
Normal file
55
js/web/lib/onnxjs/backends/webgl/ops/matmul.ts
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {MatMul} from '../../../ops/matmul';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {BroadcastUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLMatMul extends MatMul implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const aShape = inputs[0].dims;
|
||||
const bShape = inputs[1].dims;
|
||||
const outputShape = BroadcastUtil.calcShape(aShape, bShape, true);
|
||||
if (!outputShape) {
|
||||
throw new Error('Can\'t use matmul on the given tensors');
|
||||
}
|
||||
const rank = outputShape.length;
|
||||
const arank = aShape.length;
|
||||
const brank = bShape.length;
|
||||
const sharedDim = aShape[aShape.length - 1];
|
||||
const shaderSource = `
|
||||
float process(int indices[${rank}]) {
|
||||
int a[${arank}];
|
||||
int b[${brank}];
|
||||
bcastMatmulIndices_A(indices, a);
|
||||
bcastMatmulIndices_B(indices, b);
|
||||
|
||||
float value;
|
||||
for (int k=0; k<${sharedDim}; ++k) {
|
||||
a[${arank - 1}] = k;
|
||||
b[${brank - 2}] = k;
|
||||
value += _A(a) * _B(b);
|
||||
}
|
||||
return value;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A', 'B'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
149
js/web/lib/onnxjs/backends/webgl/ops/pack.ts
Normal file
149
js/web/lib/onnxjs/backends/webgl/ops/pack.ts
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
import {getCoordsDataType} from '../utils';
|
||||
|
||||
import {getChannels} from './packing_utils';
|
||||
|
||||
export class WebGLPack implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
if (inputs.length !== 1) {
|
||||
throw new Error('Pack kernel should have input tensor count to 1.');
|
||||
}
|
||||
|
||||
const inputShape = inputs[0].dims;
|
||||
|
||||
const outputLayout =
|
||||
handler.createTextureLayoutFromShape(inputShape, 4, inputShape, {isPacked: true, reverseWH: true});
|
||||
const outputShape = outputLayout.shape;
|
||||
const inputRank = inputShape.length;
|
||||
const outputRank = outputShape.length;
|
||||
|
||||
const coordsDataType = getCoordsDataType(outputRank);
|
||||
const channels = getChannels('rc', outputRank);
|
||||
const setup = getSetup(outputRank, channels, inputShape[inputShape.length - 2], inputShape[inputShape.length - 1]);
|
||||
|
||||
let reversedInputWH;
|
||||
if (inputRank === 0) {
|
||||
reversedInputWH = [1, 1];
|
||||
} else if (inputRank === 1) {
|
||||
reversedInputWH = [inputShape[0], 1];
|
||||
} else {
|
||||
reversedInputWH = [inputShape[outputRank - 1], inputShape[outputRank - 2]];
|
||||
}
|
||||
const outOfBoundsCondition = getOutOfBoundsCondition(outputRank, reversedInputWH, channels);
|
||||
const output = getOutput(inputShape, channels);
|
||||
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
void main() {
|
||||
${coordsDataType} rc = getOutputCoords();
|
||||
|
||||
if(${outOfBoundsCondition}) {
|
||||
${glsl.output} = vec4(0);
|
||||
} else {
|
||||
${setup}
|
||||
|
||||
${glsl.output} = vec4(${output});
|
||||
}
|
||||
}
|
||||
`;
|
||||
|
||||
return {
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0], 1, false, [], true)],
|
||||
outputLayout,
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
hasMain: true,
|
||||
expectPackedInputs: false,
|
||||
expectPackedoutputs: true,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* check output coordinate location and return false if it is outside input's width/height boundary
|
||||
*/
|
||||
function getOutOfBoundsCondition(rank: number, shape: readonly number[], dims: string[]): string {
|
||||
if (rank === 1) {
|
||||
return `rc > ${shape[0]}`;
|
||||
}
|
||||
|
||||
let cond = '';
|
||||
for (let i = rank - 2; i < rank; i++) {
|
||||
cond += `${dims[i]} >= ${shape[i - rank + 2]}`;
|
||||
if (i < rank - 1) {
|
||||
cond += '||';
|
||||
}
|
||||
}
|
||||
|
||||
return cond;
|
||||
}
|
||||
|
||||
/**
|
||||
* code snippet to sample input texture with output coordiantes
|
||||
*/
|
||||
function getOutput(shape: readonly number[], dims: string[]): string {
|
||||
const rank = shape.length;
|
||||
|
||||
if (rank === 0) {
|
||||
return 'getA(), 0, 0, 0';
|
||||
}
|
||||
|
||||
if (rank === 1) {
|
||||
return `getA(rc),
|
||||
rc + 1 >= ${shape[0]} ? 0. : getA(rc + 1),
|
||||
0, 0`;
|
||||
}
|
||||
|
||||
const coord00 = 'r, c';
|
||||
const coord01 = 'r, cp1';
|
||||
const coord10 = 'rp1, c';
|
||||
const coord11 = 'rp1, cp1';
|
||||
let D = '';
|
||||
if (rank > 2) {
|
||||
for (let i = 0; i < rank - 2; ++i) {
|
||||
D = D + `${dims[i]},`;
|
||||
}
|
||||
}
|
||||
return `getA(${D}${coord00}),
|
||||
rEdge ? 0. : getA(${D}${coord10}),
|
||||
cEdge ? 0. : getA(${D}${coord01}),
|
||||
rEdge || cEdge ? 0. : getA(${D}${coord11})`;
|
||||
}
|
||||
|
||||
/**
|
||||
* code snippet to setup 4 coordinates and edge conditions
|
||||
*/
|
||||
function getSetup(rank: number, dims: string[], rows: number, cols: number): string {
|
||||
if (rank === 0 || rank === 1) {
|
||||
return '';
|
||||
}
|
||||
// rank >= 2 for width+height pack.
|
||||
else {
|
||||
const setup = `
|
||||
int r = ${dims[rank - 2]};
|
||||
int c = ${dims[rank - 1]};
|
||||
int rp1 = ${dims[rank - 2]} + 1;
|
||||
int cp1 = ${dims[rank - 1]} + 1;
|
||||
bool rEdge = rp1 >= ${cols};
|
||||
bool cEdge = cp1 >= ${rows};
|
||||
`;
|
||||
return setup;
|
||||
}
|
||||
}
|
||||
32
js/web/lib/onnxjs/backends/webgl/ops/packing_utils.ts
Normal file
32
js/web/lib/onnxjs/backends/webgl/ops/packing_utils.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
export function getVecChannels(name: string, rank: number): string[] {
|
||||
return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(d => `${name}.${d}`);
|
||||
}
|
||||
|
||||
export function getChannels(name: string, rank: number): string[] {
|
||||
if (rank === 1) {
|
||||
return [name];
|
||||
}
|
||||
return getVecChannels(name, rank);
|
||||
}
|
||||
|
||||
export function unpackFromChannel(rank: number): string {
|
||||
if (rank <= 1) {
|
||||
return `
|
||||
float getChannel(vec4 frag, int dim) {
|
||||
int modCoord = imod(dim, 2);
|
||||
return modCoord == 0 ? frag.r : frag.g;
|
||||
}
|
||||
`;
|
||||
}
|
||||
return `
|
||||
float getChannel(vec4 frag, vec2 innerDims) {
|
||||
vec2 modCoord = mod(innerDims, 2.);
|
||||
return modCoord.x == 0. ?
|
||||
(modCoord.y == 0. ? frag.r : frag.g) :
|
||||
(modCoord.y == 0. ? frag.b : frag.a);
|
||||
}
|
||||
`;
|
||||
}
|
||||
137
js/web/lib/onnxjs/backends/webgl/ops/pad.ts
Normal file
137
js/web/lib/onnxjs/backends/webgl/ops/pad.ts
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Pad} from '../../../ops/pad';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {getGlsl, Glsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLPad extends Pad implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), this.pads);
|
||||
const rank = outputShape.length;
|
||||
const alayout = inferenceHandler.getOrCreateTextureLayout(inputs[0]);
|
||||
const padFunction = getPadFunction(
|
||||
getGlsl(inferenceHandler.session.backend.glContext.version), 'A', alayout, this.mode, this.pads, this.value);
|
||||
const shaderSource = `
|
||||
${padFunction}
|
||||
float process(int[${rank}] indices) {
|
||||
return padA(indices);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [alayout],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
export function getPadFunction(
|
||||
glsl: Glsl, name: string, inputLayout: TextureLayout, mode: string, pads: number[], value: number): string {
|
||||
switch (mode) {
|
||||
case 'constant':
|
||||
return getPadConstant(
|
||||
glsl, name, inputLayout.shape, inputLayout.strides, inputLayout.width, inputLayout.height, pads, value);
|
||||
case 'reflect':
|
||||
return getPadReflect(
|
||||
glsl, name, inputLayout.shape, inputLayout.strides, inputLayout.width, inputLayout.height, pads);
|
||||
case 'edge':
|
||||
return getPadEdge(
|
||||
glsl, name, inputLayout.shape, inputLayout.strides, inputLayout.width, inputLayout.height, pads);
|
||||
default:
|
||||
throw new Error('Invalid mode');
|
||||
}
|
||||
}
|
||||
function getPadConstant(
|
||||
glsl: Glsl, name: string, shape: readonly number[], strides: readonly number[], width: number, height: number,
|
||||
pads: number[], value: number) {
|
||||
const rank = shape.length;
|
||||
let block = '';
|
||||
for (let i = rank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = m[${i}] - ${pads[i]};
|
||||
if (k < 0) return constant;
|
||||
if (k >= ${shape[i]}) return constant;
|
||||
offset += k * ${strides[i]};
|
||||
`;
|
||||
}
|
||||
return `
|
||||
float pad${name}(int m[${rank}]) {
|
||||
const float constant = float(${value});
|
||||
int offset = 0;
|
||||
int k = 0;
|
||||
${block}
|
||||
vec2 coords = offsetToCoords(offset, ${width}, ${height});
|
||||
float value = getColorAsFloat(${glsl.texture2D}(${name}, coords));
|
||||
return value;
|
||||
}
|
||||
`;
|
||||
}
|
||||
function getPadReflect(
|
||||
glsl: Glsl, name: string, shape: readonly number[], strides: readonly number[], width: number, height: number,
|
||||
pads: number[]) {
|
||||
const rank = shape.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = rank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = m[${i}] - ${pads[i]};
|
||||
if (k < 0) { k = -k; }
|
||||
{
|
||||
const int _2n_1 = ${2 * (shape[i] - 1)};
|
||||
k = int( mod( float(k), float(_2n_1) ) ) ;
|
||||
if(k >= ${shape[i]}) { k = _2n_1 - k; }
|
||||
}
|
||||
offset += k * ${strides[i]};
|
||||
`;
|
||||
}
|
||||
return `
|
||||
float pad${name}(int m[${rank}]) {
|
||||
int offset = 0;
|
||||
int k = 0;
|
||||
${block}
|
||||
vec2 coords = offsetToCoords(offset, ${width}, ${height});
|
||||
float value = getColorAsFloat(${glsl.texture2D}(${name}, coords));
|
||||
return value;
|
||||
}
|
||||
`;
|
||||
}
|
||||
function getPadEdge(
|
||||
glsl: Glsl, name: string, shape: readonly number[], strides: readonly number[], width: number, height: number,
|
||||
pads: number[]) {
|
||||
const rank = shape.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = rank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = m[${i}] - ${pads[i]};
|
||||
if (k < 0) k = 0;
|
||||
if (k >= ${shape[i]}) k = ${shape[i] - 1};
|
||||
offset += k * ${strides[i]};
|
||||
`;
|
||||
}
|
||||
return `
|
||||
float pad${name}(int m[${rank}]) {
|
||||
int offset = 0;
|
||||
int k = 0;
|
||||
${block}
|
||||
vec2 coords = offsetToCoords(offset, ${width}, ${height});
|
||||
float value = getColorAsFloat(${glsl.texture2D}(${name}, coords));
|
||||
return value;
|
||||
}
|
||||
`;
|
||||
}
|
||||
293
js/web/lib/onnxjs/backends/webgl/ops/pool.ts
Normal file
293
js/web/lib/onnxjs/backends/webgl/ops/pool.ts
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {AveragePool, GlobalAveragePool, GlobalMaxPool, MaxPool} from '../../../ops/pool';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {PoolConvUtil, ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, TextureLayout, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLGlobalAveragePool extends GlobalAveragePool implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
return createAveragePoolProgramInfo(
|
||||
inferenceHandler, inputs, true, this.kernelShape, this.autoPad, this.strides, this.pads, this.countIncludePad);
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLAveragePool extends AveragePool implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
return createAveragePoolProgramInfo(
|
||||
inferenceHandler, inputs, false, this.kernelShape, this.autoPad, this.strides, this.pads, this.countIncludePad);
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
function createAveragePoolProgramInfo(
|
||||
inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], isGlobalOperator: boolean, kernelShape: number[] = [],
|
||||
autoPad = '', strides: number[] = [], pads: number[] = [], countIncludePad: boolean): ProgramInfo {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShape, kernelShape, strides, pads);
|
||||
const outputShape =
|
||||
PoolConvUtil.computePoolOutputShape(isGlobalOperator, inputShape, strides, kernelShape, pads, autoPad);
|
||||
const kernelSize = ShapeUtil.size(kernelShape);
|
||||
const op1 = 'value += _X(x);';
|
||||
let op2 = '';
|
||||
if (countIncludePad) {
|
||||
op2 += `value /= float(${kernelSize});`;
|
||||
} else {
|
||||
op2 += `value /= float(${kernelSize} - pad);`;
|
||||
}
|
||||
const inputLayout = inferenceHandler.getOrCreateTextureLayout(inputs[0]);
|
||||
const poolingCode = generatePoolingCode(inputLayout, kernelShape, pads, strides, op1, op2, '0.0');
|
||||
const shaderSource = `
|
||||
${poolingCode}
|
||||
`;
|
||||
return {
|
||||
inputLayouts: [inputLayout],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['X'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
|
||||
export class WebGLGlobalMaxPool extends GlobalMaxPool implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
return createMaxPoolProgramInfo(
|
||||
inferenceHandler, inputs, true, this.kernelShape, this.autoPad, this.strides, this.pads);
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLMaxPool extends MaxPool implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
return createMaxPoolProgramInfo(
|
||||
inferenceHandler, inputs, false, this.kernelShape, this.autoPad, this.strides, this.pads);
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
function createMaxPoolProgramInfo(
|
||||
inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], isGlobalOperator: boolean, kernelShape: number[] = [],
|
||||
autoPad = '', strides: number[] = [], pads: number[] = []): ProgramInfo {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShape, kernelShape, strides, pads);
|
||||
const outputShape =
|
||||
PoolConvUtil.computePoolOutputShape(isGlobalOperator, inputShape, strides, kernelShape, pads, autoPad);
|
||||
const op1 = `
|
||||
value = max(_X(x), value);
|
||||
`;
|
||||
const op2 = '';
|
||||
const inputLayout = inferenceHandler.createTextureLayoutFromShape(inputShape);
|
||||
const poolingCode = generatePoolingCode(inputLayout, kernelShape, pads, strides, op1, op2, '-1e5');
|
||||
const shaderSource = `
|
||||
${poolingCode}
|
||||
`;
|
||||
return {
|
||||
inputLayouts: [inputLayout],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['X'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
|
||||
export function generatePoolingCode(
|
||||
x: TextureLayout, kernelShape: number[], pads: number[], strides: number[], op1: string, op2: string,
|
||||
startVal: string): string {
|
||||
const inputDims = x.shape;
|
||||
const rank = x.shape.length;
|
||||
if (kernelShape.length <= 2) {
|
||||
const kw = kernelShape[kernelShape.length - 1];
|
||||
const sw = strides[strides.length - 1];
|
||||
const pwStart = pads[pads.length / 2 - 1];
|
||||
const pwEnd = pads[pads.length - 1];
|
||||
const dimW = inputDims[rank - 1];
|
||||
let codeW = '';
|
||||
let codeH = '';
|
||||
let codeHEnd = '';
|
||||
if (pwStart + pwEnd !== 0) {
|
||||
codeW = `
|
||||
for (int i = 0; i < ${kw}; i++) {
|
||||
x[${rank} - 1] = indices[${rank} - 1] * ${sw} - ${pwStart} + i;
|
||||
if (x[${rank} - 1] < 0 || x[${rank} - 1] >= ${dimW}) {
|
||||
pad++;
|
||||
continue;
|
||||
}
|
||||
${op1}
|
||||
}`;
|
||||
} else {
|
||||
codeW = `
|
||||
for (int i = 0; i < ${kw}; i++) {
|
||||
x[${rank} - 1] = indices[${rank} - 1] * ${sw} - ${pwStart} + i;
|
||||
${op1}
|
||||
}`;
|
||||
}
|
||||
|
||||
if (kernelShape.length === 2) {
|
||||
const kh = kernelShape[kernelShape.length - 2];
|
||||
const sh = strides[strides.length - 2];
|
||||
const phStart = pads[pads.length / 2 - 2];
|
||||
const phEnd = pads[pads.length - 2];
|
||||
const dimH = inputDims[rank - 2];
|
||||
if (phStart + phEnd !== 0) {
|
||||
codeH = `
|
||||
for (int j = 0; j < ${kh}; j++) {
|
||||
x[${rank} - 2] = indices[${rank} - 2] * ${sh} - ${phStart} + j;
|
||||
if (x[${rank} - 2] < 0 || x[${rank} - 2] >= ${dimH}) {
|
||||
pad+= ${kw};
|
||||
continue;
|
||||
}
|
||||
`;
|
||||
} else {
|
||||
codeH = `
|
||||
for (int j = 0; j < ${kh}; j++) {
|
||||
x[${rank} - 2] = indices[${rank} - 2] * ${sh} - ${phStart} + j;
|
||||
`;
|
||||
}
|
||||
codeHEnd = `
|
||||
}
|
||||
`;
|
||||
}
|
||||
|
||||
const poolingCode = `
|
||||
float process(int indices[${rank}]) {
|
||||
int x[${rank}];
|
||||
copyVec(indices, x);
|
||||
|
||||
float value = ${startVal};
|
||||
int pad = 0;
|
||||
${codeH}
|
||||
${codeW}
|
||||
${codeHEnd}
|
||||
${op2}
|
||||
return value;
|
||||
}
|
||||
`;
|
||||
return poolingCode;
|
||||
} else {
|
||||
const kernelSize = ShapeUtil.size(kernelShape);
|
||||
const kernelStrides = ShapeUtil.computeStrides(kernelShape);
|
||||
const stridesRank = kernelStrides.length;
|
||||
const padsRank = pads.length;
|
||||
const offsetToIndicesFunction = offsetToIndices(stridesRank);
|
||||
const copyInputDims = copyArray(inputDims, 'inputDims');
|
||||
const copyPads = copyArray(pads, 'pads');
|
||||
const copyKernelStrides = copyArray(kernelStrides, 'kernelStrides');
|
||||
const copyStrides = copyArray(strides, 'strides');
|
||||
const hasPads = pads.reduce((sum, cur) => sum + cur);
|
||||
let padCode = '';
|
||||
if (hasPads) {
|
||||
padCode = `
|
||||
if (x[j] >= inputDims[j] || x[j] < 0) {
|
||||
pad++;
|
||||
isPad = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isPad) {
|
||||
${op1}
|
||||
}`;
|
||||
} else {
|
||||
padCode = `
|
||||
}
|
||||
${op1}`;
|
||||
}
|
||||
const poolingCode = `
|
||||
${offsetToIndicesFunction}
|
||||
float process(int indices[${rank}]) {
|
||||
int x[${rank}];
|
||||
copyVec(indices, x);
|
||||
int offset[${stridesRank}];
|
||||
int pads[${padsRank}];
|
||||
int inputDims[${rank}];
|
||||
int kernelStrides[${stridesRank}];
|
||||
int strides[${stridesRank}];
|
||||
${copyPads}
|
||||
${copyInputDims}
|
||||
${copyStrides}
|
||||
${copyKernelStrides}
|
||||
|
||||
float value = ${startVal};
|
||||
int pad = 0;
|
||||
bool isPad = false;
|
||||
for (int i = 0; i < ${kernelSize}; i++) {
|
||||
offsetToIndices(i, kernelStrides, offset);
|
||||
isPad = false;
|
||||
for (int j = ${rank} - ${stridesRank}; j < ${rank}; j++) {
|
||||
x[j] = indices[j] * strides[j - ${rank} + ${stridesRank}]
|
||||
+ offset[j - ${rank} + ${stridesRank}] - pads[j - 2];
|
||||
${padCode}
|
||||
}
|
||||
${op2}
|
||||
|
||||
return value;
|
||||
}`;
|
||||
return poolingCode;
|
||||
}
|
||||
}
|
||||
|
||||
export function copyArray(array: readonly number[], arrayName: string): string {
|
||||
let block = '';
|
||||
for (let i = 0; i < array.length; i++) {
|
||||
block += `
|
||||
${arrayName}[${i}] = ${array[i]};
|
||||
`;
|
||||
}
|
||||
return block;
|
||||
}
|
||||
|
||||
export function offsetToIndices(rank: number): string {
|
||||
return `
|
||||
void offsetToIndices(int offset, int[${rank}] strides, out int[${rank}] indices) {
|
||||
if (${rank} == 0) {
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < ${rank} - 1; ++i) {
|
||||
indices[i] = offset / strides[i];
|
||||
offset -= indices[i] * strides[i];
|
||||
}
|
||||
indices[${rank} - 1] = offset;
|
||||
}`;
|
||||
}
|
||||
138
js/web/lib/onnxjs/backends/webgl/ops/reduce.ts
Normal file
138
js/web/lib/onnxjs/backends/webgl/ops/reduce.ts
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {ReduceBase} from '../../../ops/reduce-op';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
abstract class WebGLGenericReduce extends ReduceBase implements WebGLOperator {
|
||||
abstract getOps(inputs: Tensor[], axes: number[]): string[];
|
||||
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape: number[] = [];
|
||||
const iRank = inputs[0].dims.length || 1;
|
||||
|
||||
const idxCopy = []; // copy output indexes to input indexes
|
||||
|
||||
const axes = ShapeUtil.normalizeAxes(this.axes, inputs[0].dims.length);
|
||||
const ops = this.getOps(inputs, axes); // [init ops, reduce ops, final ops]
|
||||
let reduceOps = ops[1];
|
||||
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
// if this axis is reduced
|
||||
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
||||
if (this.keepDims) {
|
||||
outputShape.push(1);
|
||||
} // else { remove the axis from outputShape; }
|
||||
|
||||
// loop over the d-th axis
|
||||
reduceOps = `
|
||||
for(int j${k} = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
|
||||
inputIdx[${k}] = j${k};
|
||||
${reduceOps}
|
||||
}
|
||||
`;
|
||||
} else {
|
||||
idxCopy.push(`inputIdx[${k}] = outputIdx[${outputShape.length}];`);
|
||||
|
||||
outputShape.push(inputs[0].dims[k]);
|
||||
}
|
||||
}
|
||||
|
||||
const oRank = outputShape.length || 1;
|
||||
|
||||
const shaderSource = `
|
||||
float process(int outputIdx[${oRank}]) {
|
||||
float value; // final result
|
||||
int inputIdx[${iRank}]; // addressing input data
|
||||
${idxCopy.join('\n')}
|
||||
${ops[0]} // init ops for reduce max/min
|
||||
${reduceOps}
|
||||
${ops[2]} // final computation for reduce mean
|
||||
return value;
|
||||
}`;
|
||||
|
||||
return {
|
||||
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceSum extends WebGLGenericReduce {
|
||||
getOps(_inputs: Tensor[]): string[] {
|
||||
return ['value = 0.0;', 'value += _A(inputIdx);', ''];
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceMean extends WebGLGenericReduce {
|
||||
getOps(inputs: Tensor[], axes: number[]): string[] {
|
||||
let size = 1.0;
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
||||
size *= inputs[0].dims[k];
|
||||
}
|
||||
}
|
||||
|
||||
return ['value = 0.0;', 'value += _A(inputIdx);', `value /= ${size}.;`]; // ensure real number with `.`
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceMax extends WebGLGenericReduce {
|
||||
getOps(inputs: Tensor[], axes: number[]): string[] {
|
||||
const idxZero = [];
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
||||
idxZero.push(`inputIdx[${k}] = 0;`); // first element
|
||||
}
|
||||
}
|
||||
|
||||
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = max(value, _A(inputIdx));', ''];
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceMin extends WebGLGenericReduce {
|
||||
getOps(inputs: Tensor[], axes: number[]): string[] {
|
||||
const idxZero = [];
|
||||
for (let k = 0; k < inputs[0].dims.length; k++) {
|
||||
if (axes.indexOf(k) >= 0 || axes.length === 0) {
|
||||
idxZero.push(`inputIdx[${k}] = 0;`); // first element
|
||||
}
|
||||
}
|
||||
|
||||
return [`${idxZero.join('\n')}\nvalue = _A(inputIdx);`, 'value = min(value, _A(inputIdx));', ''];
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceProd extends WebGLGenericReduce {
|
||||
getOps(_inputs: Tensor[]): string[] {
|
||||
return ['value = 1.0;', 'value *= _A(inputIdx);', ''];
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceLogSum extends WebGLGenericReduce {
|
||||
getOps(_inputs: Tensor[]): string[] {
|
||||
return ['value = 0.0;', 'value += _A(inputIdx);', 'value = log(value);'];
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLReduceSumSquare extends WebGLGenericReduce {
|
||||
getOps(_inputs: Tensor[]): string[] {
|
||||
return ['float t; value = 0.0;', 't = _A(inputIdx); value += t * t;', ''];
|
||||
}
|
||||
}
|
||||
38
js/web/lib/onnxjs/backends/webgl/ops/reshape.ts
Normal file
38
js/web/lib/onnxjs/backends/webgl/ops/reshape.ts
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Reshape} from '../../../ops/reshape';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {TextureLayout} from '../types';
|
||||
import {getPackedShape} from '../utils';
|
||||
|
||||
export class WebGLReshape extends Reshape {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
const reshapedDims = ShapeUtil.calculateReshapedDims(inputs[0].dims, inputs[1].integerData);
|
||||
const reshapedTensor = reshape(inferenceHandler, inputs[0], reshapedDims);
|
||||
return [reshapedTensor];
|
||||
}
|
||||
}
|
||||
|
||||
export function reshape(
|
||||
inferenceHandler: WebGLInferenceHandler, input: Tensor, reshapedDims: readonly number[]): Tensor {
|
||||
const inputTD = inferenceHandler.getOrCreateTextureData(input);
|
||||
let packedShape = reshapedDims;
|
||||
if (inputTD.channels === 4) {
|
||||
packedShape = getPackedShape(reshapedDims);
|
||||
}
|
||||
const newTextureLayout: TextureLayout = {
|
||||
channels: inputTD.channels,
|
||||
height: inputTD.height,
|
||||
width: inputTD.width,
|
||||
// handle reshaping into scalar Tensors
|
||||
shape: packedShape.length !== 0 ? packedShape : [1],
|
||||
strides: ShapeUtil.computeStrides(packedShape),
|
||||
unpackedShape: reshapedDims,
|
||||
};
|
||||
|
||||
const newTextureData = inferenceHandler.createSharedTextureData(newTextureLayout, input.type, inputTD.texture, {});
|
||||
return newTextureData.tensor;
|
||||
}
|
||||
100
js/web/lib/onnxjs/backends/webgl/ops/slice.ts
Normal file
100
js/web/lib/onnxjs/backends/webgl/ops/slice.ts
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Slice, SliceV10} from '../../../ops/slice';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLSlice extends Slice implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
return createProgramInfo(handler, inputs[0], this.starts, this.ends, this.axes);
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
return createRunData(handler, programInfo, inputs);
|
||||
}
|
||||
}
|
||||
|
||||
export class WebGLSliceV10 extends SliceV10 implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
if (!handler.session.isInitializer(inputs[1].dataId) || !handler.session.isInitializer(inputs[2].dataId) ||
|
||||
(inputs.length >= 4 && !handler.session.isInitializer(inputs[3].dataId)) ||
|
||||
(inputs.length >= 5 && !handler.session.isInitializer(inputs[4].dataId))) {
|
||||
throw new Error('dynamic slice attributes are not allowed');
|
||||
}
|
||||
if (inputs.length >= 5 && inputs[4].integerData.some((i: number) => i !== 1)) {
|
||||
throw new Error('currently non-1 steps is not supported for Slice');
|
||||
}
|
||||
const starts = Array.from(inputs[1].integerData);
|
||||
const ends = Array.from(inputs[2].integerData);
|
||||
const axes = inputs.length >= 4 ? Array.from(inputs[3].integerData) : [];
|
||||
|
||||
return createProgramInfo(handler, inputs[0], starts, ends, axes);
|
||||
}
|
||||
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
return createRunData(handler, programInfo, inputs);
|
||||
}
|
||||
}
|
||||
|
||||
function createProgramInfo(
|
||||
handler: WebGLInferenceHandler, x: Tensor, starts: readonly number[], ends: readonly number[],
|
||||
axes: readonly number[]): ProgramInfo {
|
||||
if (axes.length === 0) {
|
||||
axes = x.dims.slice(0).map((val, ind) => ind);
|
||||
}
|
||||
axes = ShapeUtil.normalizeAxes(axes, x.dims.length);
|
||||
starts = starts.map((start, ind) => {
|
||||
if (start > x.dims[axes[ind]] - 1) {
|
||||
return x.dims[axes[ind]];
|
||||
}
|
||||
return ShapeUtil.normalizeAxis(start, x.dims[axes[ind]]);
|
||||
});
|
||||
ends = ends.map((end, ind) => {
|
||||
if (end > x.dims[axes[ind]] - 1) {
|
||||
return x.dims[axes[ind]];
|
||||
}
|
||||
return ShapeUtil.normalizeAxis(end, x.dims[axes[ind]]);
|
||||
});
|
||||
|
||||
const outputShape = x.dims.slice();
|
||||
|
||||
const sliceOps: string[] = [];
|
||||
for (let i = 0; i < axes.length; i++) {
|
||||
outputShape[axes[i]] = ends[i] - starts[i];
|
||||
if (starts[i] > 0) {
|
||||
sliceOps.push(`outputIdx[${axes[i]}] += ${starts[i]};`);
|
||||
} // else { sliceOps.push(`outputIdx[${axes[i]}] += 0;`); }
|
||||
}
|
||||
|
||||
const rank = outputShape.length;
|
||||
const shaderSource = `
|
||||
float process(int outputIdx[${rank}]) {
|
||||
${sliceOps.join('\n ')}
|
||||
return _A(outputIdx);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(x)],
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
|
||||
function createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
217
js/web/lib/onnxjs/backends/webgl/ops/softmax.ts
Normal file
217
js/web/lib/onnxjs/backends/webgl/ops/softmax.ts
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Softmax} from '../../../ops/softmax';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {Artifact, ProgramInfo, RunData, TextureLayout} from '../types';
|
||||
|
||||
export class WebGLSoftmax extends Softmax {
|
||||
constructor() {
|
||||
super();
|
||||
}
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
if (!this.artifacts) {
|
||||
this.artifacts = [];
|
||||
const programInfos = this.createProgramInfos(inferenceHandler, inputs);
|
||||
programInfos.forEach((pi) => {
|
||||
const artifact = inferenceHandler.session.programManager.build(pi);
|
||||
this.artifacts.push(artifact);
|
||||
});
|
||||
}
|
||||
|
||||
const runDatas = this.createRunDatas(inferenceHandler, this.artifacts.map(a => a.programInfo), inputs);
|
||||
runDatas.forEach((v, i) => inferenceHandler.session.programManager.run(this.artifacts[i], v));
|
||||
// return only the last output
|
||||
return [runDatas[runDatas.length - 1].outputTextureData.tensor];
|
||||
}
|
||||
createSoftMaxProgramInfo(
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
inferenceHandler: WebGLInferenceHandler, input: Tensor, N: number, D: number,
|
||||
maxElementPerLogicalRow: TextureLayout, normalizationPerLogicalRow: TextureLayout): ProgramInfo {
|
||||
const inputShape = input.dims.slice();
|
||||
const inputLayout = inferenceHandler.createTextureLayoutFromShape(inputShape);
|
||||
const outputShape = inputShape;
|
||||
const rank = outputShape.length;
|
||||
const textureWidth = inputLayout.width;
|
||||
const textureHeight = inputLayout.height;
|
||||
|
||||
if (N < 1 || D < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
|
||||
if (maxElementPerLogicalRow.shape.length !== 1 || normalizationPerLogicalRow.shape.length !== 1) {
|
||||
throw new Error('Dimensionality of the intermediate results should be 1');
|
||||
}
|
||||
|
||||
if (maxElementPerLogicalRow.shape[0] !== N || normalizationPerLogicalRow.shape[0] !== N) {
|
||||
throw new Error('Shape of the intermediate results should be equal to logical row count');
|
||||
}
|
||||
|
||||
const shaderSource = `
|
||||
float process(int[${rank}] indices) {
|
||||
|
||||
// get offset of current logical tensor index from the 2-D texture coordinates (TexCoords)
|
||||
int offset = coordsToOffset(TexCoords, ${textureWidth}, ${textureHeight});
|
||||
|
||||
//determine the logical row for this index
|
||||
int logical_row_index[1];
|
||||
logical_row_index[0] = offset / ${D};
|
||||
|
||||
float norm_factor = _Norm(logical_row_index);
|
||||
|
||||
// avoid possible division by 0
|
||||
// if norm_facor is 0, all elements are zero
|
||||
// if so, return 0
|
||||
if(norm_factor == 0.0)
|
||||
return 0.0;
|
||||
|
||||
return exp(_A(indices) - _Max(logical_row_index)) / norm_factor;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [inputLayout, maxElementPerLogicalRow, normalizationPerLogicalRow],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A', 'Max', 'Norm'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a texture that contains the normalization factor for each of the 'N' rows
|
||||
*/
|
||||
createComputScaleProgramInfo(
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
inferenceHandler: WebGLInferenceHandler, x: Tensor, N: number, D: number, maxElementPerLogicalRow: TextureLayout,
|
||||
outputShape: number[]): ProgramInfo {
|
||||
const xlayout = inferenceHandler.createTextureLayoutFromShape(x.dims.slice());
|
||||
const rank = outputShape.length;
|
||||
const textureWidth = xlayout.width;
|
||||
const textureHeight = xlayout.height;
|
||||
|
||||
if (N < 1 || D < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
|
||||
if (outputShape.length !== 1) {
|
||||
throw new Error('Dimensionality of the output should be 1');
|
||||
}
|
||||
|
||||
if (outputShape[0] !== N) {
|
||||
throw new Error('Shape of the output should be equal to logical row count');
|
||||
}
|
||||
|
||||
if (maxElementPerLogicalRow.shape.length !== 1) {
|
||||
throw new Error('Dimensionality of the intermediate results should be 1');
|
||||
}
|
||||
|
||||
if (maxElementPerLogicalRow.shape[0] !== N) {
|
||||
throw new Error('Shape of the intermediate results should be equal to logical row count');
|
||||
}
|
||||
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
float process(int[${rank}] indices) {
|
||||
|
||||
int logical_row_start_offset = indices[0] * ${D};
|
||||
|
||||
float norm_factor = 0.0;
|
||||
float max = _Max(indices);
|
||||
for(int i=0; i<${D}; ++i)
|
||||
{
|
||||
norm_factor += exp(getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i, ${
|
||||
textureWidth}, ${textureHeight}))) - max);
|
||||
}
|
||||
|
||||
return norm_factor;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [xlayout, maxElementPerLogicalRow],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A', 'Max'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
/**
|
||||
* Create a texture that contains the maximum value of each of the 'N' rows
|
||||
*/
|
||||
createComputeMaxProgramInfo(
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
inferenceHandler: WebGLInferenceHandler, x: Tensor, N: number, D: number, outputShape: number[]): ProgramInfo {
|
||||
const xlayout = inferenceHandler.createTextureLayoutFromShape(x.dims.slice());
|
||||
const rank = outputShape.length;
|
||||
const textureWidth = xlayout.width;
|
||||
const textureHeight = xlayout.height;
|
||||
|
||||
if (N < 1 || D < 1) {
|
||||
throw new Error('Logical row count N and feature count D must be greater than or equal to 1');
|
||||
}
|
||||
|
||||
if (outputShape.length !== 1) {
|
||||
throw new Error('Dimensionality of the output should be 1');
|
||||
}
|
||||
|
||||
if (outputShape[0] !== N) {
|
||||
throw new Error('Shape of the output should be equal to logical row count');
|
||||
}
|
||||
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
float process(int[${rank}] indices) {
|
||||
|
||||
int logical_row_start_offset = indices[0] * ${D};
|
||||
|
||||
float max = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset, ${textureWidth}, ${
|
||||
textureHeight} )));
|
||||
for(int i=1; i<${D}; ++i)
|
||||
{
|
||||
float current = getColorAsFloat(${glsl.texture2D}(A, offsetToCoords(logical_row_start_offset + i, ${
|
||||
textureWidth}, ${textureHeight})));
|
||||
if(current > max)
|
||||
max = current;
|
||||
}
|
||||
|
||||
return max;
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [xlayout],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createProgramInfos(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo[] {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
const axis = ShapeUtil.normalizeAxis(this.axis, inputShape.length);
|
||||
const N = ShapeUtil.sizeToDimension(inputShape, axis);
|
||||
const D = ShapeUtil.sizeFromDimension(inputShape, axis);
|
||||
const computeMaxProgramInfo = this.createComputeMaxProgramInfo(inferenceHandler, inputs[0], N, D, [N]);
|
||||
const computeScaleProgramInfo =
|
||||
this.createComputScaleProgramInfo(inferenceHandler, inputs[0], N, D, computeMaxProgramInfo.outputLayout, [N]);
|
||||
const softMaxProgramInfo = this.createSoftMaxProgramInfo(
|
||||
inferenceHandler, inputs[0], N, D, computeMaxProgramInfo.outputLayout, computeScaleProgramInfo.outputLayout);
|
||||
|
||||
const programInfos: ProgramInfo[] = [computeMaxProgramInfo, computeScaleProgramInfo, softMaxProgramInfo];
|
||||
return programInfos;
|
||||
}
|
||||
createRunDatas(inferenceHandler: WebGLInferenceHandler, programInfos: ProgramInfo[], inputs: Tensor[]): RunData[] {
|
||||
const dataType = inputs[0].type;
|
||||
const inputTD = inferenceHandler.getOrCreateTextureData(inputs[0], programInfos[0].inputLayouts[0]);
|
||||
const runDatas: RunData[] = [];
|
||||
runDatas.push({
|
||||
inputTextureDatas: [inputTD],
|
||||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[0].outputLayout, dataType),
|
||||
uniformData: {}
|
||||
});
|
||||
for (let i = 1; i < programInfos.length; ++i) {
|
||||
runDatas.push({
|
||||
inputTextureDatas: [...runDatas[i - 1].inputTextureDatas, runDatas[i - 1].outputTextureData],
|
||||
outputTextureData: inferenceHandler.createTextureDataFromLayout(programInfos[i].outputLayout, dataType),
|
||||
uniformData: {}
|
||||
});
|
||||
}
|
||||
return runDatas;
|
||||
}
|
||||
protected artifacts: Artifact[];
|
||||
}
|
||||
62
js/web/lib/onnxjs/backends/webgl/ops/split.ts
Normal file
62
js/web/lib/onnxjs/backends/webgl/ops/split.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Split} from '../../../ops/split';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil, SplitUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {Artifact, ProgramInfo, RunData} from '../types';
|
||||
|
||||
export class WebGLSplit extends Split {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
if (!this.artifacts) {
|
||||
this.artifacts = [];
|
||||
const axis = ShapeUtil.normalizeAxis(this.axis, inputs[0].dims.length);
|
||||
const count = this.getProgramCount(inferenceHandler, inputs, axis);
|
||||
for (let i = 0; i < count; ++i) {
|
||||
const programInfo = this.createProgramInfo(inferenceHandler, inputs[0], axis, i);
|
||||
const artifact = inferenceHandler.session.programManager.build(programInfo);
|
||||
this.artifacts.push(artifact);
|
||||
}
|
||||
}
|
||||
const results: Tensor[] = [];
|
||||
|
||||
this.artifacts.forEach(artifact => {
|
||||
const rundata = this.createRunData(inferenceHandler, artifact.programInfo, inputs);
|
||||
inferenceHandler.session.programManager.run(artifact, rundata);
|
||||
results.push(rundata.outputTextureData.tensor);
|
||||
});
|
||||
return results;
|
||||
}
|
||||
getProgramCount(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], axis: number): number {
|
||||
const [, offsets] = SplitUtil.splitShape(inputs[0].dims, axis, this.split, this.numOutputs);
|
||||
return offsets.length;
|
||||
}
|
||||
createProgramInfo(inferenceHandler: WebGLInferenceHandler, input: Tensor, axis: number, index: number): ProgramInfo {
|
||||
const [shapes, offsets] = SplitUtil.splitShape(input.dims, axis, this.split, this.numOutputs);
|
||||
const offset = offsets[index];
|
||||
const outputShape = shapes[index];
|
||||
const rank = outputShape.length;
|
||||
const shaderSource = `
|
||||
float process(int indices[${rank}]) {
|
||||
indices[${axis}] += ${offset};
|
||||
return _A(indices);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [inferenceHandler.getOrCreateTextureLayout(input)],
|
||||
outputLayout: inferenceHandler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(inferenceHandler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [inferenceHandler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData:
|
||||
inferenceHandler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
protected artifacts: Artifact[];
|
||||
}
|
||||
15
js/web/lib/onnxjs/backends/webgl/ops/squeeze.ts
Normal file
15
js/web/lib/onnxjs/backends/webgl/ops/squeeze.ts
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Squeeze} from '../../../ops/squeeze';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {reshape} from './reshape';
|
||||
|
||||
export class WebGLSqueeze extends Squeeze {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
const outputDims = ShapeUtil.squeezeShape(inputs[0].dims, this.axes);
|
||||
return [reshape(inferenceHandler, inputs[0], outputDims)];
|
||||
}
|
||||
}
|
||||
39
js/web/lib/onnxjs/backends/webgl/ops/sum.ts
Normal file
39
js/web/lib/onnxjs/backends/webgl/ops/sum.ts
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Sum} from '../../../ops/sum';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLSum extends Sum implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const sumLine = inputs.map((v, i) => `${glsl.texture2D}(X${i},TexCoords)`).join(' + ');
|
||||
const samplers = inputs.map((v, i) => `X${i}`);
|
||||
return {
|
||||
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers,
|
||||
shaderSource: `
|
||||
void main() {
|
||||
vec4 result = ${sumLine};
|
||||
${glsl.output} = result;
|
||||
}`,
|
||||
hasMain: true
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
46
js/web/lib/onnxjs/backends/webgl/ops/tile.ts
Normal file
46
js/web/lib/onnxjs/backends/webgl/ops/tile.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Tile} from '../../../ops/tile';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLTile extends Tile implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputShape = inputs[0].dims.slice();
|
||||
const outputShape = new Array(inputShape.length); // inputs[0].dims.slice();
|
||||
|
||||
const tileOps: string[] = [];
|
||||
for (let i = 0; i < inputShape.length; i++) {
|
||||
outputShape[i] = inputShape[i] * inputs[1].numberData[i];
|
||||
tileOps.push(`inputIdx[${i}] = int(mod(float(outputIdx[${i}]), ${inputShape[i]}.));`);
|
||||
}
|
||||
|
||||
const rank = outputShape.length;
|
||||
const shaderSource = `
|
||||
float process(int outputIdx[${rank}]) {
|
||||
int inputIdx[${rank}];
|
||||
${tileOps.join('\n')}
|
||||
return _A(inputIdx);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: inputs.map(t => handler.getOrCreateTextureLayout(t)),
|
||||
outputLayout: handler.createTextureLayoutFromShape(outputShape),
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
74
js/web/lib/onnxjs/backends/webgl/ops/transpose.ts
Normal file
74
js/web/lib/onnxjs/backends/webgl/ops/transpose.ts
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Transpose} from '../../../ops/transpose';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {FunctionType, GlslPositionalFunction} from '../glsl-definitions';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLTranspose extends Transpose implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
getOutputShape(inputShapes: Array<readonly number[]>): readonly number[] {
|
||||
const perm = this.getAdjustedPerm(inputShapes[0]);
|
||||
return ShapeUtil.sortBasedOnPerm(inputShapes[0], perm);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputShapes = inputs.map(t => t.dims.slice());
|
||||
const perm = this.getAdjustedPerm(inputShapes[0]);
|
||||
const unpackedOutputShape = this.getOutputShape(inputShapes);
|
||||
const rank = inputs[0].dims.length;
|
||||
// A dims=[${inputs[0].dims.toString()}]
|
||||
// out Dims=[${unpackedOutputShape.toString()}]
|
||||
// based on perm=[${perm.toString()}]
|
||||
const shaderSource = `
|
||||
${this.getPermFunctionBody('perm', perm, rank)}
|
||||
float process(int indices[${rank}]) {
|
||||
int a[${rank}];
|
||||
perm(a, indices);
|
||||
return _A(a);
|
||||
}`;
|
||||
const outputLayout = handler.createTextureLayoutFromShape(unpackedOutputShape, 1, unpackedOutputShape);
|
||||
return {inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])], outputLayout, samplers: ['A'], shaderSource};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
getPositionalFunction(handler: WebGLInferenceHandler, inputShape: number[], name?: string): GlslPositionalFunction {
|
||||
const outputShape = this.getOutputShape([inputShape]);
|
||||
if (!name) {
|
||||
name = 'perm';
|
||||
}
|
||||
return {
|
||||
name,
|
||||
body: this.getPermFunctionBody(name, this.getAdjustedPerm(inputShape), outputShape.length),
|
||||
type: FunctionType.Positional,
|
||||
inputShape,
|
||||
outputShape
|
||||
};
|
||||
}
|
||||
protected getAdjustedPerm(inputShape: readonly number[]): number[] {
|
||||
let perm = this.perm;
|
||||
if (perm && perm.length !== inputShape.length) {
|
||||
perm = [...(inputShape.keys())].reverse();
|
||||
}
|
||||
return perm;
|
||||
}
|
||||
protected getPermFunctionBody(name: string, perm: number[], rank: number): string {
|
||||
const reverseFunc = [];
|
||||
reverseFunc.push(`void ${name}(out int a[${rank}], int src[${rank}]) {`);
|
||||
for (let i = 0; i < rank; ++i) {
|
||||
reverseFunc.push(`\ta[${perm[i]}]=src[${i}];`);
|
||||
}
|
||||
reverseFunc.push('\t}');
|
||||
return reverseFunc.join('\n');
|
||||
}
|
||||
}
|
||||
86
js/web/lib/onnxjs/backends/webgl/ops/uint8-encode.ts
Normal file
86
js/web/lib/onnxjs/backends/webgl/ops/uint8-encode.ts
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {TextureData, TextureLayout} from '../types';
|
||||
|
||||
export class WebGLUint8Encode {
|
||||
runInternal(inferenceHandler: WebGLInferenceHandler, input: TextureData): TextureData {
|
||||
const outputShape = input.shape;
|
||||
const [width, height] = inferenceHandler.session.layoutStrategy.computeTextureWH(input.shape);
|
||||
const outputLayout: TextureLayout = {
|
||||
width,
|
||||
height,
|
||||
channels: 4,
|
||||
shape: outputShape,
|
||||
strides: ShapeUtil.computeStrides(outputShape),
|
||||
unpackedShape: outputShape
|
||||
};
|
||||
const glsl = getGlsl(inferenceHandler.session.backend.glContext.version);
|
||||
// TODO: remove this special script. Use graph transformer instead.
|
||||
/**
|
||||
* https://github.com/tensorflow/tfjs-core/blob/master/src/kernels/webgl/encode_float_gpu.ts
|
||||
*/
|
||||
const shaderSource = `
|
||||
const float FLOAT_MAX = 1.70141184e38;
|
||||
const float FLOAT_MIN = 1.17549435e-38;
|
||||
|
||||
bool isNaN(float val) {
|
||||
return (val < 1.0 || 0.0 < val || val == 0.0) ? false : true;
|
||||
}
|
||||
|
||||
highp vec4 encodeAsUint8(highp float v) {
|
||||
if (isNaN(v)) {
|
||||
return vec4(255, 255, 255, 255);
|
||||
}
|
||||
|
||||
highp float av = abs(v);
|
||||
|
||||
if(av < FLOAT_MIN) {
|
||||
return vec4(0.0, 0.0, 0.0, 0.0);
|
||||
} else if(v > FLOAT_MAX) {
|
||||
return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;
|
||||
} else if(v < -FLOAT_MAX) {
|
||||
return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;
|
||||
}
|
||||
|
||||
highp vec4 c = vec4(0,0,0,0);
|
||||
|
||||
highp float e = floor(log2(av));
|
||||
highp float m = exp2(fract(log2(av))) - 1.0;
|
||||
|
||||
c[2] = floor(128.0 * m);
|
||||
m -= c[2] / 128.0;
|
||||
c[1] = floor(32768.0 * m);
|
||||
m -= c[1] / 32768.0;
|
||||
c[0] = floor(8388608.0 * m);
|
||||
|
||||
highp float ebias = e + 127.0;
|
||||
c[3] = floor(ebias / 2.0);
|
||||
ebias -= c[3] * 2.0;
|
||||
c[2] += floor(ebias) * 128.0;
|
||||
|
||||
c[3] += 128.0 * step(0.0, -v);
|
||||
|
||||
return c / 255.0;
|
||||
}
|
||||
|
||||
void main() {
|
||||
float value = ${glsl.texture2D}(X,TexCoords).r;
|
||||
${glsl.output} = encodeAsUint8(value);
|
||||
}`;
|
||||
const programInfo = {inputLayouts: [input], outputLayout, samplers: ['X'], shaderSource, hasMain: true};
|
||||
const artifact = inferenceHandler.session.programManager.build(programInfo);
|
||||
|
||||
const encoder = inferenceHandler.session.backend.glContext.getEncoder('byte', 4);
|
||||
const texture =
|
||||
inferenceHandler.session.backend.glContext.allocateTexture(outputLayout.width, outputLayout.height, encoder);
|
||||
const outputTextureData = inferenceHandler.createSharedTextureData(outputLayout, 'uint8', texture, {});
|
||||
const runData = {inputTextureDatas: [input], outputTextureData, uniformData: {}};
|
||||
|
||||
inferenceHandler.session.programManager.run(artifact, runData);
|
||||
return runData.outputTextureData;
|
||||
}
|
||||
}
|
||||
172
js/web/lib/onnxjs/backends/webgl/ops/unary-op.ts
Normal file
172
js/web/lib/onnxjs/backends/webgl/ops/unary-op.ts
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {UnaryOp} from '../../../ops/unary-op';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {FunctionType, GlslValueFunction} from '../glsl-definitions';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLUnaryOp extends UnaryOp implements WebGLOperator {
|
||||
constructor(protected typeConstraint: readonly Tensor.DataType[], protected glslFunc: GlslValueFunction) {
|
||||
super(typeConstraint);
|
||||
}
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const outputShape = inputs[0].dims.slice();
|
||||
const inputLayout = handler.getOrCreateTextureLayout(inputs[0]);
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
${this.glslFunc.body}
|
||||
void main() {
|
||||
vec4 v = ${glsl.texture2D}(A, TexCoords);
|
||||
v = ${this.glslFunc.name}(v);
|
||||
${glsl.output} = v;
|
||||
}
|
||||
`;
|
||||
const outputLayout = handler.createTextureLayoutFromShape(outputShape);
|
||||
return {inputLayouts: [inputLayout], outputLayout, samplers: ['A'], shaderSource, hasMain: true};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function glslAbs(): GlslValueFunction {
|
||||
return glslBuiltinUnary('abs');
|
||||
}
|
||||
export function glslAcos(): GlslValueFunction {
|
||||
return glslBuiltinUnary('acos');
|
||||
}
|
||||
export function glslAsin(): GlslValueFunction {
|
||||
return glslBuiltinUnary('asin');
|
||||
}
|
||||
export function glslAtan(): GlslValueFunction {
|
||||
return glslBuiltinUnary('atan');
|
||||
}
|
||||
export function glslCeil(): GlslValueFunction {
|
||||
return glslBuiltinUnary('ceil');
|
||||
}
|
||||
export function glslCos(): GlslValueFunction {
|
||||
return glslBuiltinUnary('cos');
|
||||
}
|
||||
export function glslExp(): GlslValueFunction {
|
||||
return glslBuiltinUnary('exp');
|
||||
}
|
||||
export function glslFloor(): GlslValueFunction {
|
||||
return glslBuiltinUnary('floor');
|
||||
}
|
||||
export function glslIdentity(): GlslValueFunction {
|
||||
const name = 'indentity_';
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
return a;
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
return v;
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslLog(): GlslValueFunction {
|
||||
return glslBuiltinUnary('log');
|
||||
}
|
||||
export function glslNeg(): GlslValueFunction {
|
||||
const name = 'neg_';
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
return -a;
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
return -v;
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslNot(): GlslValueFunction {
|
||||
const name = 'not_';
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
return float( ! bool(a) );
|
||||
}
|
||||
bool ${name}(bool a) {
|
||||
return !a;
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
return vec4(!bool(v.x), !bool(v.y), !bool(v.z), !bool(v.w));
|
||||
}
|
||||
bvec4 ${name}(bvec4 v) {
|
||||
return bvec4(!v.x, !v.y, !v.z, !v.w);
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslSin(): GlslValueFunction {
|
||||
return glslBuiltinUnary('sin');
|
||||
}
|
||||
export function glslRelu(): GlslValueFunction {
|
||||
const name = 'relu_';
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
return max( a, 0.0 );
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
return max( v, 0.0 );
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslSigmoid(): GlslValueFunction {
|
||||
const name = 'sigmoid_';
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
return 1.0 / (1.0 + exp(-a));
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
return 1.0 / (1.0 + exp(-v));
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
export function glslSqrt(): GlslValueFunction {
|
||||
return glslBuiltinUnary('sqrt');
|
||||
}
|
||||
export function glslTan(): GlslValueFunction {
|
||||
return glslBuiltinUnary('tan');
|
||||
}
|
||||
export function glslTanh(): GlslValueFunction {
|
||||
const name = 'tanh_';
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
a = clamp(a, -10., 10.);
|
||||
a = exp(2.*a);
|
||||
return (a - 1.) / (a + 1.);
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
v = clamp(v, -10., 10.);
|
||||
v = exp(2.*v);
|
||||
return (v - 1.) / (v + 1.);
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
function glslBuiltinUnary(fname: string): GlslValueFunction {
|
||||
const name = `${fname}_`;
|
||||
const body = `
|
||||
float ${name}(float a) {
|
||||
return ${fname}(a);
|
||||
}
|
||||
vec4 ${name}(vec4 v) {
|
||||
return ${fname}(v);
|
||||
}
|
||||
`;
|
||||
return {body, name, type: FunctionType.ValueBased};
|
||||
}
|
||||
82
js/web/lib/onnxjs/backends/webgl/ops/unpack.ts
Normal file
82
js/web/lib/onnxjs/backends/webgl/ops/unpack.ts
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
import {getCoordsDataType} from '../utils';
|
||||
|
||||
import {getChannels, unpackFromChannel} from './packing_utils';
|
||||
|
||||
export class WebGLUnpack implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
if (inputs.length !== 1) {
|
||||
throw new Error('Pack kernel should have input tensor count to 1.');
|
||||
}
|
||||
|
||||
const inputTexture = handler.getTextureData(inputs[0].dataId);
|
||||
if (!inputTexture) {
|
||||
throw new Error('packed input texture must exist');
|
||||
}
|
||||
|
||||
const outputLayout = handler.createTextureLayoutFromShape(inputTexture.unpackedShape);
|
||||
const outputShape = outputLayout.shape;
|
||||
const rank = outputShape.length;
|
||||
|
||||
const channels = getChannels('rc', rank);
|
||||
const innerDims = channels.slice(-2);
|
||||
const coordsDataType = getCoordsDataType(rank);
|
||||
const unpackChannel = unpackFromChannel(rank);
|
||||
const sourceCoords = getSourceCoords(rank, channels);
|
||||
const coords = rank <= 1 ? 'rc' : `vec2(${innerDims.join(',')})`;
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
const shaderSource = `
|
||||
${unpackChannel}
|
||||
void main() {
|
||||
${coordsDataType} rc = getOutputCoords();
|
||||
|
||||
// Sample the texture with the coords to get the rgba channel value.
|
||||
vec4 packedInput = getA(${sourceCoords});
|
||||
|
||||
${glsl.output} = vec4(getChannel(packedInput, ${coords}), 0, 0, 0);
|
||||
}
|
||||
`;
|
||||
|
||||
return {
|
||||
inputLayouts: [handler.getOrCreateTextureLayout(inputs[0])],
|
||||
outputLayout,
|
||||
samplers: ['A'],
|
||||
shaderSource,
|
||||
hasMain: true,
|
||||
expectPackedInputs: true,
|
||||
expectPackedoutputs: false,
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = [handler.getOrCreateTextureData(inputs[0], programInfo.inputLayouts[0])];
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function getSourceCoords(rank: number, dims: string[]): string {
|
||||
if (rank === 1) {
|
||||
return 'rc';
|
||||
}
|
||||
|
||||
let coords = '';
|
||||
for (let i = 0; i < rank; i++) {
|
||||
coords += dims[i];
|
||||
if (i < rank - 1) {
|
||||
coords += ',';
|
||||
}
|
||||
}
|
||||
return coords;
|
||||
}
|
||||
15
js/web/lib/onnxjs/backends/webgl/ops/unsqueeze.ts
Normal file
15
js/web/lib/onnxjs/backends/webgl/ops/unsqueeze.ts
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Unsqueeze} from '../../../ops/unsqueeze';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {reshape} from './reshape';
|
||||
|
||||
export class WebGLUnsqueeze extends Unsqueeze {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
const outputDims = ShapeUtil.unsqueezeShape(inputs[0].dims, this.axes);
|
||||
return [reshape(inferenceHandler, inputs[0], outputDims)];
|
||||
}
|
||||
}
|
||||
193
js/web/lib/onnxjs/backends/webgl/ops/upsample.ts
Normal file
193
js/web/lib/onnxjs/backends/webgl/ops/upsample.ts
Normal file
|
|
@ -0,0 +1,193 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Upsample} from '../../../ops/upsample';
|
||||
import {Tensor} from '../../../tensor';
|
||||
import {getGlsl} from '../glsl-source';
|
||||
import {WebGLInferenceHandler} from '../inference-handler';
|
||||
import {ProgramInfo, RunData, WebGLOperator} from '../types';
|
||||
|
||||
export class WebGLUpsample extends Upsample implements WebGLOperator {
|
||||
run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
|
||||
return inferenceHandler.run(this, inputs);
|
||||
}
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo {
|
||||
const inputLayout = handler.getOrCreateTextureLayout(inputs[0]);
|
||||
const outputShape = inputs[0].dims.map((dim, i) => Math.floor(dim * this.scales[i]));
|
||||
const outputLayout = handler.createTextureLayoutFromShape(outputShape);
|
||||
const dim = outputShape.length;
|
||||
|
||||
const glsl = getGlsl(handler.session.backend.glContext.version);
|
||||
|
||||
const outputPitches = new Array<number>(dim);
|
||||
const inputPitches = new Array<number>(dim);
|
||||
let precalculatedPitches = `
|
||||
int output_pitches[${dim}];
|
||||
int input_pitches[${dim}];
|
||||
`;
|
||||
for (let d = dim - 1; d >= 0; d--) {
|
||||
outputPitches[d] = (d === dim - 1) ? 1 : outputPitches[d + 1] * outputShape[d + 1];
|
||||
inputPitches[d] = (d === dim - 1) ? 1 : inputPitches[d + 1] * inputs[0].dims[d + 1];
|
||||
|
||||
precalculatedPitches += `
|
||||
output_pitches[${d}] = ${outputPitches[d]};
|
||||
input_pitches[${d}] = ${inputPitches[d]};
|
||||
`;
|
||||
}
|
||||
const getInputFloatFunction = `
|
||||
float getInputFloat(int index) {
|
||||
vec2 coords = offsetToCoords(index, ${inputLayout.width}, ${inputLayout.height});
|
||||
float value = getColorAsFloat(${glsl.texture2D}(X, coords));
|
||||
return value;
|
||||
}
|
||||
`;
|
||||
|
||||
const shaderSource = this.mode === 'nearest' ?
|
||||
// nearest
|
||||
`
|
||||
${getInputFloatFunction}
|
||||
float process(int indices[${dim}]) {
|
||||
int input_index = 0;
|
||||
int output_index = coordsToOffset(TexCoords, ${outputLayout.width}, ${outputLayout.height});
|
||||
|
||||
${precalculatedPitches}
|
||||
|
||||
int d, m;
|
||||
for (int dim = 0; dim < ${dim}; ++dim) {
|
||||
d = output_index / output_pitches[dim];
|
||||
m = output_index - d * output_pitches[dim];
|
||||
output_index = m;
|
||||
|
||||
if (scales[dim] != 1 && d > 0) {
|
||||
int d2 = d / scales[dim];
|
||||
m = d - d2 * scales[dim];
|
||||
d = d2;
|
||||
}
|
||||
input_index += input_pitches[dim] * d;
|
||||
}
|
||||
|
||||
return getInputFloat(input_index);
|
||||
}` :
|
||||
dim === 4 ?
|
||||
// bilinear 4D
|
||||
`
|
||||
${getInputFloatFunction}
|
||||
float process(int indices[4]) {
|
||||
int input_index = 0;
|
||||
int output_index = coordsToOffset(TexCoords, ${outputLayout.width}, ${outputLayout.height});
|
||||
|
||||
${precalculatedPitches}
|
||||
|
||||
int m;
|
||||
int index_of_dim0, index_of_dim1, index_of_dim2, index_of_dim3;
|
||||
index_of_dim0 = output_index / output_pitches[0];
|
||||
m = output_index - index_of_dim0 * output_pitches[0];
|
||||
index_of_dim1 = m / output_pitches[1];
|
||||
m = m - index_of_dim1 * output_pitches[1];
|
||||
index_of_dim2 = m / output_pitches[2];
|
||||
m = m - index_of_dim2 * output_pitches[2];
|
||||
index_of_dim3 = m;
|
||||
|
||||
int index_of_input_dim2, index_of_input_dim3, x_offset, y_offset;
|
||||
index_of_input_dim2 = index_of_dim2 / scales[2];
|
||||
y_offset = index_of_dim2 - index_of_input_dim2 * scales[2];
|
||||
index_of_input_dim3 = index_of_dim3 / scales[3];
|
||||
x_offset = index_of_dim3 - index_of_input_dim3 * scales[3];
|
||||
|
||||
input_index = index_of_dim0 * input_pitches[0] +
|
||||
index_of_dim1 * input_pitches[1] +
|
||||
index_of_input_dim2 * input_pitches[2] +
|
||||
index_of_input_dim3;
|
||||
|
||||
float x00 = getInputFloat(input_index);
|
||||
float x10, x01, x11;
|
||||
|
||||
bool end_of_dim2 = false;
|
||||
if (index_of_input_dim2 == (${inputs[0].dims[2]} - 1)) {
|
||||
// It's the end in dimension 2
|
||||
x01 = x00;
|
||||
end_of_dim2 = true;
|
||||
} else {
|
||||
x01 = getInputFloat(input_index + input_pitches[2]);
|
||||
}
|
||||
|
||||
if (index_of_input_dim3 == (input_pitches[2] - 1)) {
|
||||
// It's the end in dimension 3
|
||||
x10 = x00;
|
||||
x11 = x01;
|
||||
}
|
||||
else {
|
||||
x10 = getInputFloat(input_index + 1);
|
||||
x11 = end_of_dim2 ? x10 : getInputFloat(input_index + input_pitches[2] + 1);
|
||||
}
|
||||
|
||||
float y0 = x00 + float(y_offset) * (x01 - x00) / float(scales[2]);
|
||||
float y1 = x10 + float(y_offset) * (x11 - x10) / float(scales[2]);
|
||||
return y0 + float(x_offset) * (y1 - y0) / float(scales[3]);
|
||||
}` :
|
||||
// bilinear 2D
|
||||
`
|
||||
${getInputFloatFunction}
|
||||
float process(int indices[2]) {
|
||||
int input_index = 0;
|
||||
int output_index = coordsToOffset(TexCoords, ${outputLayout.width}, ${outputLayout.height});
|
||||
|
||||
${precalculatedPitches}
|
||||
|
||||
int m;
|
||||
int index_of_dim0, index_of_dim1;
|
||||
index_of_dim0 = output_index / output_pitches[0];
|
||||
m = output_index - index_of_dim0 * output_pitches[0];
|
||||
index_of_dim1 = m;
|
||||
|
||||
int index_of_input_dim0, index_of_input_dim1, x_offset, y_offset;
|
||||
index_of_input_dim0 = index_of_dim0 / scales[0];
|
||||
y_offset = index_of_dim0 - index_of_input_dim0 * scales[0];
|
||||
index_of_input_dim1 = index_of_dim1 / scales[1];
|
||||
x_offset = index_of_dim1 - index_of_input_dim1 * scales[1];
|
||||
|
||||
input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1;
|
||||
|
||||
float x00 = getInputFloat(input_index);
|
||||
float x10, x01, x11;
|
||||
|
||||
bool end_of_dim0 = false;
|
||||
if (index_of_input_dim0 == (${inputs[0].dims[0]} - 1)) {
|
||||
// It's the end in dimension 0
|
||||
x01 = x00;
|
||||
end_of_dim0 = true;
|
||||
} else {
|
||||
x01 = getInputFloat(input_index + input_pitches[0]);
|
||||
}
|
||||
|
||||
if (index_of_input_dim1 == (input_pitches[0] - 1)) {
|
||||
// It's the end in dimension 1
|
||||
x10 = x00;
|
||||
x11 = x01;
|
||||
}
|
||||
else {
|
||||
x10 = getInputFloat(input_index + 1);
|
||||
x11 = end_of_dim0 ? x10 : getInputFloat(input_index + input_pitches[0] + 1);
|
||||
}
|
||||
|
||||
float y0 = x00 + float(y_offset) * (x01 - x00) / float(scales[0]);
|
||||
float y1 = x10 + float(y_offset) * (x11 - x10) / float(scales[0]);
|
||||
return y0 + float(x_offset) * (y1 - y0) / float(scales[1]);
|
||||
}`;
|
||||
return {
|
||||
inputLayouts: [inputLayout],
|
||||
outputLayout,
|
||||
samplers: ['X'],
|
||||
shaderSource,
|
||||
variables: [{name: 'scales', type: 'int', arrayLength: this.scales.length}]
|
||||
};
|
||||
}
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
|
||||
const inputTDs = inputs.map((t, i) => handler.getOrCreateTextureData(t, programInfo.inputLayouts[i]));
|
||||
return {
|
||||
inputTextureDatas: inputTDs,
|
||||
outputTextureData: handler.createTextureDataFromLayout(programInfo.outputLayout, inputTDs[0].tensor.type),
|
||||
uniformData: {scales: this.scales.map(x => Math.ceil(x))}
|
||||
};
|
||||
}
|
||||
}
|
||||
185
js/web/lib/onnxjs/backends/webgl/program-manager.ts
Normal file
185
js/web/lib/onnxjs/backends/webgl/program-manager.ts
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {env} from 'onnxruntime-common';
|
||||
import {Logger, Profiler} from '../../instrument';
|
||||
|
||||
import {GlslPreprocessor} from './glsl-preprocessor';
|
||||
import {getVertexShaderSource} from './glsl-source';
|
||||
import {TextureLayoutStrategy} from './texture-layout-strategy';
|
||||
import {Artifact, ProgramInfo, RunData, TextureData, UniformData, VariableInfo} from './types';
|
||||
import {WebGLContext} from './webgl-context';
|
||||
|
||||
/**
|
||||
* ProgramManager is the main class behind running computations
|
||||
* It builds ProgramInfo's into Artifacts
|
||||
* It compiles given ProgramInfo's into WebGL Prorams (cached as Artifacts)
|
||||
* Uses the artifact to run the computation by calling Draw on
|
||||
* the WebGL drawing buffer
|
||||
* ProgramManager automatically maps (binds) input variables to their
|
||||
* corresponding Location's in the binary program
|
||||
*/
|
||||
export class ProgramManager {
|
||||
repo: Map<unknown, Artifact>; // this should be per-session object
|
||||
vertexShader: WebGLShader;
|
||||
attributesBound: boolean;
|
||||
|
||||
constructor(
|
||||
public profiler: Readonly<Profiler>, public glContext: WebGLContext,
|
||||
public textureLayoutStrategy: TextureLayoutStrategy) {
|
||||
this.repo = new Map();
|
||||
this.attributesBound = false;
|
||||
}
|
||||
getArtifact(key: unknown): Artifact|undefined {
|
||||
return this.repo.get(key);
|
||||
}
|
||||
setArtifact(key: unknown, artifact: Artifact): void {
|
||||
this.repo.set(key, artifact);
|
||||
}
|
||||
run(buildArtifact: Artifact, runData: RunData): void {
|
||||
const inputInfo = runData.inputTextureDatas.map((d, i) => `input${i}:[${d.shape}]`).join(', ');
|
||||
const outputInfo = `output: [${runData.outputTextureData.shape}]`;
|
||||
|
||||
this.profiler.event('backend', `ProgramManager.run ${inputInfo} ; ${outputInfo}`, () => {
|
||||
const gl = this.glContext.gl;
|
||||
const program = buildArtifact.program;
|
||||
gl.useProgram(program);
|
||||
try {
|
||||
this.bindOutput(runData.outputTextureData);
|
||||
if (!this.attributesBound) {
|
||||
this.bindAttributes(buildArtifact.attribLocations);
|
||||
}
|
||||
this.bindUniforms(buildArtifact.uniformLocations, runData.uniformData, runData.inputTextureDatas);
|
||||
} catch (err) {
|
||||
Logger.error('ProgramManager', buildArtifact.programInfo.shaderSource);
|
||||
throw err;
|
||||
}
|
||||
this.profiler.event('backend', 'GlContext.draw()', () => {
|
||||
this.doDraw(buildArtifact, runData);
|
||||
});
|
||||
});
|
||||
}
|
||||
dispose(): void {
|
||||
if (this.vertexShader) {
|
||||
this.glContext.deleteShader(this.vertexShader);
|
||||
}
|
||||
this.repo.forEach(a => this.glContext.deleteProgram(a.program));
|
||||
}
|
||||
build(programInfo: ProgramInfo): Artifact {
|
||||
return this.profiler.event('backend', 'ProgramManager.build', () => {
|
||||
const preprocessor = new GlslPreprocessor(this.glContext, programInfo);
|
||||
const fragScript = preprocessor.preprocess();
|
||||
const program = this.compile(fragScript);
|
||||
const artifact = {
|
||||
programInfo,
|
||||
program,
|
||||
uniformLocations: this.getUniformLocations(
|
||||
program, preprocessor.context.programInfo.samplers, preprocessor.context.programInfo.variables),
|
||||
attribLocations: this.getAttribLocations(program)
|
||||
};
|
||||
return artifact;
|
||||
});
|
||||
}
|
||||
protected doDraw(artifact: Artifact, runData: RunData): void {
|
||||
if (runData.draw) {
|
||||
Logger.verbose('ProgramManager', 'Custom draw function');
|
||||
runData.draw(this.glContext, artifact);
|
||||
} else {
|
||||
this.glContext.draw();
|
||||
}
|
||||
}
|
||||
protected compile(fragShaderScript: string): WebGLProgram {
|
||||
if (!this.vertexShader) {
|
||||
Logger.verbose('ProrgramManager', 'Compiling and caching Vertex shader for the first time');
|
||||
const vertexShaderScript = getVertexShaderSource(this.glContext.version);
|
||||
this.vertexShader = this.glContext.compileShader(vertexShaderScript, this.glContext.gl.VERTEX_SHADER);
|
||||
}
|
||||
if (env.debug) {
|
||||
Logger.verbose('ProrgramManager', `FragShader:
|
||||
${fragShaderScript}
|
||||
`);
|
||||
}
|
||||
const fragShader = this.glContext.compileShader(fragShaderScript, this.glContext.gl.FRAGMENT_SHADER);
|
||||
const program = this.glContext.createProgram(this.vertexShader, fragShader);
|
||||
this.glContext.deleteShader(fragShader);
|
||||
return program;
|
||||
}
|
||||
bindOutput(td: TextureData): void {
|
||||
Logger.verbose(
|
||||
'ProrgramManager',
|
||||
`Binding output texture to Framebuffer: w/h=${td.width}/${td.height}, shape=${td.shape}, type=${
|
||||
td.tensor.type}`);
|
||||
this.glContext.attachFramebuffer(td.texture, td.width, td.height);
|
||||
}
|
||||
bindAttributes(attribLocations: Artifact.AttribLocations): void {
|
||||
const positionHandle = attribLocations.position;
|
||||
const textureCoordHandle = attribLocations.textureCoord;
|
||||
this.glContext.setVertexAttributes(positionHandle, textureCoordHandle);
|
||||
this.attributesBound = true;
|
||||
}
|
||||
bindUniforms(uniformLocations: Artifact.UniformLocations, uniformData: UniformData, textures: TextureData[]): void {
|
||||
const gl = this.glContext.gl;
|
||||
let texturePosition = 0;
|
||||
for (const {name, type, location, arrayLength} of uniformLocations) {
|
||||
switch (type) {
|
||||
case 'sampler2D':
|
||||
this.bindTexture(textures[texturePosition], location, texturePosition);
|
||||
texturePosition++;
|
||||
break;
|
||||
case 'float':
|
||||
if (arrayLength) {
|
||||
gl.uniform1fv(location, uniformData[name] as number[]);
|
||||
} else {
|
||||
gl.uniform1f(location, uniformData[name] as number);
|
||||
}
|
||||
break;
|
||||
case 'int':
|
||||
if (arrayLength) {
|
||||
gl.uniform1iv(location, uniformData[name] as number[]);
|
||||
} else {
|
||||
gl.uniform1i(location, uniformData[name] as number);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Uniform not implemented: ${type}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
bindTexture(td: TextureData, uniformHandle: WebGLUniformLocation, position: number): void {
|
||||
this.glContext.bindTextureToUniform(td.texture, position, uniformHandle);
|
||||
}
|
||||
getAttribLocations(program: WebGLProgram): Artifact.AttribLocations {
|
||||
return {
|
||||
position: this.getAttribLocation(program, 'position'),
|
||||
textureCoord: this.getAttribLocation(program, 'textureCoord')
|
||||
};
|
||||
}
|
||||
getUniformLocations(program: WebGLProgram, samplers?: string[], variables?: VariableInfo[]):
|
||||
Artifact.UniformLocations {
|
||||
const uniformLocations: Artifact.UniformLocations = [];
|
||||
if (samplers) {
|
||||
for (const sampler of samplers) {
|
||||
uniformLocations.push({name: sampler, type: 'sampler2D', location: this.getUniformLocation(program, sampler)});
|
||||
}
|
||||
}
|
||||
if (variables) {
|
||||
for (const variable of variables) {
|
||||
uniformLocations.push({...variable, location: this.getUniformLocation(program, variable.name)});
|
||||
}
|
||||
}
|
||||
return uniformLocations;
|
||||
}
|
||||
getUniformLocation(program: WebGLProgram, name: string): WebGLUniformLocation {
|
||||
const gl = this.glContext.gl;
|
||||
const reference = gl.getUniformLocation(program, name);
|
||||
if (reference === null) {
|
||||
throw new Error(`Uniform ${name} not found.`);
|
||||
}
|
||||
return reference;
|
||||
}
|
||||
getAttribLocation(program: WebGLProgram, name: string): number {
|
||||
const gl = this.glContext.gl;
|
||||
const attributeLocation: number = gl.getAttribLocation(program, name);
|
||||
return attributeLocation;
|
||||
}
|
||||
}
|
||||
68
js/web/lib/onnxjs/backends/webgl/session-handler.ts
Normal file
68
js/web/lib/onnxjs/backends/webgl/session-handler.ts
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {SessionHandler} from '../../backend';
|
||||
import {Graph} from '../../graph';
|
||||
import {Logger} from '../../instrument';
|
||||
import {Operator} from '../../operators';
|
||||
import {OpSet, resolveOperator} from '../../opset';
|
||||
import {Session} from '../../session';
|
||||
import {Tensor} from '../../tensor';
|
||||
import {WebGLBackend} from '../backend-webgl';
|
||||
|
||||
import {WebGLInferenceHandler} from './inference-handler';
|
||||
import {WEBGL_OP_RESOLVE_RULES} from './op-resolve-rules';
|
||||
import {ProgramManager} from './program-manager';
|
||||
import {PreferLogicalStrategy, TextureLayoutStrategy} from './texture-layout-strategy';
|
||||
import {TextureManager} from './texture-manager';
|
||||
import {TextureData, WebGLOperator} from './types';
|
||||
|
||||
export class WebGLSessionHandler implements SessionHandler {
|
||||
programManager: ProgramManager;
|
||||
textureManager: TextureManager;
|
||||
layoutStrategy: TextureLayoutStrategy;
|
||||
textureDataCache: Map<Tensor.Id, TextureData>;
|
||||
initializers: Set<Tensor.Id>;
|
||||
packOpCache: Map<string, WebGLOperator>;
|
||||
unpackOpCache: Map<string, WebGLOperator>;
|
||||
|
||||
constructor(public readonly backend: WebGLBackend, public readonly context: Session.Context) {
|
||||
this.layoutStrategy = new PreferLogicalStrategy(backend.glContext.maxTextureSize);
|
||||
this.programManager = new ProgramManager(this.context.profiler, backend.glContext, this.layoutStrategy);
|
||||
this.textureManager = new TextureManager(
|
||||
backend.glContext, this.layoutStrategy, this.context.profiler,
|
||||
{reuseTextures: backend.textureCacheMode === 'full'});
|
||||
this.textureDataCache = new Map();
|
||||
this.packOpCache = new Map();
|
||||
this.unpackOpCache = new Map();
|
||||
}
|
||||
|
||||
createInferenceHandler() {
|
||||
return new WebGLInferenceHandler(this);
|
||||
}
|
||||
onGraphInitialized(graph: Graph): void {
|
||||
const initializers = graph.getValues().filter(v => v.from === -1 && v.tensor).map(v => v.tensor!.dataId);
|
||||
this.initializers = new Set(initializers);
|
||||
}
|
||||
isInitializer(tensorId: Tensor.Id): boolean {
|
||||
return this.initializers ? this.initializers.has(tensorId) : false;
|
||||
}
|
||||
getTextureData(tensorId: Tensor.Id): TextureData|undefined {
|
||||
return this.textureDataCache.get(tensorId);
|
||||
}
|
||||
setTextureData(tensorId: Tensor.Id, textureData: TextureData): void {
|
||||
Logger.verbose('WebGLSessionHandler', 'Storing Texture data in cache');
|
||||
this.textureDataCache.set(tensorId, textureData);
|
||||
}
|
||||
dispose(): void {
|
||||
this.programManager.dispose();
|
||||
this.textureManager.clearActiveTextures();
|
||||
this.textureDataCache.forEach(td => this.textureManager.releaseTexture(td, true));
|
||||
this.textureDataCache = new Map();
|
||||
}
|
||||
resolve(node: Graph.Node, opsets: readonly OpSet[], graph: Graph): Operator {
|
||||
const op = resolveOperator(node, opsets, WEBGL_OP_RESOLVE_RULES);
|
||||
op.initialize(node.attributes, node, graph);
|
||||
return op;
|
||||
}
|
||||
}
|
||||
160
js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts
Normal file
160
js/web/lib/onnxjs/backends/webgl/texture-data-encoder.ts
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Logger} from '../../instrument';
|
||||
|
||||
export declare namespace Encoder {
|
||||
export interface DataTypeMap {
|
||||
float: Float32Array;
|
||||
byte: Uint8Array;
|
||||
int: Uint32Array;
|
||||
}
|
||||
export type DataType = keyof DataTypeMap;
|
||||
type DataArrayType = DataTypeMap[DataType];
|
||||
|
||||
/* eslint-disable @typescript-eslint/naming-convention */
|
||||
export const enum Usage {
|
||||
Default = 0,
|
||||
UploadOnly,
|
||||
Download4BytesAsFloat32,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Abstraction for mapping data types to texture texlets
|
||||
* Encoding means how a Float32 is mapped to 1 or 4 channels for each texlet
|
||||
* Decoding means how a texlet's channels are mapped to a resulting Float32
|
||||
*/
|
||||
export interface DataEncoder {
|
||||
internalFormat: number;
|
||||
format: number;
|
||||
textureType: number;
|
||||
channelSize: number;
|
||||
encode(src: Encoder.DataArrayType, textureSize: number): Encoder.DataArrayType;
|
||||
allocate(size: number): Encoder.DataArrayType;
|
||||
decode(buffer: Encoder.DataArrayType, dataSize: number): Encoder.DataArrayType;
|
||||
}
|
||||
/**
|
||||
* WebGL2 data encoder
|
||||
* Uses R32F as the format for texlet
|
||||
*/
|
||||
export class RedFloat32DataEncoder implements DataEncoder {
|
||||
internalFormat: number;
|
||||
format: number;
|
||||
textureType: number;
|
||||
channelSize: number;
|
||||
constructor(gl: WebGL2RenderingContext, channels = 1) {
|
||||
if (channels === 1) {
|
||||
this.internalFormat = gl.R32F;
|
||||
this.format = gl.RED;
|
||||
this.textureType = gl.FLOAT;
|
||||
this.channelSize = channels;
|
||||
} else if (channels === 4) {
|
||||
this.internalFormat = gl.RGBA32F;
|
||||
this.format = gl.RGBA;
|
||||
this.textureType = gl.FLOAT;
|
||||
this.channelSize = channels;
|
||||
} else {
|
||||
throw new Error(`Invalid number of channels: ${channels}`);
|
||||
}
|
||||
}
|
||||
encode(src: Encoder.DataArrayType, textureSize: number): Encoder.DataArrayType {
|
||||
let result: Float32Array;
|
||||
let source: Float32Array;
|
||||
if (src.constructor !== Float32Array) {
|
||||
Logger.warning('Encoder', 'data was not of type Float32; creating new Float32Array');
|
||||
source = new Float32Array(src);
|
||||
}
|
||||
if (textureSize * this.channelSize > src.length) {
|
||||
Logger.warning('Encoder', 'Source data too small. Allocating larger array');
|
||||
source = src as Float32Array;
|
||||
result = this.allocate(textureSize * this.channelSize) as Float32Array;
|
||||
source.forEach((v, i) => result[i] = v);
|
||||
} else {
|
||||
source = src as Float32Array;
|
||||
result = source;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
allocate(size: number): Encoder.DataArrayType {
|
||||
return new Float32Array(size * 4);
|
||||
}
|
||||
decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array {
|
||||
if (this.channelSize === 1) {
|
||||
const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize);
|
||||
return filteredData;
|
||||
}
|
||||
return buffer.subarray(0, dataSize) as Float32Array;
|
||||
}
|
||||
}
|
||||
/**
|
||||
* Data encoder for WebGL 1 with support for floating point texture
|
||||
*/
|
||||
export class RGBAFloatDataEncoder implements DataEncoder {
|
||||
internalFormat: number;
|
||||
format: number;
|
||||
textureType: number;
|
||||
channelSize: number;
|
||||
constructor(gl: WebGLRenderingContext, channels = 1, textureType?: number) {
|
||||
if (channels !== 1 && channels !== 4) {
|
||||
throw new Error(`Invalid number of channels: ${channels}`);
|
||||
}
|
||||
this.internalFormat = gl.RGBA;
|
||||
this.format = gl.RGBA;
|
||||
this.channelSize = channels;
|
||||
this.textureType = textureType || gl.FLOAT;
|
||||
}
|
||||
encode(src: Float32Array, textureSize: number): Encoder.DataArrayType {
|
||||
let dest = src;
|
||||
if (this.channelSize === 1) {
|
||||
Logger.verbose('Encoder', 'Exploding into a larger array');
|
||||
dest = this.allocate(textureSize) as Float32Array;
|
||||
src.forEach((v, i) => dest[i * 4] = v);
|
||||
}
|
||||
return dest;
|
||||
}
|
||||
allocate(size: number): Encoder.DataArrayType {
|
||||
return new Float32Array(size * 4);
|
||||
}
|
||||
decode(buffer: Encoder.DataArrayType, dataSize: number): Float32Array {
|
||||
if (this.channelSize === 1) {
|
||||
const filteredData = (buffer as Float32Array).filter((value, index) => index % 4 === 0).subarray(0, dataSize);
|
||||
return filteredData;
|
||||
}
|
||||
return buffer.subarray(0, dataSize) as Float32Array;
|
||||
}
|
||||
}
|
||||
|
||||
export class Uint8DataEncoder implements DataEncoder {
|
||||
internalFormat: number;
|
||||
format: number;
|
||||
textureType: number;
|
||||
channelSize = 4;
|
||||
constructor(gl: WebGLRenderingContext, channels = 1) {
|
||||
if (channels === 1) {
|
||||
this.internalFormat = gl.ALPHA;
|
||||
this.format = gl.ALPHA; // not tested
|
||||
this.textureType = gl.UNSIGNED_BYTE;
|
||||
this.channelSize = channels;
|
||||
} else if (channels === 4) {
|
||||
this.internalFormat = gl.RGBA;
|
||||
this.format = gl.RGBA;
|
||||
this.textureType = gl.UNSIGNED_BYTE;
|
||||
this.channelSize = channels;
|
||||
} else {
|
||||
throw new Error(`Invalid number of channels: ${channels}`);
|
||||
}
|
||||
}
|
||||
encode(src: Uint8Array, _textureSize: number): Encoder.DataArrayType {
|
||||
return new Uint8Array(src.buffer, src.byteOffset, src.byteLength);
|
||||
}
|
||||
allocate(size: number): Encoder.DataArrayType {
|
||||
return new Uint8Array(size * this.channelSize);
|
||||
}
|
||||
decode(buffer: Encoder.DataArrayType, dataSize: number): Uint8Array {
|
||||
if (buffer instanceof Uint8Array) {
|
||||
return buffer.subarray(0, dataSize);
|
||||
}
|
||||
throw new Error(`Invalid array type: ${buffer.constructor}`);
|
||||
}
|
||||
}
|
||||
227
js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts
Normal file
227
js/web/lib/onnxjs/backends/webgl/texture-layout-strategy.ts
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Logger} from '../../instrument';
|
||||
import {assert} from '../../util';
|
||||
/** Layout preferences */
|
||||
export interface WidthHeightPrefs {
|
||||
breakAxis?: number;
|
||||
isPacked?: boolean;
|
||||
reverseWH?: boolean;
|
||||
}
|
||||
/**
|
||||
* TextureLayoutStrategy is an abstraction for different plans
|
||||
* for mapping n-dimensional arrays to 2D textures (and back)
|
||||
*/
|
||||
export interface TextureLayoutStrategy {
|
||||
computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number];
|
||||
}
|
||||
|
||||
/**
|
||||
* This strategy try to find the minimal max(W,H) that fulfills (W * H == totalSize)
|
||||
*/
|
||||
export class AlwaysKeepOriginalSizeStrategy implements TextureLayoutStrategy {
|
||||
constructor(public maxTextureSize: number) {}
|
||||
computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
|
||||
// scalar tensor
|
||||
if (shape.length === 0) {
|
||||
return [1, 1];
|
||||
}
|
||||
const maxTextureSize = this.maxTextureSize;
|
||||
if (prefs && prefs.breakAxis !== undefined) {
|
||||
// check to see if dims fit
|
||||
const wsize = prefs.breakAxis >= shape.length ? 1 : shape.slice(prefs.breakAxis).reduce((a, b) => a * b);
|
||||
const hsize = prefs.breakAxis <= 0 ? 1 : shape.slice(0, prefs.breakAxis).reduce((a, b) => a * b);
|
||||
if (wsize > maxTextureSize || hsize > maxTextureSize) {
|
||||
// ignore preferences
|
||||
// continue with default layout
|
||||
Logger.verbose(
|
||||
'TextureLayout',
|
||||
`Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`);
|
||||
} else {
|
||||
return [wsize, hsize];
|
||||
}
|
||||
}
|
||||
const totalSize = shape.reduce((a, b) => a * b);
|
||||
|
||||
let width = Math.floor(Math.sqrt(totalSize));
|
||||
|
||||
for (; width < maxTextureSize && width < totalSize; width++) {
|
||||
if (totalSize % width === 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (width >= maxTextureSize || totalSize % width !== 0) {
|
||||
throw new Error(`The given dimensions are outside this GPU's boundaries: ${shape}`);
|
||||
}
|
||||
return [width, totalSize / width];
|
||||
}
|
||||
}
|
||||
|
||||
export class PreferLogicalStrategy implements TextureLayoutStrategy {
|
||||
constructor(public maxTextureSize: number) {}
|
||||
computeTextureWH(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
|
||||
const wh = this.computeTexture(shape, prefs);
|
||||
if (prefs && prefs.isPacked) {
|
||||
wh[0] /= 2;
|
||||
wh[1] /= 2;
|
||||
}
|
||||
|
||||
return wh;
|
||||
}
|
||||
|
||||
computeTexture(shape: readonly number[], prefs?: WidthHeightPrefs): [number, number] {
|
||||
const isPacked = prefs && prefs.isPacked;
|
||||
// scalar tensor
|
||||
if (shape.length === 0) {
|
||||
return isPacked ? [2, 2] : [1, 1];
|
||||
}
|
||||
let maxTextureSize = this.maxTextureSize;
|
||||
if (prefs && prefs.breakAxis !== undefined) {
|
||||
// check to see if dims fit
|
||||
const wsize = prefs.breakAxis >= shape.length ? 1 : shape.slice(prefs.breakAxis).reduce((a, b) => a * b);
|
||||
const hsize = prefs.breakAxis <= 0 ? 1 : shape.slice(0, prefs.breakAxis).reduce((a, b) => a * b);
|
||||
if (wsize > maxTextureSize || hsize > maxTextureSize) {
|
||||
// ignore preferences
|
||||
// continue with default layout
|
||||
Logger.verbose(
|
||||
'TextureLayout',
|
||||
`Given width/height preferences were unattainable: shape:${shape}, breakAxis:${prefs.breakAxis}`);
|
||||
} else {
|
||||
return [wsize, hsize];
|
||||
}
|
||||
}
|
||||
let logShape = shape.slice(0);
|
||||
if (isPacked) {
|
||||
maxTextureSize = maxTextureSize * 2;
|
||||
|
||||
// This logic ensures we accurately count the number of packed texels needed
|
||||
// to accommodate the tensor. We can only pack values in the same texel if
|
||||
// they are from adjacent pairs of rows/cols within the same batch. So if a
|
||||
// tensor has 3 rows, we pretend it has 4 rows in order to account for the
|
||||
// fact that the texels containing the third row are half empty.
|
||||
logShape = logShape.map(
|
||||
(d, i) => i >= logShape.length - 2 ? (logShape[i] % 2 === 0 ? logShape[i] : logShape[i] + 1) : logShape[i]);
|
||||
|
||||
// Packed texture height is at least 2 (the channel height of a single
|
||||
// texel).
|
||||
if (logShape.length === 1) {
|
||||
logShape = [2, logShape[0]];
|
||||
}
|
||||
}
|
||||
|
||||
// If logical shape is 2, we don't squeeze, since we want to match physical.
|
||||
if (logShape.length !== 2) {
|
||||
const squeezeResult = squeezeShape(logShape);
|
||||
logShape = squeezeResult.newShape;
|
||||
}
|
||||
|
||||
const size = sizeFromShape(logShape);
|
||||
if (logShape.length <= 1 && size <= maxTextureSize) {
|
||||
return [1, size];
|
||||
} else if (logShape.length === 2 && logShape[0] <= maxTextureSize && logShape[1] <= maxTextureSize) {
|
||||
return logShape as [number, number];
|
||||
} else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTextureSize && logShape[2] <= maxTextureSize) {
|
||||
return [logShape[0] * logShape[1], logShape[2]];
|
||||
} else if (logShape.length === 3 && logShape[0] <= maxTextureSize && logShape[1] * logShape[2] <= maxTextureSize) {
|
||||
return [logShape[0], logShape[1] * logShape[2]];
|
||||
} else if (
|
||||
logShape.length === 4 && logShape[0] * logShape[1] * logShape[2] <= maxTextureSize &&
|
||||
logShape[3] <= maxTextureSize) {
|
||||
return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
|
||||
} else if (
|
||||
logShape.length === 4 && logShape[0] <= maxTextureSize &&
|
||||
logShape[1] * logShape[2] * logShape[3] <= maxTextureSize) {
|
||||
return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
|
||||
} else {
|
||||
if (isPacked) {
|
||||
// For packed textures size equals the number of channels required to
|
||||
// accommodate the texture data. However in order to squarify such that
|
||||
// inner dimensions stay even, we rewrite size to equal the number of
|
||||
// texels. Then in the return statement we rehydrate the squarified
|
||||
// dimensions to channel units.
|
||||
return sizeToSquarishShape(size / 4).map(d => d * 2) as [number, number];
|
||||
}
|
||||
return sizeToSquarishShape(size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function squeezeShape(shape: number[], axis?: number[]): {newShape: number[]; keptDims: number[]} {
|
||||
const newShape: number[] = [];
|
||||
const keptDims: number[] = [];
|
||||
const isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0;
|
||||
const axes = (axis == null || isEmptyArray) ? null : parseAxisParam(axis, shape).sort();
|
||||
let j = 0;
|
||||
for (let i = 0; i < shape.length; ++i) {
|
||||
if (axes != null) {
|
||||
if (axes[j] === i && shape[i] !== 1) {
|
||||
throw new Error(`Can't squeeze axis ${i} since its dim '${shape[i]}' is not 1`);
|
||||
}
|
||||
if ((axes[j] == null || axes[j] > i) && shape[i] === 1) {
|
||||
newShape.push(shape[i]);
|
||||
keptDims.push(i);
|
||||
}
|
||||
if (axes[j] <= i) {
|
||||
j++;
|
||||
}
|
||||
}
|
||||
if (shape[i] !== 1) {
|
||||
newShape.push(shape[i]);
|
||||
keptDims.push(i);
|
||||
}
|
||||
}
|
||||
return {newShape, keptDims};
|
||||
}
|
||||
|
||||
export function parseAxisParam(axis: number|number[], shape: number[]): number[] {
|
||||
const rank = shape.length;
|
||||
|
||||
// Normalize input
|
||||
axis = axis == null ? shape.map((s, i) => i) : ([] as number[]).concat(axis);
|
||||
|
||||
// Check for valid range
|
||||
assert(
|
||||
axis.every(ax => ax >= -rank && ax < rank),
|
||||
() => `All values in axis param must be in range [-${rank}, ${rank}) but ` +
|
||||
`got axis ${axis}`);
|
||||
|
||||
// Check for only integers
|
||||
assert(
|
||||
axis.every(isInt),
|
||||
() => 'All values in axis param must be integers but ' +
|
||||
`got axis ${axis}`);
|
||||
|
||||
// Handle negative axis.
|
||||
return axis.map(a => a < 0 ? rank + a : a);
|
||||
}
|
||||
|
||||
export function isInt(a: number): boolean {
|
||||
return a % 1 === 0;
|
||||
}
|
||||
export function sizeFromShape(shape: number[]): number {
|
||||
if (shape.length === 0) {
|
||||
// Scalar.
|
||||
return 1;
|
||||
}
|
||||
let size = shape[0];
|
||||
for (let i = 1; i < shape.length; i++) {
|
||||
size *= shape[i];
|
||||
}
|
||||
return size;
|
||||
}
|
||||
export function getRowsCols(shape: number[]): [number, number] {
|
||||
if (shape.length === 0) {
|
||||
throw Error('Cannot get rows and columns of an empty shape array.');
|
||||
}
|
||||
|
||||
return [shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
|
||||
}
|
||||
export function sizeToSquarishShape(size: number): [number, number] {
|
||||
const width = Math.ceil(Math.sqrt(size));
|
||||
return [width, Math.ceil(size / width)];
|
||||
}
|
||||
export function getBatchDim(shape: number[], dimsToSkip = 2): number {
|
||||
return sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
|
||||
}
|
||||
199
js/web/lib/onnxjs/backends/webgl/texture-manager.ts
Normal file
199
js/web/lib/onnxjs/backends/webgl/texture-manager.ts
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Logger, Profiler} from '../../instrument';
|
||||
import {Tensor} from '../../tensor';
|
||||
|
||||
import {Encoder} from './texture-data-encoder';
|
||||
import {TextureLayoutStrategy} from './texture-layout-strategy';
|
||||
import {TextureData, TextureLayout} from './types';
|
||||
import {WebGLContext} from './webgl-context';
|
||||
|
||||
export interface TextureManagerConfig {
|
||||
reuseTextures?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* TextureManager is the mainly responsible for caching Textures
|
||||
* Textures are cached in 2 levels:
|
||||
* 1. the texures which are associated with a dataId (from Tensor)
|
||||
* Caching these is crucial to performance. These are In-use Textures
|
||||
* 2. textures which are not in use by any current ProgramInfo/Tensor
|
||||
* These are called Free Textures
|
||||
* TextureManager is also used to help creating textures. For this it
|
||||
* uses WebGLContext and TextureLayoutStrategy
|
||||
*/
|
||||
export class TextureManager {
|
||||
private readonly inUseTextures: Map<string, WebGLTexture[]>;
|
||||
private readonly idleTextures: Map<string, WebGLTexture[]>;
|
||||
private readonly textureLookup: Map<WebGLTexture, string>;
|
||||
|
||||
constructor(
|
||||
public glContext: WebGLContext, public layoutStrategy: TextureLayoutStrategy, public profiler: Readonly<Profiler>,
|
||||
private config: TextureManagerConfig) {
|
||||
if (config.reuseTextures) {
|
||||
this.inUseTextures = new Map();
|
||||
this.idleTextures = new Map();
|
||||
this.textureLookup = new Map();
|
||||
}
|
||||
}
|
||||
createTextureFromLayout(
|
||||
dataType: Tensor.DataType, layout: TextureLayout, data?: Tensor.NumberType, usage?: Encoder.Usage) {
|
||||
const textureDataType = this.toEncoderType(dataType);
|
||||
|
||||
const encoder = this.glContext.getEncoder(textureDataType, layout.channels || 1, usage);
|
||||
if (layout.isPacked && usage === Encoder.Usage.UploadOnly) {
|
||||
throw new Error('not implemented');
|
||||
}
|
||||
const width = layout.width;
|
||||
const height = layout.height;
|
||||
|
||||
let key: string|undefined;
|
||||
let inUseTextures: WebGLTexture[]|undefined;
|
||||
if (this.config.reuseTextures) {
|
||||
key = `${width}x${height}_${encoder.format}_${encoder.internalFormat}_${encoder.textureType}`;
|
||||
inUseTextures = this.inUseTextures.get(key);
|
||||
if (!inUseTextures) {
|
||||
inUseTextures = [];
|
||||
this.inUseTextures.set(key, inUseTextures);
|
||||
}
|
||||
|
||||
const idleTextures = this.idleTextures.get(key);
|
||||
if (idleTextures && idleTextures.length > 0) {
|
||||
const texture = idleTextures.pop()!;
|
||||
inUseTextures.push(texture);
|
||||
if (usage === Encoder.Usage.UploadOnly) {
|
||||
this.glContext.updateTexture(texture, width, height, encoder, this.toTextureData(dataType, data)!);
|
||||
}
|
||||
return texture;
|
||||
}
|
||||
}
|
||||
|
||||
Logger.verbose('TextureManager', `Creating new texture of size ${layout.width}x${layout.height}`);
|
||||
const texture = this.glContext.allocateTexture(width, height, encoder, this.toTextureData(dataType, data));
|
||||
|
||||
if (this.config.reuseTextures) {
|
||||
inUseTextures!.push(texture);
|
||||
this.textureLookup.set(texture, key!);
|
||||
}
|
||||
return texture;
|
||||
}
|
||||
readTexture(td: TextureData, dataType: Tensor.DataType, channels?: number): Tensor.NumberType {
|
||||
if (!channels) {
|
||||
channels = 1;
|
||||
}
|
||||
return this.profiler.event('backend', 'TextureManager.readTexture', () => {
|
||||
const dataSize = td.shape.reduce((a, b) => a * b) * channels!;
|
||||
const data = this.glContext.readTexture(
|
||||
td.texture, td.width, td.height, dataSize, this.toEncoderType(dataType), channels!);
|
||||
return this.toTensorData(dataType, data);
|
||||
});
|
||||
}
|
||||
readUint8TextureAsFloat(td: TextureData): Float32Array {
|
||||
return this.profiler.event('backend', 'TextureManager.readUint8TextureAsFloat', () => {
|
||||
const dataSize = td.shape.reduce((a, b) => a * b);
|
||||
const data = this.glContext.readTexture(td.texture, td.width, td.height, dataSize * 4, 'byte', 4);
|
||||
return new Float32Array(data.buffer, data.byteOffset, dataSize);
|
||||
});
|
||||
}
|
||||
releaseTexture(textureData: TextureData, deleteTexture?: boolean): void {
|
||||
let key: string|undefined;
|
||||
if (this.config.reuseTextures) {
|
||||
key = this.textureLookup.get(textureData.texture);
|
||||
if (key) {
|
||||
if (deleteTexture) {
|
||||
this.textureLookup.delete(key);
|
||||
}
|
||||
const inUseTextures = this.inUseTextures.get(key);
|
||||
if (inUseTextures) {
|
||||
const index = inUseTextures.indexOf(textureData.texture);
|
||||
if (index !== -1) {
|
||||
inUseTextures.splice(index, 1);
|
||||
let idleTextures = this.idleTextures.get(key);
|
||||
if (!idleTextures) {
|
||||
idleTextures = [];
|
||||
this.idleTextures.set(key, idleTextures);
|
||||
}
|
||||
idleTextures.push(textureData.texture);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!key || deleteTexture) {
|
||||
Logger.verbose('TextureManager', `Deleting texture of size ${textureData.width}x${textureData.height}`);
|
||||
this.glContext.deleteTexture(textureData.texture);
|
||||
}
|
||||
}
|
||||
toTensorData(dataType: Tensor.DataType, data: Encoder.DataArrayType): Tensor.NumberType {
|
||||
return (data instanceof Float32Array) ? data : new Float32Array(data);
|
||||
/*
|
||||
switch (dataType) {
|
||||
case 'int16':
|
||||
return new Int16Array(data);
|
||||
case 'int32':
|
||||
return new Int32Array(data);
|
||||
case 'int8':
|
||||
return new Int8Array(data);
|
||||
case 'uint16':
|
||||
return new Uint16Array(data);
|
||||
case 'uint32':
|
||||
return data as Uint32Array;
|
||||
case 'uint8':
|
||||
case 'bool':
|
||||
return data as Uint8Array;
|
||||
case 'float32':
|
||||
return data as Float32Array;
|
||||
case 'float64':
|
||||
return new Float64Array(data);
|
||||
default:
|
||||
throw new Error(`TensorData type ${dataType} is not supported`);
|
||||
}
|
||||
*/
|
||||
}
|
||||
toTextureData(dataType: Tensor.DataType, data: Tensor.NumberType|undefined): Encoder.DataArrayType|undefined {
|
||||
if (!data) {
|
||||
return undefined;
|
||||
}
|
||||
return (data instanceof Float32Array) ? data : new Float32Array(data);
|
||||
/*
|
||||
switch (dataType) {
|
||||
case 'int16':
|
||||
case 'int32':
|
||||
case 'uint16':
|
||||
case 'uint32':
|
||||
return (data.constructor === Uint32Array) ? data as Uint32Array : new Uint32Array(data);
|
||||
case 'int8':
|
||||
case 'uint8':
|
||||
case 'bool':
|
||||
return (data.constructor === Uint8Array) ? data as Uint8Array : new Uint8Array(data);
|
||||
case 'float32':
|
||||
case 'float64':
|
||||
return (data.constructor === Float32Array) ? data as Float32Array : new Float32Array(data);
|
||||
default:
|
||||
throw new Error(`TensorData type ${dataType} is not supported`);
|
||||
}
|
||||
*/
|
||||
}
|
||||
toEncoderType(_dataType: Tensor.DataType): Encoder.DataType {
|
||||
return 'float';
|
||||
// switch (dataType) {
|
||||
// case 'int16':
|
||||
// case 'int32':
|
||||
// case 'uint16':
|
||||
// case 'uint32':
|
||||
// return 'int';
|
||||
// case 'uint8':
|
||||
// case 'bool':
|
||||
// return 'byte';
|
||||
// case 'float32':
|
||||
// case 'float64':
|
||||
// return 'float';
|
||||
// default:
|
||||
// throw new Error(`TensorData type ${dataType} is not supported`);
|
||||
// }
|
||||
}
|
||||
clearActiveTextures(): void {
|
||||
this.glContext.clearActiveTextures();
|
||||
}
|
||||
}
|
||||
132
js/web/lib/onnxjs/backends/webgl/types.ts
Normal file
132
js/web/lib/onnxjs/backends/webgl/types.ts
Normal file
|
|
@ -0,0 +1,132 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Tensor} from '../../tensor';
|
||||
|
||||
import {WebGLInferenceHandler} from './inference-handler';
|
||||
import {WebGLContext} from './webgl-context';
|
||||
|
||||
/**
|
||||
* Represent an operator instance that can run in WebGL backend
|
||||
*/
|
||||
export interface WebGLOperator {
|
||||
createProgramInfo(handler: WebGLInferenceHandler, inputs: Tensor[]): ProgramInfo;
|
||||
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Layout info is used for mapping n-dimensional array to 2D textures
|
||||
* The layout is created by the TextureLayoutStrategy based on
|
||||
* the Tensor's dimensions and strides
|
||||
*/
|
||||
export interface TextureLayout {
|
||||
width: number;
|
||||
height: number;
|
||||
/**
|
||||
* specify the number of value that encoded in a single pixel
|
||||
*/
|
||||
channels: 1|2|3|4;
|
||||
/**
|
||||
* whether in packed mode or not
|
||||
*/
|
||||
isPacked?: boolean;
|
||||
/**
|
||||
* the normalized shape
|
||||
*/
|
||||
shape: readonly number[];
|
||||
/**
|
||||
* the stride of each dimensions, calculated according to shape
|
||||
*/
|
||||
strides: readonly number[];
|
||||
/**
|
||||
* the original shape(dims) of the corresponding tensor
|
||||
*/
|
||||
unpackedShape: readonly number[];
|
||||
}
|
||||
export interface TextureData extends TextureLayout {
|
||||
tensor: Tensor;
|
||||
texture: WebGLTexture;
|
||||
}
|
||||
|
||||
/**
|
||||
* A set of data that represent a shader program
|
||||
*/
|
||||
export interface ProgramInfo {
|
||||
/**
|
||||
* texture layouts for each input
|
||||
*/
|
||||
inputLayouts: TextureLayout[];
|
||||
/**
|
||||
* names of uniform samplers
|
||||
*/
|
||||
samplers: string[];
|
||||
/**
|
||||
* information of uniform variables
|
||||
*/
|
||||
variables?: VariableInfo[];
|
||||
/**
|
||||
* texture layout for output
|
||||
*/
|
||||
outputLayout: TextureLayout;
|
||||
/**
|
||||
* the shader's processing source code
|
||||
*/
|
||||
shaderSource: string;
|
||||
/**
|
||||
* whether the shader source contains a customized main function implementation
|
||||
*/
|
||||
hasMain?: boolean;
|
||||
params?: {[name: string]: number|number[]|string};
|
||||
|
||||
expectPackedInputs?: boolean;
|
||||
expectPackedoutputs?: boolean;
|
||||
}
|
||||
|
||||
export interface VariableInfo {
|
||||
type: 'float'|'int';
|
||||
name: string;
|
||||
arrayLength?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Information of uniforms that shader uses
|
||||
*/
|
||||
export interface UniformInfo {
|
||||
type: 'sampler2D'|VariableInfo['type'];
|
||||
name: string;
|
||||
arrayLength?: number;
|
||||
}
|
||||
|
||||
export interface UniformLocation extends UniformInfo {
|
||||
location: WebGLUniformLocation;
|
||||
}
|
||||
|
||||
/**
|
||||
* Artifact is the result of compilation
|
||||
* It does not contain input of output data
|
||||
* However anything that could be run as a "program"
|
||||
*/
|
||||
export interface Artifact {
|
||||
programInfo: ProgramInfo;
|
||||
program: WebGLProgram;
|
||||
uniformLocations: UniformLocation[];
|
||||
attribLocations: {position: number; textureCoord: number};
|
||||
}
|
||||
export declare namespace Artifact {
|
||||
type UniformLocations = Artifact['uniformLocations'];
|
||||
type AttribLocations = Artifact['attribLocations'];
|
||||
}
|
||||
|
||||
export interface UniformData {
|
||||
[name: string]: number|number[];
|
||||
}
|
||||
|
||||
/**
|
||||
* RunData contains all inputs that required to run a "program"
|
||||
*/
|
||||
export interface RunData {
|
||||
inputTextureDatas: TextureData[];
|
||||
outputTextureData: TextureData;
|
||||
uniformData: UniformData;
|
||||
draw?: (glContext: WebGLContext, artifact: Artifact) => void;
|
||||
}
|
||||
64
js/web/lib/onnxjs/backends/webgl/utils.ts
Normal file
64
js/web/lib/onnxjs/backends/webgl/utils.ts
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {assert} from '../../util';
|
||||
/**
|
||||
* Given a non RGBA shape calculate the R version
|
||||
* It is assumed that the dimensions are multiples of given channels
|
||||
* NOTE: it is always the last dim that gets packed.
|
||||
* @param unpackedShape original shape to create a packed version from
|
||||
*/
|
||||
export function getPackedShape(unpackedShape: readonly number[]): readonly number[] {
|
||||
const len = unpackedShape.length;
|
||||
return unpackedShape.slice(0, len - 1).concat(unpackedShape[len - 1] / 4);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the function name from an input sampler name.
|
||||
* @param samplerName Name of the sampler.
|
||||
*/
|
||||
export function generateShaderFuncNameFromInputSamplerName(samplerName: string): string {
|
||||
assert(typeof samplerName !== 'undefined' && samplerName.length !== 0, () => 'empty string found for sampler name');
|
||||
return 'get' + samplerName.charAt(0).toUpperCase() + samplerName.slice(1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates the function name from an input sampler name at output coordinates.
|
||||
* @param samplerName Name of the sampler.
|
||||
*/
|
||||
export function generateShaderFuncNameFromInputSamplerNameAtOutCoords(samplerName: string): string {
|
||||
assert(typeof samplerName !== 'undefined' && samplerName.length !== 0, () => 'empty string found for sampler name');
|
||||
return 'get' + samplerName.charAt(0).toUpperCase() + samplerName.slice(1) + 'AtOutCoords';
|
||||
}
|
||||
|
||||
/** Returns a new input shape (a copy) that has a squeezed logical shape. */
|
||||
export function squeezeInputShape(inputShape: readonly number[], squeezedShape: number[]): number[] {
|
||||
// Deep copy.
|
||||
let newInputShape: number[] = JSON.parse(JSON.stringify(inputShape));
|
||||
newInputShape = squeezedShape;
|
||||
return newInputShape;
|
||||
}
|
||||
|
||||
/** Returns a list of squeezed parameters for shader functions */
|
||||
export function getSqueezedParams(params: string[], keptDims: number[]): string {
|
||||
return keptDims.map(d => params[d]).join(', ');
|
||||
}
|
||||
|
||||
/** Returns the data type for different ranks. */
|
||||
export function getCoordsDataType(rank: number): string {
|
||||
if (rank <= 1) {
|
||||
return 'int';
|
||||
} else if (rank === 2) {
|
||||
return 'ivec2';
|
||||
} else if (rank === 3) {
|
||||
return 'ivec3';
|
||||
} else if (rank === 4) {
|
||||
return 'ivec4';
|
||||
} else if (rank === 5) {
|
||||
return 'ivec5';
|
||||
} else if (rank === 6) {
|
||||
return 'ivec6';
|
||||
} else {
|
||||
throw Error(`GPU for rank ${rank} is not yet supported`);
|
||||
}
|
||||
}
|
||||
91
js/web/lib/onnxjs/backends/webgl/webgl-context-factory.ts
Normal file
91
js/web/lib/onnxjs/backends/webgl/webgl-context-factory.ts
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Logger} from '../../instrument';
|
||||
|
||||
import {WebGLContext} from './webgl-context';
|
||||
|
||||
const cache: {[contextId: string]: WebGLContext} = {};
|
||||
|
||||
/**
|
||||
* This factory function creates proper WebGLRenderingContext based on
|
||||
* the current browsers capabilities
|
||||
* The order is from higher/most recent versions to most basic
|
||||
*/
|
||||
export function createWebGLContext(contextId?: 'webgl'|'webgl2'): WebGLContext {
|
||||
let context: WebGLContext|undefined;
|
||||
if ((!contextId || contextId === 'webgl2') && 'webgl2' in cache) {
|
||||
context = cache.webgl2;
|
||||
} else if ((!contextId || contextId === 'webgl') && 'webgl' in cache) {
|
||||
context = cache.webgl;
|
||||
}
|
||||
|
||||
context = context || createNewWebGLContext(contextId);
|
||||
contextId = contextId || context.version === 1 ? 'webgl' : 'webgl2';
|
||||
const gl = context.gl;
|
||||
|
||||
cache[contextId] = context;
|
||||
|
||||
if (gl.isContextLost()) {
|
||||
delete cache[contextId];
|
||||
return createWebGLContext(contextId);
|
||||
}
|
||||
|
||||
gl.disable(gl.DEPTH_TEST);
|
||||
gl.disable(gl.STENCIL_TEST);
|
||||
gl.disable(gl.BLEND);
|
||||
gl.disable(gl.DITHER);
|
||||
gl.disable(gl.POLYGON_OFFSET_FILL);
|
||||
gl.disable(gl.SAMPLE_COVERAGE);
|
||||
gl.enable(gl.SCISSOR_TEST);
|
||||
gl.enable(gl.CULL_FACE);
|
||||
gl.cullFace(gl.BACK);
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
export function createNewWebGLContext(contextId?: 'webgl'|'webgl2'): WebGLContext {
|
||||
const canvas = createCanvas();
|
||||
const contextAttributes: WebGLContextAttributes = {
|
||||
alpha: false,
|
||||
depth: false,
|
||||
antialias: false,
|
||||
stencil: false,
|
||||
preserveDrawingBuffer: false,
|
||||
premultipliedAlpha: false,
|
||||
failIfMajorPerformanceCaveat: false
|
||||
};
|
||||
let gl: WebGLRenderingContext|null;
|
||||
const ca = contextAttributes;
|
||||
if (!contextId || contextId === 'webgl2') {
|
||||
gl = canvas.getContext('webgl2', ca);
|
||||
if (gl) {
|
||||
try {
|
||||
return new WebGLContext(gl, 2);
|
||||
} catch (err) {
|
||||
Logger.warning('GlContextFactory', `failed to create WebGLContext using contextId 'webgl2'. Error: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!contextId || contextId === 'webgl') {
|
||||
gl = canvas.getContext('webgl', ca) || canvas.getContext('experimental-webgl', ca) as WebGLRenderingContext;
|
||||
if (gl) {
|
||||
try {
|
||||
return new WebGLContext(gl, 1);
|
||||
} catch (err) {
|
||||
Logger.warning(
|
||||
'GlContextFactory',
|
||||
`failed to create WebGLContext using contextId 'webgl' or 'experimental-webgl'. Error: ${err}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error('WebGL is not supported');
|
||||
}
|
||||
|
||||
function createCanvas(): HTMLCanvasElement {
|
||||
const canvas: HTMLCanvasElement = document.createElement('canvas');
|
||||
canvas.width = 1;
|
||||
canvas.height = 1;
|
||||
return canvas;
|
||||
}
|
||||
473
js/web/lib/onnxjs/backends/webgl/webgl-context.ts
Normal file
473
js/web/lib/onnxjs/backends/webgl/webgl-context.ts
Normal file
|
|
@ -0,0 +1,473 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {env} from 'onnxruntime-common';
|
||||
|
||||
import * as DataEncoders from './texture-data-encoder';
|
||||
import {DataEncoder, Encoder} from './texture-data-encoder';
|
||||
|
||||
/**
|
||||
* Abstraction and wrapper around WebGLRenderingContext and its operations
|
||||
*/
|
||||
export class WebGLContext {
|
||||
gl: WebGLRenderingContext;
|
||||
version: 1|2;
|
||||
|
||||
private vertexbuffer: WebGLBuffer;
|
||||
private framebuffer: WebGLFramebuffer;
|
||||
|
||||
// WebGL flags and vital parameters
|
||||
private isFloatTextureAttachableToFrameBuffer: boolean;
|
||||
isFloat32DownloadSupported: boolean;
|
||||
isRenderFloat32Supported: boolean;
|
||||
isBlendSupported: boolean;
|
||||
maxTextureSize: number;
|
||||
// private maxCombinedTextureImageUnits: number;
|
||||
private maxTextureImageUnits: number;
|
||||
// private maxCubeMapTextureSize: number;
|
||||
// private shadingLanguageVersion: string;
|
||||
// private webglVendor: string;
|
||||
// private webglVersion: string;
|
||||
|
||||
// WebGL2 flags and vital parameters
|
||||
// private max3DTextureSize: number;
|
||||
// private maxArrayTextureLayers: number;
|
||||
// private maxColorAttachments: number;
|
||||
// private maxDrawBuffers: number;
|
||||
|
||||
// WebGL extensions
|
||||
textureFloatExtension: unknown|null;
|
||||
// eslint-disable-next-line camelcase
|
||||
textureHalfFloatExtension: OES_texture_half_float|null;
|
||||
|
||||
// WebGL2 extensions
|
||||
colorBufferFloatExtension: unknown|null;
|
||||
|
||||
private disposed: boolean;
|
||||
private frameBufferBound = false;
|
||||
|
||||
constructor(gl: WebGLRenderingContext, version: 1|2) {
|
||||
this.gl = gl;
|
||||
this.version = version;
|
||||
|
||||
this.getExtensions();
|
||||
this.vertexbuffer = this.createVertexbuffer();
|
||||
this.framebuffer = this.createFramebuffer();
|
||||
this.queryVitalParameters();
|
||||
}
|
||||
|
||||
allocateTexture(width: number, height: number, encoder: DataEncoder, data?: Encoder.DataArrayType): WebGLTexture {
|
||||
const gl = this.gl;
|
||||
// create the texture
|
||||
const texture = gl.createTexture();
|
||||
if (!texture) {
|
||||
throw new Error('failed to create texture');
|
||||
}
|
||||
// bind the texture so the following methods effect this texture.
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
|
||||
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
|
||||
const buffer = data ? encoder.encode(data, width * height) : null;
|
||||
gl.texImage2D(
|
||||
gl.TEXTURE_2D,
|
||||
0, // Level of detail.
|
||||
encoder.internalFormat, width, height,
|
||||
0, // Always 0 in OpenGL ES.
|
||||
encoder.format, encoder.textureType, buffer);
|
||||
this.checkError();
|
||||
return texture;
|
||||
}
|
||||
updateTexture(
|
||||
texture: WebGLTexture, width: number, height: number, encoder: DataEncoder, data: Encoder.DataArrayType): void {
|
||||
const gl = this.gl;
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
const buffer = encoder.encode(data, width * height);
|
||||
gl.texSubImage2D(
|
||||
gl.TEXTURE_2D,
|
||||
0, // level
|
||||
0, // xoffset
|
||||
0, // yoffset
|
||||
width, height, encoder.format, encoder.textureType, buffer);
|
||||
this.checkError();
|
||||
}
|
||||
attachFramebuffer(texture: WebGLTexture, width: number, height: number): void {
|
||||
const gl = this.gl;
|
||||
// Make it the target for framebuffer operations - including rendering.
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
gl.bindFramebuffer(gl.FRAMEBUFFER, this.framebuffer);
|
||||
gl.framebufferTexture2D(
|
||||
gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture,
|
||||
0); // 0, we aren't using MIPMAPs
|
||||
this.checkError();
|
||||
gl.viewport(0, 0, width, height);
|
||||
gl.scissor(0, 0, width, height);
|
||||
}
|
||||
readTexture(
|
||||
texture: WebGLTexture, width: number, height: number, dataSize: number, dataType: Encoder.DataType,
|
||||
channels: number): Encoder.DataArrayType {
|
||||
const gl = this.gl;
|
||||
if (!channels) {
|
||||
channels = 1;
|
||||
}
|
||||
if (!this.frameBufferBound) {
|
||||
this.attachFramebuffer(texture, width, height);
|
||||
}
|
||||
const encoder = this.getEncoder(dataType, channels);
|
||||
const buffer = encoder.allocate(width * height);
|
||||
// bind texture to framebuffer
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
gl.framebufferTexture2D(
|
||||
gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture,
|
||||
0); // 0, we aren't using MIPMAPs
|
||||
// TODO: Check if framebuffer is ready
|
||||
gl.readPixels(0, 0, width, height, gl.RGBA, encoder.textureType, buffer);
|
||||
this.checkError();
|
||||
// unbind FB
|
||||
return encoder.decode(buffer, dataSize);
|
||||
}
|
||||
isFramebufferReady(): boolean {
|
||||
// TODO: Implement logic to check if the framebuffer is ready
|
||||
return true;
|
||||
}
|
||||
getActiveTexture(): string {
|
||||
const gl = this.gl;
|
||||
const n = gl.getParameter(this.gl.ACTIVE_TEXTURE);
|
||||
return `TEXTURE${(n - gl.TEXTURE0)}`;
|
||||
}
|
||||
getTextureBinding(): WebGLTexture {
|
||||
return this.gl.getParameter(this.gl.TEXTURE_BINDING_2D);
|
||||
}
|
||||
getFramebufferBinding(): WebGLFramebuffer {
|
||||
return this.gl.getParameter(this.gl.FRAMEBUFFER_BINDING);
|
||||
}
|
||||
setVertexAttributes(positionHandle: number, textureCoordHandle: number): void {
|
||||
const gl = this.gl;
|
||||
gl.vertexAttribPointer(positionHandle, 3, gl.FLOAT, false, 20, 0);
|
||||
gl.enableVertexAttribArray(positionHandle);
|
||||
if (textureCoordHandle !== -1) {
|
||||
gl.vertexAttribPointer(textureCoordHandle, 2, gl.FLOAT, false, 20, 12);
|
||||
gl.enableVertexAttribArray(textureCoordHandle);
|
||||
}
|
||||
this.checkError();
|
||||
}
|
||||
createProgram(
|
||||
vertexShader: WebGLShader,
|
||||
fragShader: WebGLShader,
|
||||
): WebGLProgram {
|
||||
const gl = this.gl;
|
||||
const program = gl.createProgram()!;
|
||||
|
||||
// the program consists of our shaders
|
||||
gl.attachShader(program, vertexShader);
|
||||
gl.attachShader(program, fragShader);
|
||||
gl.linkProgram(program);
|
||||
return program;
|
||||
}
|
||||
compileShader(shaderSource: string, shaderType: number): WebGLShader {
|
||||
const gl = this.gl;
|
||||
const shader = gl.createShader(shaderType);
|
||||
if (!shader) {
|
||||
throw new Error(`createShader() returned null with type ${shaderType}`);
|
||||
}
|
||||
|
||||
gl.shaderSource(shader, shaderSource);
|
||||
gl.compileShader(shader);
|
||||
if (gl.getShaderParameter(shader, gl.COMPILE_STATUS) === false) {
|
||||
throw new Error(`Failed to compile shader: ${gl.getShaderInfoLog(shader)}`);
|
||||
}
|
||||
return shader;
|
||||
}
|
||||
deleteShader(shader: WebGLShader): void {
|
||||
this.gl.deleteShader(shader);
|
||||
}
|
||||
bindTextureToUniform(texture: WebGLTexture, position: number, uniformHandle: WebGLUniformLocation): void {
|
||||
const gl = this.gl;
|
||||
gl.activeTexture(gl.TEXTURE0 + position);
|
||||
this.checkError();
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
this.checkError();
|
||||
gl.uniform1i(uniformHandle, position);
|
||||
this.checkError();
|
||||
}
|
||||
draw(): void {
|
||||
this.gl.drawArrays(this.gl.TRIANGLE_STRIP, 0, 4);
|
||||
this.checkError();
|
||||
}
|
||||
checkError(): void {
|
||||
if (env.debug) {
|
||||
const gl = this.gl;
|
||||
const error = gl.getError();
|
||||
let label = '';
|
||||
switch (error) {
|
||||
case (gl.NO_ERROR):
|
||||
return;
|
||||
case (gl.INVALID_ENUM):
|
||||
label = 'INVALID_ENUM';
|
||||
break;
|
||||
case (gl.INVALID_VALUE):
|
||||
label = 'INVALID_VALUE';
|
||||
break;
|
||||
case (gl.INVALID_OPERATION):
|
||||
label = 'INVALID_OPERATION';
|
||||
break;
|
||||
case (gl.INVALID_FRAMEBUFFER_OPERATION):
|
||||
label = 'INVALID_FRAMEBUFFER_OPERATION';
|
||||
break;
|
||||
case (gl.OUT_OF_MEMORY):
|
||||
label = 'OUT_OF_MEMORY';
|
||||
break;
|
||||
case (gl.CONTEXT_LOST_WEBGL):
|
||||
label = 'CONTEXT_LOST_WEBGL';
|
||||
break;
|
||||
default:
|
||||
label = `Unknown WebGL Error: ${error.toString(16)}`;
|
||||
}
|
||||
throw new Error(label);
|
||||
}
|
||||
}
|
||||
deleteTexture(texture: WebGLTexture): void {
|
||||
this.gl.deleteTexture(texture);
|
||||
}
|
||||
deleteProgram(program: WebGLProgram): void {
|
||||
this.gl.deleteProgram(program);
|
||||
}
|
||||
getEncoder(dataType: Encoder.DataType, channels: number, usage: Encoder.Usage = Encoder.Usage.Default): DataEncoder {
|
||||
if (this.version === 2) {
|
||||
return new DataEncoders.RedFloat32DataEncoder(this.gl as WebGL2RenderingContext, channels);
|
||||
}
|
||||
|
||||
switch (dataType) {
|
||||
case 'float':
|
||||
if (usage === Encoder.Usage.UploadOnly || this.isRenderFloat32Supported) {
|
||||
return new DataEncoders.RGBAFloatDataEncoder(this.gl, channels);
|
||||
} else {
|
||||
return new DataEncoders.RGBAFloatDataEncoder(
|
||||
this.gl, channels, this.textureHalfFloatExtension!.HALF_FLOAT_OES);
|
||||
}
|
||||
case 'int':
|
||||
throw new Error('not implemented');
|
||||
case 'byte':
|
||||
return new DataEncoders.Uint8DataEncoder(this.gl, channels);
|
||||
default:
|
||||
throw new Error(`Invalid dataType: ${dataType}`);
|
||||
}
|
||||
}
|
||||
clearActiveTextures(): void {
|
||||
const gl = this.gl;
|
||||
for (let unit = 0; unit < this.maxTextureImageUnits; ++unit) {
|
||||
gl.activeTexture(gl.TEXTURE0 + unit);
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
}
|
||||
}
|
||||
dispose(): void {
|
||||
if (this.disposed) {
|
||||
return;
|
||||
}
|
||||
const gl = this.gl;
|
||||
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
||||
gl.deleteFramebuffer(this.framebuffer);
|
||||
gl.bindBuffer(gl.ARRAY_BUFFER, null);
|
||||
gl.deleteBuffer(this.vertexbuffer);
|
||||
gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null);
|
||||
gl.finish();
|
||||
this.disposed = true;
|
||||
}
|
||||
|
||||
private createDefaultGeometry(): Float32Array {
|
||||
// Sets of x,y,z(=0),s,t coordinates.
|
||||
return new Float32Array([
|
||||
-1.0, 1.0, 0.0, 0.0, 1.0, // upper left
|
||||
-1.0, -1.0, 0.0, 0.0, 0.0, // lower left
|
||||
1.0, 1.0, 0.0, 1.0, 1.0, // upper right
|
||||
1.0, -1.0, 0.0, 1.0, 0.0 // lower right
|
||||
]);
|
||||
}
|
||||
private createVertexbuffer(): WebGLBuffer {
|
||||
const gl = this.gl;
|
||||
const buffer = gl.createBuffer();
|
||||
if (!buffer) {
|
||||
throw new Error('createBuffer() returned null');
|
||||
}
|
||||
const geometry = this.createDefaultGeometry();
|
||||
gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
|
||||
gl.bufferData(gl.ARRAY_BUFFER, geometry, gl.STATIC_DRAW);
|
||||
this.checkError();
|
||||
return buffer;
|
||||
}
|
||||
private createFramebuffer(): WebGLFramebuffer {
|
||||
const fb = this.gl.createFramebuffer();
|
||||
if (!fb) {
|
||||
throw new Error('createFramebuffer returned null');
|
||||
}
|
||||
return fb;
|
||||
}
|
||||
|
||||
private queryVitalParameters(): void {
|
||||
const gl = this.gl;
|
||||
|
||||
this.isFloatTextureAttachableToFrameBuffer = this.checkFloatTextureAttachableToFrameBuffer();
|
||||
this.isRenderFloat32Supported = this.checkRenderFloat32();
|
||||
this.isFloat32DownloadSupported = this.checkFloat32Download();
|
||||
|
||||
if (this.version === 1 && !this.textureHalfFloatExtension && !this.isRenderFloat32Supported) {
|
||||
throw new Error('both float32 and float16 TextureType are not supported');
|
||||
}
|
||||
|
||||
this.isBlendSupported = !this.isRenderFloat32Supported || this.checkFloat32Blend();
|
||||
|
||||
// this.maxCombinedTextureImageUnits = gl.getParameter(gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS);
|
||||
this.maxTextureSize = gl.getParameter(gl.MAX_TEXTURE_SIZE);
|
||||
this.maxTextureImageUnits = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
|
||||
// this.maxCubeMapTextureSize = gl.getParameter(gl.MAX_CUBE_MAP_TEXTURE_SIZE);
|
||||
// this.shadingLanguageVersion = gl.getParameter(gl.SHADING_LANGUAGE_VERSION);
|
||||
// this.webglVendor = gl.getParameter(gl.VENDOR);
|
||||
// this.webglVersion = gl.getParameter(gl.VERSION);
|
||||
|
||||
if (this.version === 2) {
|
||||
// this.max3DTextureSize = gl.getParameter(WebGL2RenderingContext.MAX_3D_TEXTURE_SIZE);
|
||||
// this.maxArrayTextureLayers = gl.getParameter(WebGL2RenderingContext.MAX_ARRAY_TEXTURE_LAYERS);
|
||||
// this.maxColorAttachments = gl.getParameter(WebGL2RenderingContext.MAX_COLOR_ATTACHMENTS);
|
||||
// this.maxDrawBuffers = gl.getParameter(WebGL2RenderingContext.MAX_DRAW_BUFFERS);
|
||||
}
|
||||
}
|
||||
private getExtensions(): void {
|
||||
if (this.version === 2) {
|
||||
this.colorBufferFloatExtension = this.gl.getExtension('EXT_color_buffer_float');
|
||||
} else {
|
||||
this.textureFloatExtension = this.gl.getExtension('OES_texture_float');
|
||||
this.textureHalfFloatExtension = this.gl.getExtension('OES_texture_half_float');
|
||||
}
|
||||
}
|
||||
|
||||
private checkFloatTextureAttachableToFrameBuffer(): boolean {
|
||||
// test whether Float32 texture is supported:
|
||||
// STEP.1 create a float texture
|
||||
const gl = this.gl;
|
||||
const texture = gl.createTexture();
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
const internalFormat = this.version === 2 ? (gl as unknown as {RGBA32F: number}).RGBA32F : gl.RGBA;
|
||||
gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null);
|
||||
// STEP.2 bind a frame buffer
|
||||
const frameBuffer = gl.createFramebuffer();
|
||||
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
|
||||
// STEP.3 attach texture to framebuffer
|
||||
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
|
||||
// STEP.4 test whether framebuffer is complete
|
||||
const isComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
||||
gl.deleteTexture(texture);
|
||||
gl.deleteFramebuffer(frameBuffer);
|
||||
return isComplete;
|
||||
}
|
||||
|
||||
private checkRenderFloat32(): boolean {
|
||||
if (this.version === 2) {
|
||||
if (!this.colorBufferFloatExtension) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!this.textureFloatExtension) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return this.isFloatTextureAttachableToFrameBuffer;
|
||||
}
|
||||
|
||||
private checkFloat32Download(): boolean {
|
||||
if (this.version === 2) {
|
||||
if (!this.colorBufferFloatExtension) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!this.textureFloatExtension) {
|
||||
return false;
|
||||
}
|
||||
if (!this.gl.getExtension('WEBGL_color_buffer_float')) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return this.isFloatTextureAttachableToFrameBuffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether GL_BLEND is supported
|
||||
*/
|
||||
private checkFloat32Blend(): boolean {
|
||||
// it looks like currently (2019-05-08) there is no easy way to detect whether BLEND is supported
|
||||
// https://github.com/microsoft/onnxjs/issues/145
|
||||
|
||||
const gl = this.gl;
|
||||
|
||||
let texture: WebGLTexture|null|undefined;
|
||||
let frameBuffer: WebGLFramebuffer|null|undefined;
|
||||
let vertexShader: WebGLShader|null|undefined;
|
||||
let fragmentShader: WebGLShader|null|undefined;
|
||||
let program: WebGLProgram|null|undefined;
|
||||
|
||||
try {
|
||||
texture = gl.createTexture();
|
||||
frameBuffer = gl.createFramebuffer();
|
||||
gl.bindTexture(gl.TEXTURE_2D, texture);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
const internalFormat = this.version === 2 ? (gl as unknown as {RGBA32F: number}).RGBA32F : gl.RGBA;
|
||||
gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, 1, 1, 0, gl.RGBA, gl.FLOAT, null);
|
||||
|
||||
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
|
||||
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
|
||||
|
||||
gl.enable(gl.BLEND);
|
||||
|
||||
vertexShader = gl.createShader(gl.VERTEX_SHADER);
|
||||
if (!vertexShader) {
|
||||
return false;
|
||||
}
|
||||
gl.shaderSource(vertexShader, 'void main(){}');
|
||||
gl.compileShader(vertexShader);
|
||||
|
||||
fragmentShader = gl.createShader(gl.FRAGMENT_SHADER);
|
||||
if (!fragmentShader) {
|
||||
return false;
|
||||
}
|
||||
gl.shaderSource(fragmentShader, 'precision highp float;void main(){gl_FragColor=vec4(0.5);}');
|
||||
gl.compileShader(fragmentShader);
|
||||
|
||||
program = gl.createProgram();
|
||||
if (!program) {
|
||||
return false;
|
||||
}
|
||||
gl.attachShader(program, vertexShader);
|
||||
gl.attachShader(program, fragmentShader);
|
||||
gl.linkProgram(program);
|
||||
gl.useProgram(program);
|
||||
|
||||
gl.drawArrays(gl.POINTS, 0, 1);
|
||||
return gl.getError() === gl.NO_ERROR;
|
||||
|
||||
} finally {
|
||||
gl.disable(gl.BLEND);
|
||||
|
||||
if (program) {
|
||||
gl.deleteProgram(program);
|
||||
}
|
||||
if (vertexShader) {
|
||||
gl.deleteShader(vertexShader);
|
||||
}
|
||||
if (fragmentShader) {
|
||||
gl.deleteShader(fragmentShader);
|
||||
}
|
||||
if (frameBuffer) {
|
||||
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
||||
gl.deleteFramebuffer(frameBuffer);
|
||||
}
|
||||
if (texture) {
|
||||
gl.bindTexture(gl.TEXTURE_2D, null);
|
||||
gl.deleteTexture(texture);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
163
js/web/lib/onnxjs/execution-plan.ts
Normal file
163
js/web/lib/onnxjs/execution-plan.ts
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {SessionHandler} from './backend';
|
||||
import {Graph} from './graph';
|
||||
import {Logger, Profiler} from './instrument';
|
||||
import {Operator} from './operators';
|
||||
import {Tensor} from './tensor';
|
||||
|
||||
class KernelOp {
|
||||
constructor(public op: Operator, public node: Graph.Node) {}
|
||||
}
|
||||
|
||||
export class ExecutionPlan {
|
||||
constructor(private graph: Graph, ops: Operator[], private profiler: Readonly<Profiler>) {
|
||||
this.initialize(ops);
|
||||
}
|
||||
|
||||
initialize(ops: Operator[]) {
|
||||
this.profiler.event('session', 'ExecutionPlan.initialize', () => {
|
||||
const graphNodes = this.graph.getNodes();
|
||||
if (graphNodes.length !== ops.length) {
|
||||
throw new Error('The size of nodes and OPs do not match.');
|
||||
}
|
||||
|
||||
this._ops = ops.map((op, i) => new KernelOp(op, graphNodes[i]));
|
||||
this.reset();
|
||||
|
||||
// look for starter node(s)
|
||||
this._starter = [];
|
||||
this._ops.forEach((op, i) => {
|
||||
let resolved = true;
|
||||
for (const input of op.node.inputs) {
|
||||
if (
|
||||
!this._values[input] // not an initialized input
|
||||
&& this.graph.getInputIndices().indexOf(input) === -1 // not model input
|
||||
) {
|
||||
resolved = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (resolved) {
|
||||
this._starter.push(i);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
reset() {
|
||||
this._values = this.graph.getValues().map(i => i.tensor);
|
||||
}
|
||||
|
||||
async execute(sessionHandler: SessionHandler, modelInputs: Tensor[]): Promise<Tensor[]> {
|
||||
return this.profiler.event('session', 'ExecutionPlan.execute', async () => {
|
||||
// reset mediem result
|
||||
this.reset();
|
||||
|
||||
// create inference handler
|
||||
const inferenceHandler = sessionHandler.createInferenceHandler();
|
||||
|
||||
// populate inputs value
|
||||
const graphInputs = this.graph.getInputIndices();
|
||||
if (modelInputs.length !== graphInputs.length) {
|
||||
throw new Error(`number of input tensors don't match the number of inputs to the model: actual: ${
|
||||
modelInputs.length} expected: ${graphInputs.length}`);
|
||||
}
|
||||
|
||||
modelInputs.forEach((input, i) => {
|
||||
const index = graphInputs[i];
|
||||
this._values[index] = input;
|
||||
});
|
||||
|
||||
// prepare running sequence
|
||||
const sequence: number[] = this._starter.slice(0);
|
||||
|
||||
// execution iterations
|
||||
const graphValues = this.graph.getValues();
|
||||
const graphNodes = this.graph.getNodes();
|
||||
|
||||
let rear = 0;
|
||||
while (rear < sequence.length) {
|
||||
const thisOpIndex = sequence[rear++];
|
||||
const thisOp = this._ops[thisOpIndex];
|
||||
|
||||
// check input
|
||||
const inputList = thisOp.node.inputs.map(i => this._values[i]);
|
||||
if (inputList.indexOf(undefined) !== -1) {
|
||||
throw new Error(`unresolved input detected: op: ${thisOp.node}`);
|
||||
}
|
||||
|
||||
// run
|
||||
const inputTensors = inputList as Tensor[];
|
||||
Logger.verbose(
|
||||
'ExecPlan',
|
||||
`Runing op:${thisOp.node.name} (${
|
||||
inputTensors.map((t, i) => `'${thisOp.node.inputs[i]}': ${t.type}[${t.dims.join(',')}]`).join(', ')})`);
|
||||
|
||||
const outputList = await this.profiler.event('node', thisOp.node.name, async () => {
|
||||
const op = thisOp.op;
|
||||
if (!op.checkInputs(inputTensors)) {
|
||||
throw new Error(`invalid inputs detected; op: ${thisOp.node.name}`);
|
||||
}
|
||||
|
||||
const result = op.run(inferenceHandler, inputTensors);
|
||||
|
||||
return result;
|
||||
});
|
||||
|
||||
// check output
|
||||
if (outputList.length !== thisOp.node.outputs.length) {
|
||||
throw new Error('the size of output does not match model definition.');
|
||||
}
|
||||
|
||||
// fill value
|
||||
outputList.forEach((output, i) => {
|
||||
const j = thisOp.node.outputs[i];
|
||||
if (this._values[j]) {
|
||||
throw new Error(`output [${j}] already has value: op:${thisOp.node.name}`);
|
||||
}
|
||||
this._values[j] = output;
|
||||
});
|
||||
|
||||
// resolve downstream nodes
|
||||
const downstreamNodes = new Set<number>();
|
||||
outputList.forEach((output, i) => {
|
||||
const j = thisOp.node.outputs[i];
|
||||
for (const currentDownstreamNodeIndex of graphValues[j].to) {
|
||||
const currentDownstreamNode = graphNodes[currentDownstreamNodeIndex];
|
||||
let resolved = true;
|
||||
for (const k of currentDownstreamNode.inputs) {
|
||||
if (!this._values[k]) {
|
||||
resolved = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (resolved) {
|
||||
downstreamNodes.add(currentDownstreamNodeIndex);
|
||||
}
|
||||
}
|
||||
});
|
||||
sequence.push(...downstreamNodes);
|
||||
}
|
||||
|
||||
const output: Tensor[] = [];
|
||||
this.graph.getOutputIndices().forEach((outputIndex) => {
|
||||
const thisValue = this._values[outputIndex];
|
||||
if (thisValue === undefined) {
|
||||
throw new Error(`required output [${outputIndex}] does not have value`);
|
||||
}
|
||||
// eslint-disable-next-line no-unused-expressions
|
||||
thisValue.data;
|
||||
output.push(thisValue);
|
||||
});
|
||||
Logger.verbose('ExecPlan', 'disposing of inferenceHandler');
|
||||
inferenceHandler.dispose();
|
||||
return output;
|
||||
});
|
||||
}
|
||||
|
||||
_values: Array<Tensor|undefined>;
|
||||
_ops: KernelOp[];
|
||||
_starter: number[];
|
||||
}
|
||||
553
js/web/lib/onnxjs/graph.ts
Normal file
553
js/web/lib/onnxjs/graph.ts
Normal file
|
|
@ -0,0 +1,553 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {onnx} from 'onnx-proto';
|
||||
|
||||
import {Attribute} from './attribute';
|
||||
import {Tensor} from './tensor';
|
||||
import {ProtoUtil} from './util';
|
||||
|
||||
export declare namespace Graph {
|
||||
export interface Shape {
|
||||
readonly dims: readonly number[];
|
||||
}
|
||||
export interface ValueType {
|
||||
readonly tensorType: Tensor.DataType;
|
||||
readonly shape: Shape;
|
||||
}
|
||||
export interface Value {
|
||||
// the tensor data. empty for non-initialized inputs
|
||||
readonly tensor?: Tensor;
|
||||
|
||||
// index to the Node where the value comes from. -1 for initializer.
|
||||
readonly from: number;
|
||||
|
||||
// indices to the Nodes where the values go to.
|
||||
readonly to: readonly number[];
|
||||
|
||||
// value type specification. empty for non-input values.
|
||||
readonly type?: ValueType;
|
||||
}
|
||||
export interface Node {
|
||||
// name of the node
|
||||
readonly name: string;
|
||||
|
||||
// the operator type
|
||||
readonly opType: string;
|
||||
|
||||
// indices to the Values where the inputs come from.
|
||||
readonly inputs: readonly number[];
|
||||
|
||||
// indices to the Values where the outpus go to.
|
||||
readonly outputs: readonly number[];
|
||||
|
||||
// the attributes that used by the operator
|
||||
readonly attributes: Attribute;
|
||||
}
|
||||
|
||||
/**
|
||||
* a Transformer is an instance that allows all possible transformation operations that applied to a graph
|
||||
*/
|
||||
export interface Transformer {
|
||||
removeAllIdentityNodes(): void;
|
||||
removeAllDropoutNodes(): void;
|
||||
// TODO: add generic functions to manipulate the graph
|
||||
}
|
||||
|
||||
// an initializer can use transformer to transform the graph
|
||||
export interface Initializer {
|
||||
transformGraph(transformer: Transformer): void;
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-redeclare
|
||||
export interface Graph {
|
||||
getInputIndices(): readonly number[];
|
||||
getInputNames(): readonly string[];
|
||||
getOutputIndices(): readonly number[];
|
||||
getOutputNames(): readonly string[];
|
||||
getValues(): readonly Graph.Value[];
|
||||
getNodes(): readonly Graph.Node[];
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-redeclare
|
||||
export const Graph = {
|
||||
/**
|
||||
* construct a graph from a graph protobuf type
|
||||
*/
|
||||
from: (graphProto: onnx.IGraphProto, initializer?: Graph.Initializer) => new GraphImpl(graphProto, initializer)
|
||||
};
|
||||
|
||||
class Value implements Graph.Value {
|
||||
constructor(valueInfo?: onnx.IValueInfoProto) {
|
||||
this._from = undefined;
|
||||
this._to = [];
|
||||
this.tensor = undefined;
|
||||
this.type = undefined;
|
||||
|
||||
if (valueInfo) {
|
||||
this.type = ProtoUtil.tensorValueTypeFromProto(valueInfo.type!.tensorType!);
|
||||
}
|
||||
}
|
||||
|
||||
_from?: number; // -1 represent from initializer
|
||||
get from() {
|
||||
return this._from!;
|
||||
}
|
||||
_to: number[];
|
||||
get to() {
|
||||
return this._to;
|
||||
}
|
||||
type?: Graph.ValueType;
|
||||
tensor?: Tensor;
|
||||
}
|
||||
|
||||
class Node implements Graph.Node {
|
||||
constructor(_nodeProto: onnx.INodeProto) {
|
||||
this.name = _nodeProto.name!;
|
||||
this.opType = _nodeProto.opType!;
|
||||
this.inputs = [];
|
||||
this.outputs = [];
|
||||
this.attributes = new Attribute(_nodeProto.attribute);
|
||||
this.executeNode = true;
|
||||
}
|
||||
|
||||
name: string;
|
||||
opType: string;
|
||||
inputs: number[];
|
||||
outputs: number[];
|
||||
attributes: Attribute;
|
||||
executeNode: boolean;
|
||||
}
|
||||
|
||||
class GraphImpl implements Graph, Graph.Transformer {
|
||||
private _allData: Value[];
|
||||
|
||||
private _allInputIndices: number[];
|
||||
private _allInputNames: string[];
|
||||
|
||||
private _allOutputIndices: number[];
|
||||
private _allOutputNames: string[];
|
||||
|
||||
private _nodes: Node[];
|
||||
|
||||
constructor(graph: onnx.IGraphProto, graphInitializer?: Graph.Initializer) {
|
||||
if (!graph) {
|
||||
throw new TypeError('graph is empty');
|
||||
}
|
||||
|
||||
// build the graph - will throw exceptions if something fatal is detected
|
||||
this.buildGraph(graph);
|
||||
|
||||
// execute any transformation logic for the graph (if applicable)
|
||||
this.transformGraph(graphInitializer);
|
||||
|
||||
// check for cycles and other inconsistencies - will throw exceptions if something fatal is detected
|
||||
this.checkIsAcyclic();
|
||||
}
|
||||
|
||||
getInputIndices(): readonly number[] {
|
||||
return this._allInputIndices;
|
||||
}
|
||||
|
||||
getInputNames(): readonly string[] {
|
||||
return this._allInputNames;
|
||||
}
|
||||
|
||||
getOutputIndices(): readonly number[] {
|
||||
return this._allOutputIndices;
|
||||
}
|
||||
|
||||
getOutputNames(): readonly string[] {
|
||||
return this._allOutputNames;
|
||||
}
|
||||
|
||||
getValues(): readonly Graph.Value[] {
|
||||
return this._allData;
|
||||
}
|
||||
|
||||
getNodes(): readonly Graph.Node[] {
|
||||
return this._nodes;
|
||||
}
|
||||
|
||||
private buildGraph(graph: onnx.IGraphProto) {
|
||||
const dataIndices = new Map<string, number>();
|
||||
this._allData = [];
|
||||
|
||||
this._allInputIndices = [];
|
||||
this._allInputNames = [];
|
||||
|
||||
this._allOutputIndices = [];
|
||||
this._allOutputNames = [];
|
||||
|
||||
this._nodes = [];
|
||||
|
||||
const nodesIndices = new Map<string, number>();
|
||||
|
||||
// scan all inputs
|
||||
if (!graph.input) {
|
||||
throw new Error('missing information in graph: input');
|
||||
}
|
||||
const inputValueNames = [];
|
||||
for (const i of graph.input) {
|
||||
if (dataIndices.has(i.name!)) {
|
||||
throw new Error(`duplicated input name: ${i.name}`);
|
||||
}
|
||||
const currentIndex = this._allData.push(new Value(i)) - 1;
|
||||
dataIndices.set(i.name!, currentIndex);
|
||||
inputValueNames.push(i.name!);
|
||||
}
|
||||
|
||||
// scan all initializers
|
||||
if (!graph.initializer) {
|
||||
throw new Error('missing information in graph: initializer');
|
||||
}
|
||||
for (const i of graph.initializer) {
|
||||
let index = dataIndices.get(i.name!);
|
||||
if (index === undefined) {
|
||||
const value = new Value();
|
||||
value.type = {
|
||||
shape: {dims: ProtoUtil.tensorDimsFromProto(i.dims!)},
|
||||
tensorType: ProtoUtil.tensorDataTypeFromProto(i.dataType!)
|
||||
};
|
||||
index = this._allData.push(value) - 1;
|
||||
dataIndices.set(i.name!, index);
|
||||
}
|
||||
this._allData[index]._from = -1;
|
||||
this._allData[index].tensor = Tensor.fromProto(i);
|
||||
}
|
||||
|
||||
// filter out input indices
|
||||
for (let i = 0; i < this._allData.length; i++) {
|
||||
if (!this._allData[i].tensor) {
|
||||
this._allInputIndices.push(i);
|
||||
this._allInputNames.push(inputValueNames[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// scan all outputs
|
||||
if (!graph.output) {
|
||||
throw new Error('missing information in graph: output');
|
||||
}
|
||||
for (const i of graph.output) {
|
||||
if (dataIndices.has(i.name!)) {
|
||||
throw new Error(`duplicated output name: ${i.name}`);
|
||||
}
|
||||
const currentIndex = this._allData.push(new Value(i)) - 1;
|
||||
dataIndices.set(i.name!, currentIndex);
|
||||
this._allOutputIndices.push(currentIndex);
|
||||
this._allOutputNames.push(i.name!);
|
||||
}
|
||||
|
||||
// scan all nodes
|
||||
if (!graph.node) {
|
||||
throw new Error('missing information in graph: node');
|
||||
}
|
||||
for (const nodeProto of graph.node) {
|
||||
if (!nodeProto.name) {
|
||||
// assign a name to the node if it doesn't have one
|
||||
for (let pick = 0;; pick++) {
|
||||
const name = `unnamed_${nodeProto.opType}_${pick}`;
|
||||
if (!nodesIndices.has(name)) {
|
||||
nodeProto.name = name;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (nodesIndices.has(nodeProto.name)) {
|
||||
throw new Error(`duplicated node name: ${nodeProto.name}`);
|
||||
}
|
||||
const currentIndex = this._nodes.push(new Node(nodeProto)) - 1;
|
||||
nodesIndices.set(nodeProto.name, currentIndex);
|
||||
}
|
||||
|
||||
// scan node's outputs
|
||||
for (let i = 0; i < this._nodes.length; i++) {
|
||||
const node = this._nodes[i];
|
||||
const nodeProto = graph.node[i];
|
||||
if (!nodeProto.output) {
|
||||
throw new Error(`missing output for node: ${nodeProto.name}`);
|
||||
}
|
||||
for (const output of nodeProto.output) {
|
||||
let dataIndex = dataIndices.get(output);
|
||||
if (typeof dataIndex === 'undefined') {
|
||||
dataIndex = this._allData.push(new Value()) - 1;
|
||||
dataIndices.set(output, dataIndex);
|
||||
}
|
||||
node.outputs.push(dataIndex);
|
||||
|
||||
if (this._allData[dataIndex]._from !== undefined) {
|
||||
throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
|
||||
}
|
||||
this._allData[dataIndex]._from = i;
|
||||
|
||||
// for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
|
||||
// operator and ignore the node from the graph
|
||||
if (nodeProto.opType === 'Constant') {
|
||||
if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) {
|
||||
throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
|
||||
}
|
||||
if (!nodeProto.output || nodeProto.output.length !== 1) {
|
||||
throw new Error('missing output or incorrect number of outputs for this Constant operator');
|
||||
}
|
||||
node.outputs.pop();
|
||||
node.executeNode = false;
|
||||
|
||||
this._allData[dataIndex]._from = -1;
|
||||
this._allData[dataIndex].tensor = Tensor.fromProto(nodeProto.attribute[0].t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// scan node's inputs
|
||||
for (let i = 0; i < this._nodes.length; i++) {
|
||||
const node = this._nodes[i];
|
||||
const nodeProto = graph.node[i];
|
||||
|
||||
if (!nodeProto.input) {
|
||||
throw new Error(`missing input for node: ${nodeProto.name}`);
|
||||
}
|
||||
for (const input of nodeProto.input) {
|
||||
const dataIndex = dataIndices.get(input);
|
||||
if (typeof dataIndex === 'undefined') {
|
||||
throw new Error(`unrecognized input '${input}' for node: ${nodeProto.name}`);
|
||||
}
|
||||
node.inputs.push(dataIndex);
|
||||
|
||||
this._allData[dataIndex]._to.push(i);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private checkIsAcyclic() {
|
||||
// go through the graph and check for cycles or other fatal inconsistencies
|
||||
const starters: Set<number> = new Set<number>();
|
||||
this._allInputIndices.forEach(i => {
|
||||
const data = this._allData[i];
|
||||
data._to.forEach(j => {
|
||||
starters.add(j);
|
||||
});
|
||||
});
|
||||
|
||||
// Iterative DFS to check for cycles
|
||||
const nodesStack = Array.from(starters);
|
||||
const nodesState = new Array<string>(this._nodes.length).fill('white');
|
||||
|
||||
while (nodesStack.length > 0) {
|
||||
const nodeIndex = nodesStack.pop()!;
|
||||
// this node has now been processed completely. Mark this node 'black' to denote this.
|
||||
if (nodesState[nodeIndex] === 'gray') {
|
||||
nodesState[nodeIndex] = 'black';
|
||||
} else {
|
||||
// this node is under processing stage. mark this node 'gray' to denote this.
|
||||
nodesStack.push(nodeIndex);
|
||||
nodesState[nodeIndex] = 'gray';
|
||||
|
||||
this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => {
|
||||
const data = this._allData[outgoingEdgeIndex];
|
||||
if (typeof data.tensor !== 'undefined') {
|
||||
throw new Error('node outputs should not be initialized');
|
||||
}
|
||||
if (data._from !== nodeIndex) {
|
||||
throw new Error('from property of the Value object doesn\'t match index of Node being processed');
|
||||
}
|
||||
data._to.forEach((downstreamNodeIndex) => {
|
||||
// back edge found - cyclic
|
||||
if (nodesState[downstreamNodeIndex] === 'gray') {
|
||||
throw new Error('model graph is cyclic');
|
||||
}
|
||||
// tree edge found - continue processing by adding it to stack
|
||||
else if (nodesState[downstreamNodeIndex] === 'white') {
|
||||
nodesStack.push(downstreamNodeIndex);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private transformGraph(graphInitializer?: Graph.Initializer): void {
|
||||
// apply common transform
|
||||
this.removeAllIdentityNodes();
|
||||
this.removeAllDropoutNodes();
|
||||
|
||||
// apply initializer specific transform
|
||||
if (graphInitializer) {
|
||||
graphInitializer.transformGraph(this);
|
||||
}
|
||||
|
||||
// finalize graph
|
||||
this.finalizeGraph();
|
||||
}
|
||||
|
||||
/**
|
||||
* finalize the graph.
|
||||
*
|
||||
* this function should be called after all the transformation completed.
|
||||
* this function removes all unnecessary nodes and values from the graph
|
||||
*/
|
||||
finalizeGraph() {
|
||||
let offset = 0;
|
||||
// delete all nodes that are not being executed
|
||||
for (let i = 0; i < this._nodes.length; i++) {
|
||||
if (!this._nodes[i].executeNode) {
|
||||
// delete this node and shift all subsequent nodes up
|
||||
offset++;
|
||||
// delete all output values
|
||||
this._nodes[i].outputs.forEach(ind => {
|
||||
this._allData[ind]._from = -2;
|
||||
});
|
||||
this._nodes.splice(i, 1);
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
if (offset > 0) {
|
||||
// update the value table
|
||||
this._nodes[i].inputs.forEach(value => {
|
||||
const ind = this._allData[value]._to.indexOf(i + offset);
|
||||
if (ind !== -1) {
|
||||
this._allData[value]._to[ind] = i;
|
||||
}
|
||||
});
|
||||
this._nodes[i].outputs.forEach(value => {
|
||||
if (this._allData[value]._from && this._allData[value]._from! === i + offset) {
|
||||
this._allData[value]._from = i;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
offset = 0;
|
||||
// delete all values that are not being referenced
|
||||
for (let i = 0; i < this._allData.length; i++) {
|
||||
// if current value is neither linked to next node, nor an output value, remove it.
|
||||
if (this._allData[i].from === -2 && this._allOutputIndices.indexOf(i + offset) === -1) {
|
||||
offset++;
|
||||
this._allData.splice(i, 1);
|
||||
i--;
|
||||
continue;
|
||||
}
|
||||
if (offset > 0) {
|
||||
let ind = -1;
|
||||
// if current value is neither an input value nor an initializer, find the node it's
|
||||
// coming from and update the corresponding node output
|
||||
if (this._allData[i].from !== undefined && this._allData[i].from !== -1) {
|
||||
ind = this._nodes[this._allData[i].from].outputs.indexOf(i + offset);
|
||||
if (ind !== -1) {
|
||||
this._nodes[this._allData[i].from].outputs[ind] = i;
|
||||
}
|
||||
} else {
|
||||
// if current value is an input value, update its reference in inputIndices
|
||||
ind = this._allInputIndices.indexOf(i + offset);
|
||||
if (ind !== -1) {
|
||||
this._allInputIndices[ind] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// find the node that the current value is linking to and update its input reference
|
||||
this._allData[i].to.forEach(node => {
|
||||
ind = this._nodes[node].inputs.indexOf(i + offset);
|
||||
if (ind !== -1) {
|
||||
this._nodes[node].inputs[ind] = i;
|
||||
}
|
||||
});
|
||||
if (this._allData[i].to.length === 0) {
|
||||
// if current value is a graph output, update its reference in outputIndices
|
||||
ind = this._allOutputIndices.indexOf(i + offset);
|
||||
if (ind !== -1) {
|
||||
this._allOutputIndices[ind] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete the specifed node. Assume the node has only one input and the first output connected to other nodes
|
||||
* @param nodeIndex The index of node to be deleted
|
||||
*/
|
||||
private deleteNode(nodeIndex: number) {
|
||||
const node = this._nodes[nodeIndex];
|
||||
if (node.inputs.length > 1) {
|
||||
throw new Error('Node deletion with multiple inputs is not supported. ');
|
||||
}
|
||||
if (node.outputs.length > 1) {
|
||||
for (let i = 1; i < node.outputs.length; i++) {
|
||||
if (this._allData[node.outputs[i]].to.length > 0) {
|
||||
throw new Error('Node deletion with more than one output connected to other nodes is not supported. ');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// this node wil not be executed
|
||||
node.executeNode = false;
|
||||
const inputValueIndex = node.inputs[0];
|
||||
const outputValueIndex = node.outputs[0];
|
||||
const nodesConsumingOutput = this._allData[outputValueIndex].to;
|
||||
|
||||
// remove this node from the to property of the input Value
|
||||
const delIndex = this._allData[inputValueIndex].to.indexOf(nodeIndex);
|
||||
// should not happen
|
||||
if (delIndex === -1) {
|
||||
throw new Error('The Value object doesn\'t have the current Node in it\'s \'to\' property ');
|
||||
}
|
||||
this._allData[inputValueIndex].to.splice(delIndex, 1);
|
||||
|
||||
// clear node indices consuming this output Value
|
||||
this._allData[outputValueIndex]._to = [];
|
||||
|
||||
// if the output of this node is a graph output, adjust the index appropriately
|
||||
const index = this._allOutputIndices.indexOf(outputValueIndex);
|
||||
if (index !== -1) {
|
||||
this._allOutputIndices[index] = inputValueIndex;
|
||||
}
|
||||
|
||||
// override the inputs for nodes consuming this node's output with the input to this node
|
||||
if (nodesConsumingOutput && nodesConsumingOutput.length > 0) {
|
||||
for (const nodeIndex of nodesConsumingOutput) {
|
||||
const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex);
|
||||
// should not happen
|
||||
if (replaceIndex === -1) {
|
||||
throw new Error('The Node object doesn\'t have the output Value in it\'s \'inputs\' property ');
|
||||
}
|
||||
this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex;
|
||||
this._allData[inputValueIndex].to.push(nodeIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removeAllDropoutNodes() {
|
||||
let nodeIndex = 0;
|
||||
for (const node of this._nodes) {
|
||||
// weed out 'Dropout' nodes so that no time is wasted in execution
|
||||
if (node.opType === 'Dropout') {
|
||||
// the node should have exactly 1 input and 1 or 2 outputs
|
||||
if (node.inputs.length !== 1) {
|
||||
throw new Error('Dropout nodes should only contain one input. ');
|
||||
}
|
||||
if (node.outputs.length !== 1 && node.outputs.length !== 2) {
|
||||
throw new Error('Dropout nodes should contain either 1 or 2 output(s)');
|
||||
}
|
||||
// the second output should not be referenced by any other node
|
||||
if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) {
|
||||
throw new Error('Dropout nodes\'s second output should not be referenced by other nodes');
|
||||
}
|
||||
this.deleteNode(nodeIndex);
|
||||
}
|
||||
nodeIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
removeAllIdentityNodes() {
|
||||
let nodeIndex = 0;
|
||||
for (const node of this._nodes) {
|
||||
// weed out 'Identity' nodes so that no time is wasted in execution
|
||||
if (node.opType === 'Identity') {
|
||||
this.deleteNode(nodeIndex);
|
||||
}
|
||||
nodeIndex++;
|
||||
}
|
||||
}
|
||||
}
|
||||
386
js/web/lib/onnxjs/instrument.ts
Normal file
386
js/web/lib/onnxjs/instrument.ts
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
export declare namespace Logger {
|
||||
export interface SeverityTypeMap {
|
||||
verbose: 'v';
|
||||
info: 'i';
|
||||
warning: 'w';
|
||||
error: 'e';
|
||||
}
|
||||
|
||||
export type Severity = keyof SeverityTypeMap;
|
||||
|
||||
export type Provider = 'none'|'console';
|
||||
|
||||
/**
|
||||
* Logging config that used to control the behavior of logger
|
||||
*/
|
||||
export interface Config {
|
||||
/**
|
||||
* Specify the logging provider. 'console' by default
|
||||
*/
|
||||
provider?: Provider;
|
||||
/**
|
||||
* Specify the minimal logger serverity. 'info' by default
|
||||
*/
|
||||
minimalSeverity?: Logger.Severity;
|
||||
/**
|
||||
* Whether to output date time in log. true by default
|
||||
*/
|
||||
logDateTime?: boolean;
|
||||
/**
|
||||
* Whether to output source information (Not yet supported). false by default
|
||||
*/
|
||||
logSourceLocation?: boolean;
|
||||
}
|
||||
|
||||
export interface CategorizedLogger {
|
||||
verbose(content: string): void;
|
||||
info(content: string): void;
|
||||
warning(content: string): void;
|
||||
error(content: string): void;
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-redeclare
|
||||
export interface Logger {
|
||||
(category: string): Logger.CategorizedLogger;
|
||||
|
||||
verbose(content: string): void;
|
||||
verbose(category: string, content: string): void;
|
||||
info(content: string): void;
|
||||
info(category: string, content: string): void;
|
||||
warning(content: string): void;
|
||||
warning(category: string, content: string): void;
|
||||
error(content: string): void;
|
||||
error(category: string, content: string): void;
|
||||
|
||||
/**
|
||||
* Reset the logger configuration.
|
||||
* @param config specify an optional default config
|
||||
*/
|
||||
reset(config?: Logger.Config): void;
|
||||
/**
|
||||
* Set the logger's behavior on the given category
|
||||
* @param category specify a category string. If '*' is specified, all previous configuration will be overwritten. If
|
||||
* '' is specified, the default behavior will be updated.
|
||||
* @param config the config object to indicate the logger's behavior
|
||||
*/
|
||||
set(category: string, config: Logger.Config): void;
|
||||
}
|
||||
|
||||
interface LoggerProvider {
|
||||
log(severity: Logger.Severity, content: string, category?: string): void;
|
||||
}
|
||||
class NoOpLoggerProvider implements LoggerProvider {
|
||||
log(_severity: Logger.Severity, _content: string, _category?: string) {
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
class ConsoleLoggerProvider implements LoggerProvider {
|
||||
log(severity: Logger.Severity, content: string, category?: string) {
|
||||
// eslint-disable-next-line no-console
|
||||
console.log(`${this.color(severity)} ${category ? '\x1b[35m' + category + '\x1b[0m ' : ''}${content}`);
|
||||
}
|
||||
|
||||
private color(severity: Logger.Severity) {
|
||||
switch (severity) {
|
||||
case 'verbose':
|
||||
return '\x1b[34;40mv\x1b[0m';
|
||||
case 'info':
|
||||
return '\x1b[32mi\x1b[0m';
|
||||
case 'warning':
|
||||
return '\x1b[30;43mw\x1b[0m';
|
||||
case 'error':
|
||||
return '\x1b[31;40me\x1b[0m';
|
||||
default:
|
||||
throw new Error(`unsupported severity: ${severity}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const SEVERITY_VALUE = {
|
||||
verbose: 1000,
|
||||
info: 2000,
|
||||
warning: 4000,
|
||||
error: 5000
|
||||
};
|
||||
|
||||
const LOGGER_PROVIDER_MAP: {readonly [provider: string]: Readonly<LoggerProvider>} = {
|
||||
['none']: new NoOpLoggerProvider(),
|
||||
['console']: new ConsoleLoggerProvider()
|
||||
};
|
||||
const LOGGER_DEFAULT_CONFIG = {
|
||||
provider: 'console',
|
||||
minimalSeverity: 'info',
|
||||
logDateTime: true,
|
||||
logSourceLocation: false
|
||||
};
|
||||
let LOGGER_CONFIG_MAP:
|
||||
{[category: string]: Readonly<Required<Logger.Config>>} = {['']: LOGGER_DEFAULT_CONFIG as Required<Logger.Config>};
|
||||
|
||||
function log(category: string): Logger.CategorizedLogger;
|
||||
function log(severity: Logger.Severity, content: string): void;
|
||||
function log(severity: Logger.Severity, category: string, content: string): void;
|
||||
function log(severity: Logger.Severity, arg1: string, arg2?: string): void;
|
||||
function log(
|
||||
arg0: string|Logger.Severity, arg1?: string, arg2?: string|number, arg3?: number): Logger.CategorizedLogger|void {
|
||||
if (arg1 === undefined) {
|
||||
// log(category: string): Logger.CategorizedLogger;
|
||||
return createCategorizedLogger(arg0);
|
||||
} else if (arg2 === undefined) {
|
||||
// log(severity, content);
|
||||
logInternal(arg0 as Logger.Severity, arg1, 1);
|
||||
} else if (typeof arg2 === 'number' && arg3 === undefined) {
|
||||
// log(severity, content, stack)
|
||||
logInternal(arg0 as Logger.Severity, arg1, arg2);
|
||||
} else if (typeof arg2 === 'string' && arg3 === undefined) {
|
||||
// log(severity, category, content)
|
||||
logInternal(arg0 as Logger.Severity, arg2, 1, arg1);
|
||||
} else if (typeof arg2 === 'string' && typeof arg3 === 'number') {
|
||||
// log(severity, category, content, stack)
|
||||
logInternal(arg0 as Logger.Severity, arg2, arg3, arg1);
|
||||
} else {
|
||||
throw new TypeError('input is valid');
|
||||
}
|
||||
}
|
||||
|
||||
function createCategorizedLogger(category: string): Logger.CategorizedLogger {
|
||||
return {
|
||||
verbose: log.verbose.bind(null, category),
|
||||
info: log.info.bind(null, category),
|
||||
warning: log.warning.bind(null, category),
|
||||
error: log.error.bind(null, category)
|
||||
};
|
||||
}
|
||||
|
||||
// NOTE: argument 'category' is put the last parameter beacause typescript
|
||||
// doesn't allow optional argument put in front of required argument. This
|
||||
// order is different from a usual logging API.
|
||||
function logInternal(severity: Logger.Severity, content: string, stack: number, category?: string) {
|
||||
const config = LOGGER_CONFIG_MAP[category || ''] || LOGGER_CONFIG_MAP[''];
|
||||
if (SEVERITY_VALUE[severity] < SEVERITY_VALUE[config.minimalSeverity]) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (config.logDateTime) {
|
||||
content = `${new Date().toISOString()}|${content}`;
|
||||
}
|
||||
|
||||
if (config.logSourceLocation) {
|
||||
// TODO: calculate source location from 'stack'
|
||||
}
|
||||
|
||||
LOGGER_PROVIDER_MAP[config.provider].log(severity, content, category);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-namespace
|
||||
namespace log {
|
||||
export function verbose(content: string): void;
|
||||
export function verbose(category: string, content: string): void;
|
||||
export function verbose(arg0: string, arg1?: string) {
|
||||
log('verbose', arg0, arg1);
|
||||
}
|
||||
export function info(content: string): void;
|
||||
export function info(category: string, content: string): void;
|
||||
export function info(arg0: string, arg1?: string) {
|
||||
log('info', arg0, arg1);
|
||||
}
|
||||
export function warning(content: string): void;
|
||||
export function warning(category: string, content: string): void;
|
||||
export function warning(arg0: string, arg1?: string) {
|
||||
log('warning', arg0, arg1);
|
||||
}
|
||||
export function error(content: string): void;
|
||||
export function error(category: string, content: string): void;
|
||||
export function error(arg0: string, arg1?: string) {
|
||||
log('error', arg0, arg1);
|
||||
}
|
||||
|
||||
export function reset(config?: Logger.Config): void {
|
||||
LOGGER_CONFIG_MAP = {};
|
||||
// tslint:disable-next-line:no-backbone-get-set-outside-model
|
||||
set('', config || {});
|
||||
}
|
||||
export function set(category: string, config: Logger.Config): void {
|
||||
if (category === '*') {
|
||||
reset(config);
|
||||
} else {
|
||||
const previousConfig = LOGGER_CONFIG_MAP[category] || LOGGER_DEFAULT_CONFIG;
|
||||
LOGGER_CONFIG_MAP[category] = {
|
||||
provider: config.provider || previousConfig.provider,
|
||||
minimalSeverity: config.minimalSeverity || previousConfig.minimalSeverity,
|
||||
logDateTime: (config.logDateTime === undefined) ? previousConfig.logDateTime : config.logDateTime,
|
||||
logSourceLocation: (config.logSourceLocation === undefined) ? previousConfig.logSourceLocation :
|
||||
config.logSourceLocation
|
||||
};
|
||||
}
|
||||
|
||||
// TODO: we want to support wildcard or regex?
|
||||
}
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-redeclare, @typescript-eslint/naming-convention
|
||||
export const Logger: Logger = log;
|
||||
|
||||
export declare namespace Profiler {
|
||||
export interface Config {
|
||||
maxNumberEvents?: number;
|
||||
flushBatchSize?: number;
|
||||
flushIntervalInMilliseconds?: number;
|
||||
}
|
||||
|
||||
export type EventCategory = 'session'|'node'|'op'|'backend';
|
||||
|
||||
export interface Event {
|
||||
end(): void;
|
||||
}
|
||||
}
|
||||
|
||||
class Event implements Profiler.Event {
|
||||
constructor(
|
||||
public category: Profiler.EventCategory, public name: string, public startTime: number,
|
||||
private endCallback: (e: Event) => void) {}
|
||||
|
||||
end() {
|
||||
this.endCallback(this);
|
||||
}
|
||||
}
|
||||
|
||||
class EventRecord {
|
||||
constructor(
|
||||
public category: Profiler.EventCategory, public name: string, public startTime: number, public endTime: number) {}
|
||||
}
|
||||
|
||||
export class Profiler {
|
||||
static create(config?: Profiler.Config): Profiler {
|
||||
if (config === undefined) {
|
||||
return new this();
|
||||
}
|
||||
return new this(config.maxNumberEvents, config.flushBatchSize, config.flushIntervalInMilliseconds);
|
||||
}
|
||||
|
||||
private constructor(maxNumberEvents?: number, flushBatchSize?: number, flushIntervalInMilliseconds?: number) {
|
||||
this._started = false;
|
||||
this._maxNumberEvents = maxNumberEvents === undefined ? 10000 : maxNumberEvents;
|
||||
this._flushBatchSize = flushBatchSize === undefined ? 10 : flushBatchSize;
|
||||
this._flushIntervalInMilliseconds = flushIntervalInMilliseconds === undefined ? 5000 : flushIntervalInMilliseconds;
|
||||
}
|
||||
|
||||
// start profiling
|
||||
start() {
|
||||
this._started = true;
|
||||
this._timingEvents = [];
|
||||
this._flushTime = now();
|
||||
this._flushPointer = 0;
|
||||
}
|
||||
|
||||
// stop profiling
|
||||
stop() {
|
||||
this._started = false;
|
||||
for (; this._flushPointer < this._timingEvents.length; this._flushPointer++) {
|
||||
this.logOneEvent(this._timingEvents[this._flushPointer]);
|
||||
}
|
||||
}
|
||||
|
||||
// create an event scope for the specific function
|
||||
event<T>(category: Profiler.EventCategory, name: string, func: () => T): T;
|
||||
event<T>(category: Profiler.EventCategory, name: string, func: () => Promise<T>): Promise<T>;
|
||||
|
||||
event<T>(category: Profiler.EventCategory, name: string, func: () => T | Promise<T>): T|Promise<T> {
|
||||
const event = this._started ? this.begin(category, name) : undefined;
|
||||
let isPromise = false;
|
||||
|
||||
try {
|
||||
const res = func();
|
||||
|
||||
// we consider a then-able object is a promise
|
||||
if (res && typeof (res as Promise<T>).then === 'function') {
|
||||
isPromise = true;
|
||||
return new Promise<T>((resolve, reject) => {
|
||||
(res as Promise<T>)
|
||||
.then(
|
||||
value => { // fulfilled
|
||||
resolve(value);
|
||||
if (event) {
|
||||
event.end();
|
||||
}
|
||||
},
|
||||
reason => { // rejected
|
||||
reject(reason);
|
||||
if (event) {
|
||||
event.end();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
} finally {
|
||||
if (!isPromise && event) {
|
||||
event.end();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// begin an event
|
||||
begin(category: Profiler.EventCategory, name: string): Event {
|
||||
if (!this._started) {
|
||||
throw new Error('profiler is not started yet');
|
||||
}
|
||||
const startTime = now();
|
||||
this.flush(startTime);
|
||||
return new Event(category, name, startTime, e => this.end(e));
|
||||
}
|
||||
|
||||
// end the specific event
|
||||
private end(event: Event) {
|
||||
if (this._timingEvents.length < this._maxNumberEvents) {
|
||||
const endTime = now();
|
||||
this._timingEvents.push(new EventRecord(event.category, event.name, event.startTime, endTime));
|
||||
this.flush(endTime);
|
||||
}
|
||||
}
|
||||
|
||||
private logOneEvent(event: EventRecord) {
|
||||
Logger.verbose(
|
||||
`Profiler.${event.category}`,
|
||||
`${(event.endTime - event.startTime).toFixed(2)}ms on event '${event.name}' at ${event.endTime.toFixed(2)}`);
|
||||
}
|
||||
|
||||
private flush(currentTime: number) {
|
||||
if (this._timingEvents.length - this._flushPointer >= this._flushBatchSize ||
|
||||
currentTime - this._flushTime >= this._flushIntervalInMilliseconds) {
|
||||
// should flush when either batch size accumlated or interval elepsed
|
||||
|
||||
for (const previousPointer = this._flushPointer; this._flushPointer < previousPointer + this._flushBatchSize &&
|
||||
this._flushPointer < this._timingEvents.length;
|
||||
this._flushPointer++) {
|
||||
this.logOneEvent(this._timingEvents[this._flushPointer]);
|
||||
}
|
||||
|
||||
this._flushTime = now();
|
||||
}
|
||||
}
|
||||
|
||||
get started() {
|
||||
return this._started;
|
||||
}
|
||||
private _started = false;
|
||||
private _timingEvents: EventRecord[];
|
||||
|
||||
private readonly _maxNumberEvents: number;
|
||||
|
||||
private readonly _flushBatchSize: number;
|
||||
private readonly _flushIntervalInMilliseconds: number;
|
||||
|
||||
private _flushTime: number;
|
||||
private _flushPointer = 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* returns a number to represent the current timestamp in a resolution as high as possible.
|
||||
*/
|
||||
export const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now;
|
||||
36
js/web/lib/onnxjs/model.ts
Normal file
36
js/web/lib/onnxjs/model.ts
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {onnx} from 'onnx-proto';
|
||||
|
||||
import {Graph} from './graph';
|
||||
import {OpSet} from './opset';
|
||||
import {LongUtil} from './util';
|
||||
|
||||
export class Model {
|
||||
// empty model
|
||||
constructor() {}
|
||||
|
||||
load(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
|
||||
const modelProto = onnx.ModelProto.decode(buf);
|
||||
const irVersion = LongUtil.longToNumber(modelProto.irVersion);
|
||||
if (irVersion < 3) {
|
||||
throw new Error('only support ONNX model with IR_VERSION>=3');
|
||||
}
|
||||
|
||||
this._opsets =
|
||||
modelProto.opsetImport.map(i => ({domain: i.domain as string, version: LongUtil.longToNumber(i.version!)}));
|
||||
|
||||
this._graph = Graph.from(modelProto.graph!, graphInitializer);
|
||||
}
|
||||
|
||||
private _graph: Graph;
|
||||
get graph(): Graph {
|
||||
return this._graph;
|
||||
}
|
||||
|
||||
private _opsets: OpSet[];
|
||||
get opsets(): readonly OpSet[] {
|
||||
return this._opsets;
|
||||
}
|
||||
}
|
||||
18
js/web/lib/onnxjs/operators.ts
Normal file
18
js/web/lib/onnxjs/operators.ts
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from './attribute';
|
||||
import {InferenceHandler} from './backend';
|
||||
import {Graph} from './graph';
|
||||
import {Tensor} from './tensor';
|
||||
|
||||
export interface Operator {
|
||||
initialize(attributes: Attribute, node: Graph.Node, graph: Graph): void;
|
||||
checkInputs(inputs: Tensor[]): boolean;
|
||||
run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
}
|
||||
|
||||
export const NUMBER_TYPES: readonly Tensor.DataType[] =
|
||||
['float32', 'float64', 'int32', 'int16', 'int8', 'uint16', 'uint32', 'uint8'];
|
||||
export const INT_TYPES: readonly Tensor.DataType[] = ['int32', 'int16', 'int8', 'uint16', 'uint32', 'uint8'];
|
||||
export const FLOAT_TYPES: readonly Tensor.DataType[] = ['float32', 'float64'];
|
||||
35
js/web/lib/onnxjs/ops/argMax.ts
Normal file
35
js/web/lib/onnxjs/ops/argMax.ts
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {NUMBER_TYPES, Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class ArgMax implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.axis = attributes.getInt('axis', 0);
|
||||
this.keepDims = attributes.getInt('keepdims', 1) === 1;
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected axis: number;
|
||||
protected keepDims: boolean;
|
||||
}
|
||||
57
js/web/lib/onnxjs/ops/batch-normalization.ts
Normal file
57
js/web/lib/onnxjs/ops/batch-normalization.ts
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
// currently this operator supports ONLY 'test' mode
|
||||
// inputs/outputs and parameters will reflect that
|
||||
// the operator implementation only supports test mode
|
||||
export abstract class BatchNormalization implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.epsilon = attributes.getFloat('epsilon', 1e-5);
|
||||
this.momentum = attributes.getFloat('momentum', 0.9);
|
||||
this.spatial = attributes.getInt('spatial', 1);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
const X = inputs[0];
|
||||
const scale = inputs[1];
|
||||
const B = inputs[2];
|
||||
const mean = inputs[3];
|
||||
const var_ = inputs[4];
|
||||
|
||||
// input should atleast have three dimensions - N,C,dim1,...,dimn
|
||||
// other inputs can have only one dimensions
|
||||
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1 || mean.dims.length !== 1 ||
|
||||
var_.dims.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1] || mean.dims[0] !== X.dims[1] ||
|
||||
var_.dims[0] !== X.dims[1]) {
|
||||
return false;
|
||||
}
|
||||
if ((X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') ||
|
||||
(B.type !== 'float32' && B.type !== 'float64') || (mean.type !== 'float32' && mean.type !== 'float64') ||
|
||||
(var_.type !== 'float32' && var_.type !== 'float64')) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected epsilon: number;
|
||||
protected momentum: number;
|
||||
protected spatial: number;
|
||||
}
|
||||
35
js/web/lib/onnxjs/ops/binary-op.ts
Normal file
35
js/web/lib/onnxjs/ops/binary-op.ts
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class BinaryOp implements Operator {
|
||||
constructor(
|
||||
protected typeConstraint: readonly Tensor.DataType[], protected opType?: string,
|
||||
protected resultType?: Tensor.DataType) {}
|
||||
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(_attributes: Attribute): void {}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (this.typeConstraint.indexOf(inputs[0].type) === -1) {
|
||||
return false;
|
||||
}
|
||||
if (inputs[0].type !== inputs[1].type) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
33
js/web/lib/onnxjs/ops/cast.ts
Normal file
33
js/web/lib/onnxjs/ops/cast.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
import {ProtoUtil} from '../util';
|
||||
|
||||
export abstract class Cast implements Operator {
|
||||
protected to: Tensor.DataType;
|
||||
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.to = ProtoUtil.tensorDataTypeFromProto(attributes.getInt('to'));
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (inputs[0].type === 'string') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
35
js/web/lib/onnxjs/ops/clip.ts
Normal file
35
js/web/lib/onnxjs/ops/clip.ts
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Clip implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.min = attributes.getFloat('min', -3.4028234663852886e+38);
|
||||
this.max = attributes.getFloat('max', 3.4028234663852886e+38);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected min: number;
|
||||
protected max: number;
|
||||
}
|
||||
49
js/web/lib/onnxjs/ops/concat.ts
Normal file
49
js/web/lib/onnxjs/ops/concat.ts
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Concat implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.axis = attributes.getInt('axis');
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length < 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
const inputType = inputs[0].type;
|
||||
const inputDimensionality = inputs[0].dims.length;
|
||||
|
||||
// TODO: Support string concat
|
||||
if (inputType === 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const input of inputs) {
|
||||
// make sure types of all inputs match
|
||||
if (input.type !== inputType) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// make sure the dimensionality of all inputs are the same
|
||||
if (input.dims.length !== inputDimensionality) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected axis: number;
|
||||
}
|
||||
91
js/web/lib/onnxjs/ops/conv.ts
Normal file
91
js/web/lib/onnxjs/ops/conv.ts
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Conv implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
// TODO : Make this generic enough to compute default attributes for multi-dimensional conv
|
||||
this.autoPad = attributes.getString('auto_pad', 'NOTSET');
|
||||
this.dilations = attributes.getInts('dilations', [1, 1]);
|
||||
this.group = attributes.getInt('group', 1);
|
||||
this.kernelShape = attributes.getInts('kernel_shape', []);
|
||||
this.pads = attributes.getInts('pads', [0, 0, 0, 0]);
|
||||
this.strides = attributes.getInts('strides', [1, 1]);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
// Refer to the below link for all input checks
|
||||
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
|
||||
if (!inputs || (inputs.length !== 2 && inputs.length !== 3)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO : Need to add support for multi-dimensional conv
|
||||
// currently only support 2-dimensional conv
|
||||
if (inputs[0].dims.length !== 4 || inputs[1].dims.length !== 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// FILTER_IN_CHANNEL should be equal to DATA_CHANNEL
|
||||
const dataChannel = inputs[0].dims[1];
|
||||
const filterInChannel = inputs[1].dims[1] * this.group;
|
||||
if (dataChannel !== filterInChannel) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if bias is provided it should be 1D and the number of elements should be equal to the number of feature maps
|
||||
if (inputs.length === 3 && (inputs[2].dims.length !== 1 || inputs[1].dims[0] !== inputs[2].dims[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const spatialRank = inputs[0].dims.length - 2;
|
||||
// wrong dilations dimension
|
||||
if (this.dilations.length !== spatialRank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wrong strides dimension
|
||||
if (this.strides.length !== spatialRank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Wrong pads dimension
|
||||
if (this.pads.length !== spatialRank * 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// if kernelShape is specified, it's data length must be 2 less than dims length of the weights tensor
|
||||
// (the first 2 dims are batch_size and channels)
|
||||
if (this.kernelShape.length !== 0 && this.kernelShape.length !== inputs[1].dims.length - 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
// TODO : Need to add support for float64
|
||||
if (inputs[0].type !== 'float32' || inputs[1].type !== 'float32') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inputs.length === 3 && inputs[2].type !== 'float32') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected autoPad: string;
|
||||
protected dilations: number[];
|
||||
protected group: number;
|
||||
protected kernelShape: number[];
|
||||
protected pads: number[];
|
||||
protected strides: number[];
|
||||
}
|
||||
35
js/web/lib/onnxjs/ops/dropout.ts
Normal file
35
js/web/lib/onnxjs/ops/dropout.ts
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Dropout implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.ratio = attributes.getFloat('ratio', 0.5);
|
||||
this.testMode = true; // this is a hack to reflect that test mode is hardcoded
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected ratio: number;
|
||||
protected testMode: boolean;
|
||||
}
|
||||
33
js/web/lib/onnxjs/ops/elu.ts
Normal file
33
js/web/lib/onnxjs/ops/elu.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Elu implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.alpha = attributes.getFloat('alpha', 1.0);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected alpha: number;
|
||||
}
|
||||
33
js/web/lib/onnxjs/ops/expand.ts
Normal file
33
js/web/lib/onnxjs/ops/expand.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {NUMBER_TYPES, Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Expand implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(_attributes: Attribute): void {}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inputs[1].type !== 'int32') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
43
js/web/lib/onnxjs/ops/flatten.ts
Normal file
43
js/web/lib/onnxjs/ops/flatten.ts
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Flatten implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.axis = attributes.getInt('axis', 1); // default axis is 1
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const r = inputs[0].dims.length;
|
||||
if (r === 0) {
|
||||
return false; // scalar tensor is not supported
|
||||
}
|
||||
|
||||
if (this.axis < -r || this.axis > r) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
// TODO: Support string type
|
||||
if (inputs[0].type === 'string') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected axis: number;
|
||||
}
|
||||
42
js/web/lib/onnxjs/ops/gather.ts
Normal file
42
js/web/lib/onnxjs/ops/gather.ts
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {NUMBER_TYPES, Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Gather implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.axis = attributes.getInt('axis', 0);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
return false;
|
||||
}
|
||||
const tensorRank = inputs[0].dims.length;
|
||||
if (tensorRank < 1) {
|
||||
return false;
|
||||
}
|
||||
if (this.axis < -tensorRank || this.axis > tensorRank - 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (NUMBER_TYPES.indexOf(inputs[0].type) === -1) {
|
||||
return false;
|
||||
}
|
||||
if (inputs[1].type !== 'int32' && inputs[1].type !== 'int16') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected axis: number;
|
||||
}
|
||||
62
js/web/lib/onnxjs/ops/gemm.ts
Normal file
62
js/web/lib/onnxjs/ops/gemm.ts
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class Gemm implements Operator {
|
||||
constructor(isOptionalC: boolean) {
|
||||
this.isOptionalC = isOptionalC;
|
||||
}
|
||||
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.transA = attributes.getInt('transA', 0) !== 0;
|
||||
this.transB = attributes.getInt('transB', 0) !== 0;
|
||||
this.alpha = attributes.getFloat('alpha', 1);
|
||||
this.beta = attributes.getFloat('beta', 1);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs) {
|
||||
return false;
|
||||
}
|
||||
if (this.isOptionalC && (inputs.length < 2 || inputs.length > 3)) {
|
||||
return false;
|
||||
}
|
||||
if (!this.isOptionalC && inputs.length !== 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 'C' can be of dimensionality 1 or 2 only
|
||||
if (inputs.length === 3 && inputs[2].dims.length !== 1 && inputs[2].dims.length !== 2) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if ((inputs[0].type !== 'float32' && inputs[0].type !== 'float64') ||
|
||||
(inputs[1].type !== 'float32' && inputs[1].type !== 'float64') ||
|
||||
(inputs.length === 3 && inputs[2].type !== 'float32' && inputs[2].type !== 'float64')) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((inputs[0].type !== inputs[1].type) || (inputs.length === 3 && inputs[0].type !== inputs[2].type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected transA: boolean;
|
||||
protected transB: boolean;
|
||||
protected alpha: number;
|
||||
protected beta: number;
|
||||
|
||||
protected isOptionalC: boolean; // in opset 11, C becomes optional
|
||||
}
|
||||
39
js/web/lib/onnxjs/ops/image-scaler.ts
Normal file
39
js/web/lib/onnxjs/ops/image-scaler.ts
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class ImageScaler implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.scale = attributes.getFloat('scale');
|
||||
this.bias = attributes.getFloats('bias');
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inputs[0].dims.length !== 4) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
protected scale: number;
|
||||
protected bias: number[];
|
||||
}
|
||||
45
js/web/lib/onnxjs/ops/instance-normalization.ts
Normal file
45
js/web/lib/onnxjs/ops/instance-normalization.ts
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {Attribute} from '../attribute';
|
||||
import {InferenceHandler} from '../backend';
|
||||
import {Operator} from '../operators';
|
||||
import {Tensor} from '../tensor';
|
||||
|
||||
export abstract class InstanceNormalization implements Operator {
|
||||
abstract run(inferenceHandler: InferenceHandler, inputs: Tensor[]): Tensor[]|Promise<Tensor[]>;
|
||||
|
||||
initialize(attributes: Attribute): void {
|
||||
this.epsilon = attributes.getFloat('epsilon', 1e-5);
|
||||
}
|
||||
|
||||
checkInputs(inputs: Tensor[]): boolean {
|
||||
if (!inputs || inputs.length !== 3) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return this.checkInputTypes(inputs);
|
||||
}
|
||||
|
||||
protected checkInputTypes(inputs: Tensor[]): boolean {
|
||||
const X = inputs[0];
|
||||
const scale = inputs[1];
|
||||
const B = inputs[2];
|
||||
|
||||
// input should atleast have three dimensions - N,C,dim1,...,dimn
|
||||
// other inputs can have only one dimensions
|
||||
if (X.dims.length < 3 || scale.dims.length !== 1 || B.dims.length !== 1) {
|
||||
return false;
|
||||
}
|
||||
if (scale.dims[0] !== X.dims[1] || B.dims[0] !== X.dims[1]) {
|
||||
return false;
|
||||
}
|
||||
if ((X.type !== 'float32' && X.type !== 'float64') || (scale.type !== 'float32' && scale.type !== 'float64') ||
|
||||
(B.type !== 'float32' && B.type !== 'float64')) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected epsilon: number;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue