mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
231 lines
8.4 KiB
TypeScript
231 lines
8.4 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common';
|
|
import {TensorView} from '../../tensor-view';
|
|
import {ShapeUtil} from '../../util';
|
|
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
|
|
|
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common';
|
|
|
|
interface PadAttributes {
|
|
// 0-constant, 1-reflect, 2-edge, 3-wrap
|
|
readonly mode: number;
|
|
readonly value: number;
|
|
readonly pads: number[];
|
|
}
|
|
|
|
const validateInputs = (inputs: readonly TensorView[]): void => {
|
|
if (!inputs || inputs.length < 1) {
|
|
throw new Error('Too few inputs');
|
|
}
|
|
if (inputs[0].dataType !== DataType.float) {
|
|
throw new Error('Input type must be float.');
|
|
}
|
|
|
|
if (inputs.length >= 2) {
|
|
let validPads = inputs[0].dims.length * 2 === inputs[1].dims[0];
|
|
if (inputs.length === 4) {
|
|
validPads = inputs[3].dims[0] * 2 === inputs[1].dims[0];
|
|
}
|
|
if (!validPads) {
|
|
throw new Error('The pads should be a 1D tensor of shape [2 * input_rank] or [2 * num_axes].');
|
|
}
|
|
}
|
|
};
|
|
|
|
const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
|
let block = '';
|
|
for (let i = inputRank - 1; i >= 0; --i) {
|
|
block += `
|
|
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
|
if (k < 0) {
|
|
break;
|
|
}
|
|
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
|
break;
|
|
}
|
|
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
|
`;
|
|
}
|
|
|
|
return `
|
|
value = ${output.type.value}(uniforms.constant_value);
|
|
for (var i = 0; i < 1; i++) {
|
|
var offset = 0;
|
|
var k = 0;
|
|
${block}
|
|
value = x[offset];
|
|
}
|
|
`;
|
|
};
|
|
|
|
const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
|
let block = '';
|
|
for (let i = inputRank - 1; i >= 0; --i) {
|
|
block += `
|
|
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
|
if (k < 0) {
|
|
k = -k;
|
|
}
|
|
{
|
|
let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1);
|
|
k = k % _2n_1;
|
|
if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
|
k = _2n_1 - k;
|
|
}
|
|
}
|
|
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
|
`;
|
|
}
|
|
|
|
return `
|
|
var offset = 0;
|
|
var k = 0;
|
|
${block}
|
|
value = x[offset];
|
|
`;
|
|
};
|
|
|
|
const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
|
let block = '';
|
|
for (let i = inputRank - 1; i >= 0; --i) {
|
|
block += `
|
|
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
|
if (k < 0) {
|
|
k = 0;
|
|
}
|
|
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
|
k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1;
|
|
}
|
|
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
|
`;
|
|
}
|
|
|
|
return `
|
|
var offset = 0;
|
|
var k = 0;
|
|
${block}
|
|
value = x[offset];
|
|
`;
|
|
};
|
|
|
|
const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
|
let block = '';
|
|
for (let i = inputRank - 1; i >= 0; --i) {
|
|
block += `
|
|
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
|
if (k < 0) {
|
|
k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]);
|
|
}
|
|
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
|
k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)});
|
|
}
|
|
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
|
`;
|
|
}
|
|
|
|
return `
|
|
var offset = 0;
|
|
var k = 0;
|
|
${block}
|
|
value = x[offset];
|
|
`;
|
|
};
|
|
|
|
const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => {
|
|
switch (attributes.mode) {
|
|
case 0:
|
|
return getPadConstant(output, inputRank, attributes.pads.length);
|
|
case 1:
|
|
return getPadReflect(output, inputRank, attributes.pads.length);
|
|
case 2:
|
|
return getPadEdge(output, inputRank, attributes.pads.length);
|
|
case 3:
|
|
return getPadWrap(output, inputRank, attributes.pads.length);
|
|
default:
|
|
throw new Error('Invalid mode');
|
|
}
|
|
};
|
|
|
|
const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => {
|
|
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
|
|
const inputDims = inputs[0].dims;
|
|
const outputSize = ShapeUtil.size(outputShape);
|
|
const programUniforms: ProgramUniform[] =
|
|
[{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}];
|
|
if (attributes.mode === 0) {
|
|
const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type'];
|
|
programUniforms.push({type: tensorDataType, data: attributes.value});
|
|
}
|
|
|
|
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape));
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
|
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
|
const input = inputVariable('x', inputs[0].dataType, inputDims.length);
|
|
const dataType = input.type.value;
|
|
const padSnippet = getPadSnippet(output, inputDims.length, attributes);
|
|
const uniforms: UniformsArrayType =
|
|
[{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}];
|
|
if (attributes.mode === 0) {
|
|
uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType});
|
|
}
|
|
|
|
return `
|
|
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
|
|
|
let indices = ${output.offsetToIndices('global_idx')};
|
|
|
|
var value = ${dataType}(0);
|
|
${padSnippet}
|
|
output[global_idx] = value;
|
|
}`;
|
|
};
|
|
|
|
return {
|
|
name: 'Pad',
|
|
shaderCache: {hint: `${attributes.mode}`, inputDependencies},
|
|
getRunData: () => ({
|
|
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
|
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
|
|
programUniforms
|
|
}),
|
|
getShaderSource,
|
|
};
|
|
};
|
|
|
|
const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
|
|
if (inputs.length > 1) {
|
|
const bigInt64Pads = inputs[1].getBigInt64Array();
|
|
const value = (inputs.length >= 3 && inputs[2].data) ? inputs[2].getFloat32Array()[0] : 0.0;
|
|
|
|
const inputRank = inputs[0].dims.length;
|
|
const updatePads = new Int32Array(2 * inputRank).fill(0);
|
|
if (inputs.length >= 4) {
|
|
const axes = inputs[3].getBigInt64Array();
|
|
for (let i = 0; i < axes.length; i++) {
|
|
updatePads[Number(axes[i])] = Number(bigInt64Pads[i]);
|
|
updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]);
|
|
}
|
|
} else {
|
|
bigInt64Pads.forEach((v, i) => updatePads[Number(i)] = (Number(v)));
|
|
}
|
|
|
|
const pads: number[] = [];
|
|
updatePads.forEach(v => pads.push(v));
|
|
|
|
return {mode: attributes.mode, value, pads};
|
|
} else {
|
|
return attributes;
|
|
}
|
|
};
|
|
|
|
export const pad = (context: ComputeContext, attributes: PadAttributes): void => {
|
|
validateInputs(context.inputs);
|
|
const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes);
|
|
context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]});
|
|
};
|