mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[js/webgpu] Remove enableShapesUniforms (#19279)
This commit is contained in:
parent
00d048121b
commit
624b4e2063
9 changed files with 68 additions and 134 deletions
|
|
@ -443,9 +443,9 @@ export const createMatmulProgramInfo =
|
|||
|
||||
const components = isVec4 ? 4 : 1;
|
||||
const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components];
|
||||
const aShapeOrRank = aShapeTemp.length;
|
||||
const aRank = aShapeTemp.length;
|
||||
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
|
||||
const bShapeOrRank = bShapeTemp.length;
|
||||
const bRank = bShapeTemp.length;
|
||||
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
|
|
@ -467,12 +467,12 @@ export const createMatmulProgramInfo =
|
|||
programUniforms.push(...createTensorShapeVariables(outputShapeTemp));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const batchShapeOrRank = outerDims.length;
|
||||
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1);
|
||||
const batchRank = outerDims.length;
|
||||
const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1);
|
||||
const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
|
||||
|
||||
const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components);
|
||||
const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components);
|
||||
const A = inputVariable('a', inputs[0].dataType, aRank, components);
|
||||
const B = inputVariable('b', inputs[1].dataType, bRank, components);
|
||||
const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components);
|
||||
const inputVariables = [A, B];
|
||||
if (hasBias) {
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface BatchNormAttributes extends AttributeWithCacheKey {
|
||||
readonly epsilon: number;
|
||||
|
|
@ -61,7 +61,7 @@ const createBatchNormInferenceProgramInfo =
|
|||
const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1;
|
||||
const outputSize = ShapeUtil.size(yShape) / components;
|
||||
// Only support uniforms for opset version >= 9 (spatial = true).
|
||||
const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial;
|
||||
const useShapesUniforms = spatial;
|
||||
const shapeOrRank = useShapesUniforms ? yShape.length : yShape;
|
||||
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components);
|
||||
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents);
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
|
|||
import {BroadcastUtil, ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
type BuiltinFunctionName = string;
|
||||
type BinaryCustomExpression = (expressionA: string, expressionB: string) => string;
|
||||
|
|
@ -18,8 +18,7 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{
|
|||
const createBinaryOpProgramShader =
|
||||
(shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[],
|
||||
vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall,
|
||||
typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean,
|
||||
additionalImplementation?: string) => {
|
||||
typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => {
|
||||
let expressionScalar: BinaryCustomExpression;
|
||||
let expressionVector: BinaryCustomExpression;
|
||||
if (typeof funcCall === 'string') {
|
||||
|
|
@ -31,12 +30,9 @@ const createBinaryOpProgramShader =
|
|||
expressionVector = funcCall.vector;
|
||||
}
|
||||
|
||||
const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA;
|
||||
const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB;
|
||||
const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput;
|
||||
const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4);
|
||||
const a = inputVariable('aData', typeA, inputAShapeOrRank, 4);
|
||||
const b = inputVariable('bData', typeB, inputBShapeOrRank, 4);
|
||||
const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4);
|
||||
const a = inputVariable('aData', typeA, dimsA.length, 4);
|
||||
const b = inputVariable('bData', typeB, dimsB.length, 4);
|
||||
|
||||
let assignment: string;
|
||||
if (vectorize) {
|
||||
|
|
@ -169,30 +165,25 @@ const createBinaryOpProgramInfo =
|
|||
vectorize = true;
|
||||
}
|
||||
cacheKeyAux.push(vectorize);
|
||||
const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) &&
|
||||
enableShapesUniforms(outputShape.length);
|
||||
|
||||
return {
|
||||
name,
|
||||
shaderCache: {
|
||||
hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'),
|
||||
inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'],
|
||||
inputDependencies: ['rank', 'rank'],
|
||||
},
|
||||
getShaderSource: (shaderHelper) => createBinaryOpProgramShader(
|
||||
shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall,
|
||||
a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation),
|
||||
a.dataType, b.dataType, outputDataType, additionalImplementation),
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: outputDataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
|
||||
programUniforms: useShapesUniforms ?
|
||||
[
|
||||
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
|
||||
...createTensorShapeVariables(a.dims),
|
||||
...createTensorShapeVariables(b.dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
] :
|
||||
[
|
||||
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
|
||||
],
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
|
||||
...createTensorShapeVariables(a.dims),
|
||||
...createTensorShapeVariables(b.dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
],
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -922,6 +922,3 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
|
|||
}
|
||||
return dims;
|
||||
};
|
||||
|
||||
// TODO: remove this when all related uses have been removed.
|
||||
export const enableShapesUniforms = (_rank: number): boolean => true;
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface ConcatAttributes extends AttributeWithCacheKey {
|
||||
readonly axis: number;
|
||||
|
|
@ -94,32 +94,22 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
|
|||
|
||||
let previousSum = 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
const inputShapeOrRanks = [];
|
||||
const enableInputShapesUniforms = [];
|
||||
const inputRanks = [];
|
||||
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
previousSum += inputs[i].dims[adjustedAxis];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length));
|
||||
inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims);
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]);
|
||||
inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims');
|
||||
inputRanks.push(inputs[i].dims.length);
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
|
||||
inputDependencies.push('rank');
|
||||
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
|
||||
}
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
if (enableInputShapesUniforms[i]) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
||||
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
|
||||
if (enableOutputShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
|
||||
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
|
||||
const output = outputVariable('output', dataType, outputShapeOrRank);
|
||||
|
||||
const output = outputVariable('output', dataType, outputShape.length);
|
||||
const indicesAxis = output.indicesGet('indices', adjustedAxis);
|
||||
const sizeInConcatAxisStr =
|
||||
Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(',');
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface EinsumAttributes extends AttributeWithCacheKey {
|
||||
readonly equation: string;
|
||||
|
|
@ -181,14 +180,12 @@ class EinsumEquation {
|
|||
const appendMax = (name: string): string => name + '_max';
|
||||
|
||||
const createEinsumProgramInfo =
|
||||
(enableInputShapesUniforms: readonly boolean[], inputShapes: Array<readonly number[]>, dataType: number,
|
||||
einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => {
|
||||
const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims);
|
||||
const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank));
|
||||
(inputShapes: Array<readonly number[]>, dataType: number, einsumEquation: EinsumEquation,
|
||||
outputShape: readonly number[]): ProgramInfo => {
|
||||
const ranks = inputShapes.map((dims) => dims.length);
|
||||
const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank));
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
|
||||
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
|
||||
const output = outputVariable('output', dataType, outputShapeOrRank);
|
||||
const output = outputVariable('output', dataType, outputShape.length);
|
||||
const uniformsSymbols =
|
||||
[...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol));
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
@ -269,10 +266,7 @@ const createEinsumProgramInfo =
|
|||
};
|
||||
return {
|
||||
name: 'Einsum',
|
||||
shaderCache: {
|
||||
hint: einsumEquation.equation,
|
||||
inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims')
|
||||
},
|
||||
shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')},
|
||||
getRunData: () => {
|
||||
// The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The
|
||||
// filter is added to make sure that dimValue is never 0.
|
||||
|
|
@ -281,12 +275,9 @@ const createEinsumProgramInfo =
|
|||
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
|
||||
programUniformsInit.push({type: 'uint32', data: outputSize});
|
||||
const programUniforms: ProgramUniform[] =
|
||||
inputShapes.filter((_, index) => enableInputShapesUniforms[index])
|
||||
.map((dims, _) => [...createTensorShapeVariables(dims)])
|
||||
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
|
||||
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
|
||||
if (enableOutputShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
return ({
|
||||
outputs: [{dims: outputShape, dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
|
|
@ -299,11 +290,9 @@ const createEinsumProgramInfo =
|
|||
|
||||
export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => {
|
||||
const einsumEquation = new EinsumEquation(context.inputs, attributes.equation);
|
||||
const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length));
|
||||
const outputShape = einsumEquation.outputDims;
|
||||
const inputShapes = context.inputs.map((input, _) => input.dims);
|
||||
context.compute(createEinsumProgramInfo(
|
||||
enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
|
||||
context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape));
|
||||
};
|
||||
|
||||
export const parseEinsumAttributes = (attributes: Record<string, unknown>): EinsumAttributes => {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
|
|||
import {ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
const validateInputs = (inputs: readonly TensorView[]): void => {
|
||||
if (!inputs || inputs.length !== 2) {
|
||||
|
|
@ -49,15 +49,9 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
|
|||
const components = dataType === DataType.bool ? 4 : 1;
|
||||
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
|
||||
|
||||
const enableInputShapeUniform = enableShapesUniforms(inputShape.length);
|
||||
const enableOutputShapeUniform = enableShapesUniforms(outputShape.length);
|
||||
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape;
|
||||
const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape;
|
||||
const input = inputVariable('input', dataType, inputShapeOrRank, components);
|
||||
const output = outputVariable('output', dataType, outputShapeOrRank, components);
|
||||
const input = inputVariable('input', dataType, inputShape.length, components);
|
||||
const output = outputVariable('output', dataType, outputShape.length, components);
|
||||
let assignment: string;
|
||||
if (dataType === DataType.bool) {
|
||||
const singleAssignment = (resStr: string, x: number, typeCast = '') => `
|
||||
|
|
@ -90,16 +84,13 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
|
|||
${assignment}`;
|
||||
};
|
||||
|
||||
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
|
||||
if (enableInputShapeUniform) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputShape));
|
||||
}
|
||||
if (enableOutputShapeUniform) {
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
return {
|
||||
name: 'Expand',
|
||||
shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']},
|
||||
shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']},
|
||||
getShaderSource,
|
||||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common';
|
|||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface GatherAttributes extends AttributeWithCacheKey {
|
||||
axis: number;
|
||||
|
|
@ -33,33 +33,16 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
|
|||
const components = inputs[0].dataType === DataType.bool ? 4 : 1;
|
||||
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
|
||||
|
||||
const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length);
|
||||
const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims;
|
||||
const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length);
|
||||
const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims;
|
||||
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
|
||||
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
|
||||
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
|
||||
if (enableInputShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
|
||||
}
|
||||
if (enableIndicesShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
|
||||
}
|
||||
if (enableOutputShapesUniforms) {
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
}
|
||||
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims');
|
||||
inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims');
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis},
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components);
|
||||
const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components);
|
||||
const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components);
|
||||
const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length);
|
||||
const output = outputVariable('output', inputs[0].dataType, outputShape.length, components);
|
||||
|
||||
const calcDataIndices = (x: number|string): string => {
|
||||
const indicesRank = indicesShape.length;
|
||||
|
|
@ -127,7 +110,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
|
|||
};
|
||||
return {
|
||||
name: 'Gather',
|
||||
shaderCache: {hint: attributes.cacheKey, inputDependencies},
|
||||
shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']},
|
||||
getRunData: () => ({
|
||||
outputs: [
|
||||
{dims: outputShape, dataType: inputs[0].dataType},
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
|
|||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
import {ComputeContext, ProgramInfo} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
|
||||
|
||||
export interface TransposeAttributes extends AttributeWithCacheKey {
|
||||
readonly perm: number[];
|
||||
|
|
@ -39,12 +39,9 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
const inputDataType = inputTensor.dataType;
|
||||
const inputRank = inputTensor.dims.length;
|
||||
const perm = getAdjustedPerm(inputRank, permAttr);
|
||||
const useShapesUniforms = enableShapesUniforms(inputRank);
|
||||
const outputShape = getOutputShape(inputTensor.dims, perm);
|
||||
const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
|
||||
const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
|
||||
const output = outputVariable('output', inputDataType, outShapeOrRank);
|
||||
const input = inputVariable('a', inputDataType, inShapeOrRank);
|
||||
const output = outputVariable('output', inputDataType, outputShape.length);
|
||||
const input = inputVariable('a', inputDataType, inputRank);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
|
||||
|
|
@ -61,21 +58,17 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
}`;
|
||||
return {
|
||||
name: 'Transpose',
|
||||
shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
|
||||
shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
|
||||
getRunData: (inputs) => {
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
return {
|
||||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: useShapesUniforms ?
|
||||
[
|
||||
{type: 'uint32', data: outputSize},
|
||||
...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
] :
|
||||
[
|
||||
{type: 'uint32', data: outputSize},
|
||||
],
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: outputSize},
|
||||
...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
],
|
||||
};
|
||||
},
|
||||
getShaderSource,
|
||||
|
|
|
|||
Loading…
Reference in a new issue