mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
### Description Also update the op test suite. ### Motivation and Context Previously the *total* size in case `Expand - last dim is not divisible by 4` was a multiple of 4, even though the *last dimension* was not, so the bug has never been caught.
149 lines
6.2 KiB
TypeScript
149 lines
6.2 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {DataType} from '../../../wasm-common';
|
|
import {TensorView} from '../../tensor-view';
|
|
import {ShapeUtil} from '../../util';
|
|
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
|
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
|
|
|
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
|
|
|
|
export interface GatherAttributes extends AttributeWithCacheKey {
|
|
axis: number;
|
|
}
|
|
|
|
const validateInputs = (inputs: readonly TensorView[]): void => {
|
|
if (!inputs || inputs.length !== 2) {
|
|
throw new Error('Gather requires 2 inputs.');
|
|
}
|
|
};
|
|
|
|
const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: GatherAttributes): ProgramInfo => {
|
|
const inputShape = inputs[0].dims;
|
|
const indicesShape = inputs[1].dims;
|
|
|
|
const inputRank = inputShape.length;
|
|
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
|
|
|
|
const outputShape = inputShape.slice(0);
|
|
outputShape.splice(axis, 1, ...indicesShape);
|
|
|
|
const axisDimLimit = inputShape[axis];
|
|
const components = inputs[0].dataType === DataType.bool ? 4 : 1;
|
|
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
|
|
|
|
const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length);
|
|
const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims;
|
|
const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length);
|
|
const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims;
|
|
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
|
|
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
|
|
|
|
const programUniforms: ProgramUniform[] =
|
|
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
|
|
if (enableInputShapesUniforms) {
|
|
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
|
}
|
|
if (enableIndicesShapesUniforms) {
|
|
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
|
|
}
|
|
if (enableOutputShapesUniforms) {
|
|
programUniforms.push(...createTensorShapeVariables(outputShape));
|
|
}
|
|
|
|
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
|
inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims');
|
|
inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims');
|
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components);
|
|
const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank);
|
|
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components);
|
|
|
|
const calcDataIndices = (x: number|string): string => {
|
|
const indicesRank = indicesShape.length;
|
|
let calcStr = `var indicesIndices${x} = ${indices.type.indices}(0);`;
|
|
for (let i = 0; i < indicesRank; i++) {
|
|
calcStr += `${indicesRank > 1 ? `indicesIndices${x}[${i}]` : `indicesIndices${x}`} = ${
|
|
outputShape.length > 1 ? `outputIndices${x}[uniforms.axis + ${i}]` : `outputIndices${x}`};`;
|
|
}
|
|
calcStr += `
|
|
var idx${x} = ${indices.getByIndices(`indicesIndices${x}`)};
|
|
if (idx${x} < 0) {
|
|
idx${x} = idx${x} + uniforms.axisDimLimit;
|
|
}
|
|
var dataIndices${x} = ${data.type.indices}(0);
|
|
`;
|
|
for (let i = 0, j = 0; i < inputRank; i++) {
|
|
if (i === axis) {
|
|
calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = u32(idx${x});`;
|
|
j += indicesRank;
|
|
} else {
|
|
calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = ${
|
|
outputShape.length > 1 ? `outputIndices${x}[${j}]` : `outputIndices${x}`};`;
|
|
j++;
|
|
}
|
|
}
|
|
return calcStr;
|
|
};
|
|
let assignment: string;
|
|
if (inputs[0].dataType === DataType.bool) {
|
|
const singleAssignment = (resStr: string, x: number, typeCast = '') => `
|
|
let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)};
|
|
${calcDataIndices(x)};
|
|
let offset${x} = ${data.indicesToOffset(`dataIndices${x}`)};
|
|
let index${x} = offset${x} / 4u;
|
|
let component${x} = offset${x} % 4u;
|
|
${resStr}[${x}] = ${typeCast}(${data.getByOffset(`index${x}`)}[component${x}]);
|
|
`;
|
|
assignment = `
|
|
let outputOffset = global_idx * ${components};
|
|
var value = vec4<u32>(0);
|
|
${singleAssignment('value', 0, 'u32')}
|
|
${singleAssignment('value', 1, 'u32')}
|
|
${singleAssignment('value', 2, 'u32')}
|
|
${singleAssignment('value', 3, 'u32')}
|
|
${output.setByOffset('global_idx', 'value')}
|
|
`;
|
|
} else {
|
|
assignment = `
|
|
let outputIndices = ${output.offsetToIndices('global_idx')};
|
|
${calcDataIndices('')};
|
|
let value = ${data.getByIndices('dataIndices')};
|
|
${output.setByOffset('global_idx', 'value')};
|
|
`;
|
|
}
|
|
return `
|
|
${
|
|
shaderHelper.registerUniform('outputSize', 'u32')
|
|
.registerUniform('axisDimLimit', 'i32')
|
|
.registerUniform('axis', 'u32')
|
|
.declareVariables(data, indices, output)}
|
|
${shaderHelper.mainStart()}
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
|
${assignment}
|
|
}`;
|
|
};
|
|
return {
|
|
name: 'Gather',
|
|
shaderCache: {hint: attributes.cacheKey, inputDependencies},
|
|
getRunData: () => ({
|
|
outputs: [
|
|
{dims: outputShape, dataType: inputs[0].dataType},
|
|
],
|
|
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
|
programUniforms
|
|
}),
|
|
getShaderSource,
|
|
};
|
|
};
|
|
|
|
export const parseGatherAttributes = (attributes: Record<string, unknown>): GatherAttributes =>
|
|
createAttributeWithCacheKey({axis: attributes.axis as number});
|
|
|
|
export const gather = (context: ComputeContext, attributes: GatherAttributes): void => {
|
|
const inputs = context.inputs;
|
|
validateInputs(inputs);
|
|
context.compute(createGatherProgramInfo(context.inputs, attributes));
|
|
};
|