onnxruntime/js/web/lib/onnxjs/backends/webgl/ops/image-scaler.ts
Yulong Wang 586f06f5a1
[js/web] set noUnusedParameters to true and fix a few bugs (#18404)
### Description
- set tsconfig "noUnusedParameters" to `true` and fix a few bugs
discovered by typescript.
   how unused parameter is fixed:
- for most code (webgl), add underscore as prefix, which is the standard
ignore pattern for typescript check.
- remove unused parameter from function and modify corresponding
function calls (jsep)
- fix a bug in ArgMinMax: this 2 operators do not have more than one
input(s) so the `createArgMinMaxAttributesFromInputs()` is removed.
- add proxy main.ts into typescript check and fix a bug in parameter
passing
   - fixed `run()` function call and add typecheck fix (hack)
2023-11-15 09:16:29 -08:00

98 lines
3.6 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../../../attribute-with-cache-key';
import {Graph} from '../../../graph';
import {OperatorImplementation, OperatorInitialization} from '../../../operators';
import {Tensor} from '../../../tensor';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, ProgramInfoLoader, ProgramMetadata, TextureType} from '../types';
export interface ImageScalerAttributes extends AttributeWithCacheKey {
scale: number;
bias: number[];
}
export const imageScaler: OperatorImplementation<ImageScalerAttributes> =
(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[], attributes: ImageScalerAttributes): Tensor[] => {
validateInputs(inputs);
const output =
inferenceHandler.run(createImageScalerProgramInfoLoader(inferenceHandler, inputs, attributes), inputs);
return [output];
};
export const parseImageScalerAttributes: OperatorInitialization<ImageScalerAttributes> =
(node: Graph.Node): ImageScalerAttributes => {
const scale = node.attributes.getFloat('scale');
const bias = node.attributes.getFloats('bias');
return createAttributeWithCacheKey({scale, bias});
};
const imageScalerProgramMetadata = {
name: 'ImageScaler',
inputNames: ['X'],
inputTypes: [TextureType.unpacked],
};
const createImageScalerProgramInfo =
(_handler: WebGLInferenceHandler, metadata: ProgramMetadata, inputs: Tensor[], attributes: ImageScalerAttributes):
ProgramInfo => {
const outputShape = inputs[0].dims.slice();
const rank = outputShape.length;
const getBiasMethod = createGetBiasMethod(attributes.bias.length);
const shaderSource = `
${getBiasMethod}
float process(int indices[${rank}]) {
return _X(indices) * scale + getBias(bias, indices[1]);
}`;
return {
...metadata,
output: {dims: outputShape, type: inputs[0].type, textureType: TextureType.unpacked},
variables: [
{name: 'bias', type: 'float', arrayLength: attributes.bias.length, data: attributes.bias},
{name: 'scale', type: 'float', data: attributes.scale}
],
shaderSource
};
};
const createImageScalerProgramInfoLoader =
(handler: WebGLInferenceHandler, inputs: Tensor[], attributes: ImageScalerAttributes): ProgramInfoLoader => {
const metadata = {...imageScalerProgramMetadata, cacheHint: attributes.cacheKey};
return {...metadata, get: () => createImageScalerProgramInfo(handler, metadata, inputs, attributes)};
};
const createGetBiasMethod = (numChannels: number): string => {
const codeLines: string[] = [`float getBias(float bias[${numChannels}], int channel) {`];
for (let i = 0; i < numChannels; ++i) {
if (i === 0) {
codeLines.push(
'\t' +
`if (channel == ${i}) { return bias[${i}]; }`);
} else if (i === numChannels - 1) {
codeLines.push(
'\t' +
`else { return bias[${i}]; }`);
} else {
codeLines.push(
'\t' +
`else if (channel == ${i}) { return bias[${i}]; }`);
}
}
codeLines.push(
'\t' +
'}');
return codeLines.join('\n');
};
const validateInputs = (inputs: Tensor[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('ImageScaler requires 1 input.');
}
if (inputs[0].dims.length !== 4) {
throw new Error('Invalid input shape.');
}
if (inputs[0].type !== 'float32' && inputs[0].type !== 'float64') {
throw new Error('Invalid input type.');
}
};