onnxruntime/js/web/lib/onnxjs/backends/webgl/ops/reshape-packed.ts
Yulong Wang abdc31de40
[js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728)
### Description

See
454996d496
for manual changes (excluded auto-generated formatting changes)

### Why

Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.

- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.

No one in community seems interested in fixing those.

Choose Prettier as it is the most popular TS/JS formatter.

### How to merge

It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.

So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
2024-08-14 16:51:22 -07:00

172 lines
5.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { Tensor } from '../../../tensor';
import { ShapeUtil } from '../../../util';
import { getGlsl } from '../glsl-source';
import { WebGLInferenceHandler } from '../inference-handler';
import { ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType } from '../types';
import { unpackFromChannel } from './packing-utils';
const createPackedReshape3DProgramMetadata = (outputShape3D: readonly number[]) => ({
name: 'Reshape (packed)',
inputTypes: [TextureType.packed],
inputNames: ['A'],
cacheHint: `${outputShape3D}`,
});
const createPackedReshape3DProgramInfo = (
handler: WebGLInferenceHandler,
input3D: Tensor,
metadata: ProgramMetadata,
outputShape3D: readonly number[],
): ProgramInfo => {
const inputShape3D = input3D.dims as [number, number, number];
const squeezedOutputShape = outputShape3D as [number, number, number];
let mainLoop = '';
for (let i = 0; i < 4; i++) {
let outputCoords = '';
switch (i) {
case 0:
outputCoords = 'outputCoords = rc;';
break;
case 1:
outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z);';
break;
case 2:
outputCoords = 'outputCoords = ivec3(rc.x, rc.y, rc.z+1);';
break;
case 3:
outputCoords = 'outputCoords = ivec3(rc.x, rc.y+1, rc.z+1);';
break;
default:
throw new Error();
}
mainLoop += `
${outputCoords}
${i > 0 ? 'if(outputCoords.y < rows && outputCoords.z < cols){' : ''}
int flattenedIndex = getFlattenedIndex(outputCoords);
ivec3 inputRC = inputCoordsFromReshapedOutCoords(flattenedIndex);
vec2 innerDims = vec2(float(inputRC.y),float(inputRC.z));
result[${i}] = getChannel(getA(inputRC.x, inputRC.y, inputRC.z), innerDims);
${i > 0 ? '}' : ''}
`;
}
const glsl = getGlsl(handler.session.backend.glContext.version);
const shaderSource = `
${getReshapedInputCoords(inputShape3D)}
${getFlattenedIndexFrom3D(squeezedOutputShape)}
${unpackFromChannel()}
void main() {
ivec3 rc = getOutputCoords();
vec4 result = vec4(0.0);
ivec3 outputCoords;
int rows = ${squeezedOutputShape[2]};
int cols = ${squeezedOutputShape[1]};
${mainLoop}
${glsl.output} = result;
}
`;
return {
...metadata,
output: { dims: squeezedOutputShape, type: input3D.type, textureType: TextureType.packed },
shaderSource,
hasMain: true,
};
};
export const createPackedReshape3DProgramInfoLoader = (
handler: WebGLInferenceHandler,
input3D: Tensor,
outputShape3D: readonly number[],
): ProgramInfoLoader => {
const metadata = createPackedReshape3DProgramMetadata(outputShape3D);
return { ...metadata, get: () => createPackedReshape3DProgramInfo(handler, input3D, metadata, outputShape3D) };
};
export function processDims3D(shape: ArrayLike<number>): [number, number, number] {
if (shape.length === 0) {
return [1, 1, 1];
}
// TODO: squeeze other shapes to 2D case
let batch = 1;
for (let i = 0; i < shape.length - 2; ++i) {
batch *= shape[i];
}
return [batch, shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]];
}
// For packed reshape, we need to re-arrange texel data for output shape.
// Our pack is designed to pack a 2x2 tile in last h and w dimension, so
// for the reshaped new tensor, we just need to re-arrange the last h and
// w dimension. For any shape that is not in 3D, i.e. [batch, W, H], we
// first convert it to 3D by collapsing other dimension to batch dim, then
// process with the last two dimensions.
// Note: we only need the shape tensor to calculate output shape, so the
// content in shape tensor is never uploaded to GPU. It is always kept in CPU.
// TODO: optimize the algorithm -- in some cases, if the last two dims are
// the same between input shape and output shape, the packed reshape can be
// treated as no-op.
export function isReshapeCheap(dims: readonly number[], reshapedDims: readonly number[]) {
let isCheapReshape = false;
if (dims.length === 0 || reshapedDims.length === 0) {
// scalar
isCheapReshape = true;
} else if (dims.length < 2 || reshapedDims.length < 2) {
// 1D
isCheapReshape = dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1];
} else {
// 2D +
isCheapReshape =
dims[dims.length - 1] === reshapedDims[reshapedDims.length - 1] &&
dims[dims.length - 2] === reshapedDims[reshapedDims.length - 2];
}
return isCheapReshape;
}
function getReshapedInputCoords(shape: [number, number, number]): string {
const strides = ShapeUtil.computeStrides(shape);
const coords = ['b', 'r', 'c'];
const index = 'index';
const coordsFromIndexSnippet = strides
.map((stride, i) => {
const line1 = `int ${coords[i]} = ${index} / ${stride}`;
const line2 =
i === strides.length - 1
? `int ${coords[i + 1]} = ${index} - ${coords[i]} * ${stride}`
: `index -= ${coords[i]} * ${stride}`;
return `${line1}; ${line2};`;
})
.join('');
return `
ivec3 inputCoordsFromReshapedOutCoords(int index) {
${coordsFromIndexSnippet}
return ivec3(b, r, c);
}
`;
}
function getFlattenedIndexFrom3D(shape: [number, number, number]): string {
const strides = ShapeUtil.computeStrides(shape);
return `
int getFlattenedIndex(ivec3 coords) {
// reverse y, z order
return coords.x * ${strides[0]} + coords.z * ${strides[1]} + coords.y;
}
`;
}