onnxruntime/js/web/lib/onnxjs/backends/webgl/utils.ts
Yulong Wang 4ebc9c3b5e
[JS] onnxruntime-web (#7394)
* add web

* add script and test

* fix lint

* add test/data/ops

* add test/data/node/ to gitignore

* modify scripts

* add onnxjs

* fix tests

* fix test-runner

* fix sourcemap

* fix onnxjs profiling

* update test list

* update README

* resolve comments

* set wasm as default backend

* rename package

* update copyright header

* do not use class "Buffer" in browser context

* revise readme
2021-04-27 00:04:25 -07:00

64 lines
2.4 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {assert} from '../../util';
/**
* Given a non RGBA shape calculate the R version
* It is assumed that the dimensions are multiples of given channels
* NOTE: it is always the last dim that gets packed.
* @param unpackedShape original shape to create a packed version from
*/
export function getPackedShape(unpackedShape: readonly number[]): readonly number[] {
const len = unpackedShape.length;
return unpackedShape.slice(0, len - 1).concat(unpackedShape[len - 1] / 4);
}
/**
* Generates the function name from an input sampler name.
* @param samplerName Name of the sampler.
*/
export function generateShaderFuncNameFromInputSamplerName(samplerName: string): string {
assert(typeof samplerName !== 'undefined' && samplerName.length !== 0, () => 'empty string found for sampler name');
return 'get' + samplerName.charAt(0).toUpperCase() + samplerName.slice(1);
}
/**
* Generates the function name from an input sampler name at output coordinates.
* @param samplerName Name of the sampler.
*/
export function generateShaderFuncNameFromInputSamplerNameAtOutCoords(samplerName: string): string {
assert(typeof samplerName !== 'undefined' && samplerName.length !== 0, () => 'empty string found for sampler name');
return 'get' + samplerName.charAt(0).toUpperCase() + samplerName.slice(1) + 'AtOutCoords';
}
/** Returns a new input shape (a copy) that has a squeezed logical shape. */
export function squeezeInputShape(inputShape: readonly number[], squeezedShape: number[]): number[] {
// Deep copy.
let newInputShape: number[] = JSON.parse(JSON.stringify(inputShape));
newInputShape = squeezedShape;
return newInputShape;
}
/** Returns a list of squeezed parameters for shader functions */
export function getSqueezedParams(params: string[], keptDims: number[]): string {
return keptDims.map(d => params[d]).join(', ');
}
/** Returns the data type for different ranks. */
export function getCoordsDataType(rank: number): string {
if (rank <= 1) {
return 'int';
} else if (rank === 2) {
return 'ivec2';
} else if (rank === 3) {
return 'ivec3';
} else if (rank === 4) {
return 'ivec4';
} else if (rank === 5) {
return 'ivec5';
} else if (rank === 6) {
return 'ivec6';
} else {
throw Error(`GPU for rank ${rank} is not yet supported`);
}
}