mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[JS/Web] AddedUniforms in GatherElements. (#18670)
### Description Use Uniforms in GatherElements and clean-up ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Improve performance
This commit is contained in:
parent
f949e0580b
commit
70816001cc
1 changed files with 26 additions and 32 deletions
|
|
@ -4,9 +4,9 @@
|
|||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface GatherElementsAttributes extends AttributeWithCacheKey {
|
||||
axis: number;
|
||||
|
|
@ -32,65 +32,59 @@ const createGatherElementsProgramInfo =
|
|||
const inputShape = inputs[0].dims;
|
||||
const inputOutputDataType = inputs[0].dataType;
|
||||
const inputRank = inputShape.length;
|
||||
const inputStrides = ShapeUtil.computeStrides(inputShape);
|
||||
const inputSize = ShapeUtil.size(inputShape);
|
||||
|
||||
const indicesShape = inputs[1].dims;
|
||||
const indicesDataType = inputs[1].dataType;
|
||||
const indicesSize = ShapeUtil.size(indicesShape);
|
||||
|
||||
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
|
||||
const axisDimLimit = inputShape[axis];
|
||||
|
||||
const outputShape = indicesShape.slice(0);
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const input = inputVariable('input', inputOutputDataType, inputShape);
|
||||
const indices = inputVariable('indices', indicesDataType, [indicesSize]);
|
||||
const output = outputVariable('output', inputOutputDataType, outputShape);
|
||||
const input = inputVariable('input', inputOutputDataType, inputRank);
|
||||
const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length);
|
||||
const output = outputVariable('output', inputOutputDataType, outputShape.length);
|
||||
|
||||
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
|
||||
programUniforms.push(...createTensorShapeVariables(inputShape));
|
||||
programUniforms.push(...createTensorShapeVariables(indicesShape));
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
|
||||
|
||||
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
|
||||
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
|
||||
// Input data will be treated as u32 or two u32 for 8-byte tensors
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
const inputStrides = array<u32, ${inputStrides.length}>(${inputStrides.map(i => `${i}u`).join(',')});
|
||||
${shaderHelper.declareVariables(input, indices, output)}
|
||||
${
|
||||
shaderHelper.registerUniform('outputSize', 'u32')
|
||||
.registerUniform('axisDimLimit', 'i32')
|
||||
.registerUniform('axis', 'u32')
|
||||
.declareVariables(input, indices, output)}
|
||||
${shaderHelper.mainStart()}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
|
||||
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
|
||||
|
||||
let outputIndices = ${output.offsetToIndices('global_idx')};
|
||||
|
||||
var idx = ${indices.getByOffset('global_idx')};
|
||||
if (idx < 0) {
|
||||
idx = idx + ${axisDimLimit};
|
||||
idx = idx + uniforms.axisDimLimit;
|
||||
}
|
||||
var inputIndices = ${input.type.indices}(outputIndices);
|
||||
${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(idx)')};
|
||||
let value = ${input.getByIndices('inputIndices')};
|
||||
|
||||
var srcOffset = u32(0);
|
||||
|
||||
for (var i = 0; i < ${inputShape.length}; i++) {
|
||||
if (i == ${axis}) {
|
||||
srcOffset += u32(idx) * inputStrides[i];
|
||||
} else {
|
||||
srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Should never hit this with valid values in indices
|
||||
// This is a guard against malicious data in the indices input
|
||||
if (srcOffset < 0 || srcOffset >= ${inputSize}) {
|
||||
return;
|
||||
}
|
||||
|
||||
output[global_idx] = input[srcOffset];
|
||||
${output.setByOffset('global_idx', 'value')};
|
||||
}`;
|
||||
|
||||
return {
|
||||
name: 'GatherElements',
|
||||
shaderCache: {hint: attributes.cacheKey},
|
||||
shaderCache: {inputDependencies},
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue