mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
[js/webgpu] Support pad uniforms (#19057)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
eb92681bfb
commit
42ba2aed54
2 changed files with 106 additions and 118 deletions
|
|
@ -20,7 +20,7 @@ import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
|
|||
import {layerNorm} from './ops/layer-norm';
|
||||
import {matMul} from './ops/matmul';
|
||||
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
|
||||
import {pad, parsePadAttributes} from './ops/pad';
|
||||
import {pad} from './ops/pad';
|
||||
import * as pool from './ops/pool';
|
||||
import {range} from './ops/range';
|
||||
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
|
||||
|
|
@ -95,7 +95,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
|
||||
['Neg', [unaryOps.neg]],
|
||||
['Not', [unaryOps.not]],
|
||||
['Pad', [pad, parsePadAttributes]],
|
||||
['Pad', [pad]],
|
||||
['Pow', [binaryOps.pow]],
|
||||
['Range', [range]],
|
||||
['Reciprocal', [unaryOps.reciprocal]],
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common';
|
||||
|
||||
export interface PadAttributes extends AttributeWithCacheKey {
|
||||
interface PadAttributes {
|
||||
// 0-constant, 1-reflect, 2-edge, 3-wrap
|
||||
readonly mode: number;
|
||||
readonly value: number;
|
||||
|
|
@ -35,27 +34,23 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
|
|||
}
|
||||
};
|
||||
|
||||
const getPadConstant =
|
||||
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[],
|
||||
dataType: string, constantValue: number): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
const getPadConstant = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
||||
if (k < 0) {
|
||||
break;
|
||||
}
|
||||
if (k >= ${inputDims[i]}) {
|
||||
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
||||
break;
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
return `
|
||||
value = ${dataType}(${constantValue});
|
||||
return `
|
||||
value = ${output.type.value}(uniforms.constant_value);
|
||||
for (var i = 0; i < 1; i++) {
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
|
|
@ -63,143 +58,143 @@ const getPadConstant =
|
|||
value = x[offset];
|
||||
}
|
||||
`;
|
||||
};
|
||||
};
|
||||
|
||||
const getPadReflect =
|
||||
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
const getPadReflect = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
||||
if (k < 0) {
|
||||
k = -k;
|
||||
}
|
||||
{
|
||||
let _2n_1 = ${2 * (inputDims[i] - 1)};
|
||||
let _2n_1 = 2 * (i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1);
|
||||
k = k % _2n_1;
|
||||
if(k >= ${inputDims[i]}) {
|
||||
if(k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
||||
k = _2n_1 - k;
|
||||
}
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
return `
|
||||
return `
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
`;
|
||||
};
|
||||
};
|
||||
|
||||
const getPadEdge =
|
||||
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
const getPadEdge = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
||||
if (k < 0) {
|
||||
k = 0;
|
||||
}
|
||||
if (k >= ${inputDims[i]}) {
|
||||
k = ${inputDims[i] - 1};
|
||||
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
||||
k = i32(${getElementAt('uniforms.x_shape', i, inputRank)}) - 1;
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
return `
|
||||
return `
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
`;
|
||||
};
|
||||
};
|
||||
|
||||
const getPadWrap =
|
||||
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], pads: number[]): string => {
|
||||
const inputRank = inputDims.length;
|
||||
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${pads[i]};
|
||||
const getPadWrap = (output: IndicesHelper, inputRank: number, padsLength: number): string => {
|
||||
let block = '';
|
||||
for (let i = inputRank - 1; i >= 0; --i) {
|
||||
block += `
|
||||
k = i32(${output.indicesGet('indices', i)}) - ${getElementAt('uniforms.pads', i, padsLength)};
|
||||
if (k < 0) {
|
||||
k += ${inputDims[i]};
|
||||
k += i32(${getElementAt('uniforms.x_shape', i, inputRank)}]);
|
||||
}
|
||||
if (k >= ${inputDims[i]}) {
|
||||
k -= ${inputDims[i]};
|
||||
if (k >= i32(${getElementAt('uniforms.x_shape', i, inputRank)})) {
|
||||
k -= i32(${getElementAt('uniforms.x_shape', i, inputRank)});
|
||||
}
|
||||
offset += k * ${inputStrides[i]};
|
||||
offset += k * i32(${getElementAt('uniforms.x_strides', i, inputRank)});
|
||||
`;
|
||||
}
|
||||
}
|
||||
|
||||
return `
|
||||
return `
|
||||
var offset = 0;
|
||||
var k = 0;
|
||||
${block}
|
||||
value = x[offset];
|
||||
`;
|
||||
};
|
||||
};
|
||||
|
||||
const getPadSnippet =
|
||||
(output: IndicesHelper, inputDims: readonly number[], inputStrides: readonly number[], attributes: PadAttributes,
|
||||
dataType: string): string => {
|
||||
switch (attributes.mode) {
|
||||
case 0:
|
||||
return getPadConstant(output, inputDims, inputStrides, attributes.pads, dataType, attributes.value);
|
||||
case 1:
|
||||
return getPadReflect(output, inputDims, inputStrides, attributes.pads);
|
||||
case 2:
|
||||
return getPadEdge(output, inputDims, inputStrides, attributes.pads);
|
||||
case 3:
|
||||
return getPadWrap(output, inputDims, inputStrides, attributes.pads);
|
||||
default:
|
||||
throw new Error('Invalid mode');
|
||||
}
|
||||
};
|
||||
|
||||
const generatePadCode =
|
||||
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: PadAttributes, dataType: string):
|
||||
string => {
|
||||
const inputDims = inputs[0].dims;
|
||||
const outputDims = ShapeUtil.padShape(inputDims.slice(), attributes.pads);
|
||||
const outputSize = ShapeUtil.size(outputDims);
|
||||
const inputStrides = ShapeUtil.computeStrides(inputDims);
|
||||
|
||||
const output = outputVariable('output', inputs[0].dataType, outputDims);
|
||||
const input = inputVariable('x', inputs[0].dataType, inputDims);
|
||||
|
||||
const padSnippet = getPadSnippet(output, inputDims, inputStrides, attributes, dataType);
|
||||
const padCode = `
|
||||
${shaderHelper.declareVariables(input, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
|
||||
let indices = ${output.offsetToIndices('global_idx')};
|
||||
|
||||
var value = ${dataType}(0);
|
||||
${padSnippet}
|
||||
output[global_idx] = value;
|
||||
}`;
|
||||
return padCode;
|
||||
};
|
||||
const getPadSnippet = (output: IndicesHelper, inputRank: number, attributes: PadAttributes): string => {
|
||||
switch (attributes.mode) {
|
||||
case 0:
|
||||
return getPadConstant(output, inputRank, attributes.pads.length);
|
||||
case 1:
|
||||
return getPadReflect(output, inputRank, attributes.pads.length);
|
||||
case 2:
|
||||
return getPadEdge(output, inputRank, attributes.pads.length);
|
||||
case 3:
|
||||
return getPadWrap(output, inputRank, attributes.pads.length);
|
||||
default:
|
||||
throw new Error('Invalid mode');
|
||||
}
|
||||
};
|
||||
|
||||
const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttributes): ProgramInfo => {
|
||||
const outputShape = ShapeUtil.padShape(inputs[0].dims.slice(), attributes.pads);
|
||||
const inputDims = inputs[0].dims;
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}];
|
||||
if (attributes.mode === 0) {
|
||||
const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type'];
|
||||
programUniforms.push({type: tensorDataType, data: attributes.value});
|
||||
}
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
|
||||
const input = inputVariable('x', inputs[0].dataType, inputDims.length);
|
||||
const dataType = input.type.value;
|
||||
const padSnippet = getPadSnippet(output, inputDims.length, attributes);
|
||||
const uniforms: UniformsArrayType =
|
||||
[{name: 'output_size', type: 'u32'}, {name: 'pads', type: 'i32', length: attributes.pads.length}];
|
||||
if (attributes.mode === 0) {
|
||||
uniforms.push({name: 'constant_value', type: dataType as UniformDataElementType});
|
||||
}
|
||||
|
||||
return `
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
||||
|
||||
let indices = ${output.offsetToIndices('global_idx')};
|
||||
|
||||
var value = ${dataType}(0);
|
||||
${padSnippet}
|
||||
output[global_idx] = value;
|
||||
}`;
|
||||
};
|
||||
|
||||
return {
|
||||
name: 'Pad',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
shaderCache: {hint: `${attributes.mode}`, inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}
|
||||
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource: shaderHelper => generatePadCode(shaderHelper, inputs, attributes, 'f32'),
|
||||
getShaderSource,
|
||||
};
|
||||
};
|
||||
|
||||
|
|
@ -223,7 +218,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes
|
|||
const pads: number[] = [];
|
||||
updatePads.forEach(v => pads.push(v));
|
||||
|
||||
return createAttributeWithCacheKey({mode: attributes.mode, value, pads});
|
||||
return {mode: attributes.mode, value, pads};
|
||||
} else {
|
||||
return attributes;
|
||||
}
|
||||
|
|
@ -234,10 +229,3 @@ export const pad = (context: ComputeContext, attributes: PadAttributes): void =>
|
|||
const updatedAttributes = createPadAttributesFromInputs(context.inputs, attributes);
|
||||
context.compute(createPadProgramInfo(context.inputs, updatedAttributes), {inputs: [0]});
|
||||
};
|
||||
|
||||
export const parsePadAttributes = (attributes: Record<string, unknown>): PadAttributes => {
|
||||
const mode = attributes.mode as number;
|
||||
const value = attributes.value as number;
|
||||
const pads = attributes.pads as number[];
|
||||
return createAttributeWithCacheKey({mode, value, pads});
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue