mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-17 21:10:43 +00:00
416 lines
18 KiB
TypeScript
416 lines
18 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {env} from 'onnxruntime-common';
|
|
|
|
import {TensorView} from '../../tensor-view';
|
|
import {PoolConvUtil, ShapeUtil} from '../../util';
|
|
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
|
|
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
|
|
|
import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
|
|
|
|
// TODO: support:
|
|
// - ceil_mode "test_maxpool_2d_ceil"
|
|
// - storage_order "test_maxpool_with_argmax_2d_precomputed_strides"
|
|
// - [MaxPool] dilations "test_maxpool_2d_dilations"
|
|
// - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads"
|
|
|
|
const validateInputs = (inputs: readonly TensorView[]): void => {
|
|
if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) {
|
|
throw new Error('Pool ops requires 1 input.');
|
|
}
|
|
};
|
|
|
|
const getAdjustedPoolAttributesAndOutputShape = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
|
|
input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => {
|
|
const isChannelsLast = attributes.format === 'NHWC';
|
|
const inputShapeAsChannelFirst = input.dims.slice();
|
|
if (isChannelsLast) {
|
|
inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
|
|
}
|
|
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
|
|
const kernelShape = attributes.kernelShape.slice();
|
|
const strides = attributes.strides.slice();
|
|
const dilations: number[] = hasDilations ? (attributes as MaxPoolAttributes).dilations.slice() : [];
|
|
const pads = attributes.pads.slice();
|
|
PoolConvUtil.adjustPoolAttributes(isGlobalOperator, inputShapeAsChannelFirst, kernelShape, strides, dilations, pads);
|
|
|
|
const outputShapeAsChannelFirst = PoolConvUtil.computePoolOutputShape(
|
|
isGlobalOperator, inputShapeAsChannelFirst, strides, dilations, kernelShape, pads, attributes.autoPad);
|
|
|
|
const newAttributes = Object.assign({}, attributes);
|
|
if (hasDilations) {
|
|
Object.assign(newAttributes, {kernelShape, strides, pads, dilations, cacheKey: attributes.cacheKey});
|
|
} else {
|
|
Object.assign(newAttributes, {kernelShape, strides, pads, cacheKey: attributes.cacheKey});
|
|
}
|
|
const outputShapeAsChannelLast = outputShapeAsChannelFirst.slice();
|
|
outputShapeAsChannelLast.push(outputShapeAsChannelLast.splice(1, 1)[0]);
|
|
return [newAttributes, isChannelsLast ? outputShapeAsChannelLast : outputShapeAsChannelFirst];
|
|
};
|
|
|
|
const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
|
|
outputShape: readonly number[],
|
|
attributes: AttributeType): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => {
|
|
const isChannelsLast = attributes.format === 'NHWC';
|
|
const outputSize = ShapeUtil.size(outputShape);
|
|
const kernelSize = ShapeUtil.size(attributes.kernelShape);
|
|
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}, {type: 'uint32', data: kernelSize}];
|
|
const uniforms: UniformsArrayType = [{name: 'outputSize', type: 'u32'}, {name: 'kernelSize', type: 'u32'}];
|
|
if (attributes.kernelShape.length <= 2) {
|
|
const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
|
|
const sw = attributes.strides[attributes.strides.length - 1];
|
|
const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
|
|
const pwEnd = attributes.pads[attributes.pads.length - 1];
|
|
const pwStartEndNotZero = !!(pwStart + pwEnd);
|
|
programUniforms.push(
|
|
{type: 'uint32', data: kw},
|
|
{type: 'uint32', data: sw},
|
|
{type: 'uint32', data: pwStart},
|
|
{type: 'uint32', data: pwEnd},
|
|
);
|
|
uniforms.push(
|
|
{name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'},
|
|
{name: 'pwEnd', type: 'u32'});
|
|
|
|
let phStartEndNotZero = false;
|
|
if (attributes.kernelShape.length === 2) {
|
|
const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
|
|
const sh = attributes.strides[attributes.strides.length - 2];
|
|
const phStart = attributes.pads[attributes.pads.length / 2 - 2];
|
|
const phEnd = attributes.pads[attributes.pads.length - 2];
|
|
phStartEndNotZero = !!(phStart + phEnd);
|
|
programUniforms.push(
|
|
{type: 'uint32', data: kh}, {type: 'uint32', data: sh}, {type: 'uint32', data: phStart},
|
|
{type: 'uint32', data: phEnd});
|
|
|
|
uniforms.push(
|
|
{name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'},
|
|
{name: 'phEnd', type: 'u32'});
|
|
}
|
|
return [programUniforms, uniforms, true, pwStartEndNotZero, phStartEndNotZero];
|
|
} else {
|
|
if (isChannelsLast) {
|
|
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
|
|
}
|
|
const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
|
|
programUniforms.push(
|
|
{type: 'uint32', data: kernelStrides}, {type: 'uint32', data: attributes.pads},
|
|
{type: 'uint32', data: attributes.strides});
|
|
uniforms.push(
|
|
{name: 'kernelStrides', type: 'u32', length: kernelStrides.length},
|
|
{name: 'pads', type: 'u32', length: attributes.pads.length},
|
|
{name: 'strides', type: 'u32', length: attributes.strides.length});
|
|
|
|
const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
|
|
return [programUniforms, uniforms, !!hasPads, false, false];
|
|
}
|
|
};
|
|
|
|
const generatePoolingCode = <AttributeType extends AveragePoolAttributes|MaxPoolAttributes>(
|
|
shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType,
|
|
op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEndNotZero: boolean,
|
|
phStartEndNotZero: boolean): string => {
|
|
const isChannelsLast = attributes.format === 'NHWC';
|
|
const dataType = x.type.value;
|
|
const output = outputVariable('output', x.type.tensor, outputShapeRank);
|
|
|
|
if (attributes.kernelShape.length <= 2) {
|
|
let codeW = '';
|
|
let codeH = '';
|
|
let codeHEnd = '';
|
|
const dimIdxW = rank - (isChannelsLast ? 2 : 1);
|
|
if (pwStartEndNotZero) {
|
|
codeW = `
|
|
for (var i: u32 = 0u; i < uniforms.kw; i++) {
|
|
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
|
|
if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}]
|
|
>= uniforms.x_shape[${dimIdxW}]) {
|
|
pad++;
|
|
continue;
|
|
}
|
|
let x_val = x[${x.indicesToOffset('xIndices')}];
|
|
${op1}
|
|
}`;
|
|
} else {
|
|
codeW = `
|
|
for (var i: u32 = 0u; i < uniforms.kw; i++) {
|
|
xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
|
|
let x_val = x[${x.indicesToOffset('xIndices')}];
|
|
${op1}
|
|
}`;
|
|
}
|
|
|
|
if (attributes.kernelShape.length === 2) {
|
|
const dimIdxH = rank - (isChannelsLast ? 3 : 2);
|
|
if (phStartEndNotZero) {
|
|
codeH = `
|
|
for (var j: u32 = 0u; j < uniforms.kh; j++) {
|
|
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
|
|
if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) {
|
|
pad += i32(uniforms.kw);
|
|
continue;
|
|
}
|
|
`;
|
|
} else {
|
|
codeH = `
|
|
for (var j: u32 = 0u; j < uniforms.kh; j++) {
|
|
xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
|
|
`;
|
|
}
|
|
codeHEnd = `
|
|
}
|
|
`;
|
|
}
|
|
|
|
const poolingCode = `
|
|
${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
|
|
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
|
|
|
let indices = ${output.offsetToIndices('global_idx')};
|
|
var xIndices = ${output.offsetToIndices('global_idx')};
|
|
|
|
var value = ${dataType}(${start});
|
|
var pad = 0;
|
|
${codeH}
|
|
${codeW}
|
|
${codeHEnd}
|
|
${op2}
|
|
|
|
output[global_idx] = value;
|
|
}`;
|
|
return poolingCode;
|
|
} else {
|
|
if (isChannelsLast) {
|
|
throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
|
|
}
|
|
const stridesRank = attributes.kernelShape.length;
|
|
const padsRank = attributes.pads.length;
|
|
let padCode = '';
|
|
if (hasPads) {
|
|
padCode = `
|
|
if (xIndices[j] >= uniforms.x_shape[j]) {
|
|
pad++;
|
|
isPad = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!isPad) {
|
|
let x_val = x[${x.indicesToOffset('xIndices')}];
|
|
${op1}
|
|
}`;
|
|
} else {
|
|
padCode = `
|
|
}
|
|
let x_val = x[${x.indicesToOffset('xIndices')}];
|
|
${op1}
|
|
`;
|
|
}
|
|
const poolingCode = `
|
|
${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
|
|
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
|
let indices = ${output.offsetToIndices('global_idx')};
|
|
var xIndices = ${output.offsetToIndices('global_idx')};
|
|
|
|
var offsets: array<u32, ${stridesRank}>;
|
|
|
|
var value = ${dataType}(${start});
|
|
var pad = 0;
|
|
var isPad = false;
|
|
|
|
for (var i: u32 = 0u; i < uniforms.kernelSize; i++) {
|
|
var offset = i;
|
|
for (var j = 0u; j < ${stridesRank - 1}u; j++) {
|
|
offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
|
|
offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
|
|
}
|
|
offsets[${stridesRank - 1}] = offset;
|
|
|
|
isPad = false;
|
|
for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) {
|
|
xIndices[j] = indices[j] * ${
|
|
getElementAt('uniforms.strides', `j - ${rank - stridesRank}u`, stridesRank)}
|
|
+ offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)};
|
|
${padCode}
|
|
}
|
|
${op2}
|
|
|
|
output[global_idx] = value;
|
|
}`;
|
|
return poolingCode;
|
|
}
|
|
};
|
|
|
|
export interface FormatAttributes {
|
|
readonly format: 'NHWC'|'NCHW';
|
|
}
|
|
|
|
export interface PoolCommonAttributes extends FormatAttributes {
|
|
readonly autoPad: string;
|
|
readonly ceilMode: number;
|
|
readonly kernelShape: readonly number[];
|
|
readonly strides: readonly number[];
|
|
readonly pads: readonly number[];
|
|
}
|
|
|
|
const createShaderKeyFromAttributes = (attributes: PoolCommonAttributes): string =>
|
|
(`${attributes.format};${attributes.ceilMode};${attributes.autoPad};${attributes.kernelShape.length}`);
|
|
|
|
const createAveragePoolShaderKeyFromAttributes = (attributes: AveragePoolAttributes): string =>
|
|
(`${createShaderKeyFromAttributes(attributes)};${attributes.countIncludePad}`);
|
|
|
|
const createMaxPoolShaderKeyFromAttributes = (attributes: MaxPoolAttributes): string =>
|
|
(`${createShaderKeyFromAttributes(attributes)};${attributes.storageOrder};${attributes.dilations}`);
|
|
|
|
const parsePoolCommonAttributes = (attributes: Record<string, unknown>): PoolCommonAttributes => ({
|
|
format: attributes.format as FormatAttributes['format'],
|
|
autoPad: ['NOTSET', 'VALID', 'SAME_UPPER', 'SAME_LOWER'][attributes.auto_pad as number],
|
|
ceilMode: attributes.ceil_mode as number,
|
|
kernelShape: attributes.kernel_shape as [number, number],
|
|
strides: attributes.strides as [number, number],
|
|
pads: attributes.pads as [number, number, number, number]
|
|
});
|
|
|
|
export interface AveragePoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
|
|
readonly countIncludePad: boolean;
|
|
}
|
|
|
|
const createAveragePoolProgramInfo =
|
|
(name: string, input: TensorView, isGlobalOperator: boolean, attributes: AveragePoolAttributes): ProgramInfo => {
|
|
const [adjustedAttributes, outputShape] =
|
|
getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator);
|
|
const x = inputVariable('x', input.dataType, input.dims.length);
|
|
const dataType = x.type.value;
|
|
|
|
const op1 = 'value += x_val;';
|
|
let op2 = '';
|
|
if (adjustedAttributes.countIncludePad) {
|
|
op2 += `value /= ${dataType}(uniforms.kernelSize);`;
|
|
} else {
|
|
op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
|
|
}
|
|
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
|
|
getUniformAndPadInfo(outputShape, adjustedAttributes);
|
|
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
|
|
return {
|
|
name,
|
|
shaderCache:
|
|
{hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies},
|
|
getRunData: () => ({
|
|
outputs: [{dims: outputShape, dataType: input.dataType}],
|
|
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
|
|
programUniforms
|
|
}),
|
|
getShaderSource: shaderHelper => generatePoolingCode(
|
|
shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms,
|
|
hasPads, pwStartEndNotZero, phStartEndNotZero),
|
|
};
|
|
};
|
|
|
|
export const parseAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
|
|
const countIncludePad = (attributes.count_include_pad as number) === 0 ? false : true;
|
|
|
|
const attr = parsePoolCommonAttributes(attributes);
|
|
// TODO: support attribute 'ceil_mode'
|
|
if (attr.ceilMode !== 0) {
|
|
throw new Error('using ceil() in shape computation is not yet supported for AveragePool');
|
|
}
|
|
const averagePoolAttributes = {countIncludePad, ...attr, cacheKey: ''};
|
|
return {...averagePoolAttributes, cacheKey: createAveragePoolShaderKeyFromAttributes(averagePoolAttributes)};
|
|
};
|
|
|
|
export const averagePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
|
|
validateInputs(context.inputs);
|
|
context.compute(createAveragePoolProgramInfo('AveragePool', context.inputs[0], false, attributes));
|
|
};
|
|
|
|
const globalPoolAttributes = {
|
|
autoPad: '',
|
|
ceilMode: 0,
|
|
countIncludePad: false,
|
|
kernelShape: [],
|
|
strides: [],
|
|
pads: [],
|
|
storageOrder: 0,
|
|
dilations: []
|
|
};
|
|
|
|
export const parseGlobalAveragePoolAttributes = (attributes: Record<string, unknown>): AveragePoolAttributes => {
|
|
const format = attributes.format as FormatAttributes['format'];
|
|
return {format, ...globalPoolAttributes, cacheKey: format};
|
|
};
|
|
|
|
export const globalAveragePool = (context: ComputeContext, attributes: AveragePoolAttributes): void => {
|
|
validateInputs(context.inputs);
|
|
context.compute(createAveragePoolProgramInfo('GlobalAveragePool', context.inputs[0], true, attributes));
|
|
};
|
|
|
|
export interface MaxPoolAttributes extends PoolCommonAttributes, AttributeWithCacheKey {
|
|
readonly storageOrder: number;
|
|
readonly dilations: number[];
|
|
}
|
|
|
|
const createMaxPoolProgramInfo =
|
|
(name: string, input: TensorView, isGlobalOperator: boolean, attributes: MaxPoolAttributes): ProgramInfo => {
|
|
const [adjustedAttributes, outputShape] =
|
|
getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator);
|
|
const op1 = `
|
|
value = max(x_val, value);
|
|
`;
|
|
const op2 = '';
|
|
const x = inputVariable('x', input.dataType, input.dims.length);
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
|
|
const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] =
|
|
getUniformAndPadInfo(outputShape, adjustedAttributes);
|
|
programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape));
|
|
return {
|
|
name,
|
|
shaderCache:
|
|
{hint: `${attributes.cacheKey};${hasPads};${pwStartEndNotZero};${phStartEndNotZero}`, inputDependencies},
|
|
getRunData: () => ({
|
|
outputs: [{dims: outputShape, dataType: input.dataType}],
|
|
dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
|
|
programUniforms
|
|
}),
|
|
getShaderSource: shaderHelper => generatePoolingCode(
|
|
shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms,
|
|
hasPads, pwStartEndNotZero, phStartEndNotZero),
|
|
};
|
|
};
|
|
|
|
export const maxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
|
|
validateInputs(context.inputs);
|
|
context.compute(createMaxPoolProgramInfo('MaxPool', context.inputs[0], false, attributes));
|
|
};
|
|
|
|
export const parseMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
|
|
const storageOrder = attributes.storage_order as number;
|
|
const dilations = attributes.dilations as [number, number];
|
|
|
|
const attr = parsePoolCommonAttributes(attributes);
|
|
// TODO: support attribute 'ceil_mode' and 'storage_order'
|
|
if (storageOrder !== 0) {
|
|
throw new Error('column major storage order is not yet supported for MaxPool');
|
|
}
|
|
if (attr.ceilMode !== 0) {
|
|
throw new Error('using ceil() in shape computation is not yet supported for MaxPool');
|
|
}
|
|
const maxPoolAttributes = {storageOrder, dilations, ...attr, cacheKey: ''};
|
|
return {...maxPoolAttributes, cacheKey: createMaxPoolShaderKeyFromAttributes(maxPoolAttributes)};
|
|
};
|
|
|
|
export const parseGlobalMaxPoolAttributes = (attributes: Record<string, unknown>): MaxPoolAttributes => {
|
|
const format = attributes.format as FormatAttributes['format'];
|
|
return {format, ...globalPoolAttributes, cacheKey: format};
|
|
};
|
|
|
|
export const globalMaxPool = (context: ComputeContext, attributes: MaxPoolAttributes): void => {
|
|
validateInputs(context.inputs);
|
|
context.compute(createMaxPoolProgramInfo('GlobalMaxPool', context.inputs[0], true, attributes));
|
|
};
|