mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-15 01:23:42 +00:00
[js/webgpu] Use DataType as uniform cpu type (#19281)
This saves turning data type to string by tensorDataTypeEnumToString.
This commit is contained in:
parent
85cef0af8c
commit
d73131cf0f
37 changed files with 148 additions and 108 deletions
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
|
||||
|
||||
import {tensorDataTypeEnumToString} from '../wasm-common';
|
||||
import {DataType, tensorDataTypeEnumToString} from '../wasm-common';
|
||||
|
||||
import {configureLogger, LOG_DEBUG} from './log';
|
||||
import {createView, TensorView} from './tensor-view';
|
||||
|
|
@ -453,10 +453,10 @@ export class WebGpuBackend {
|
|||
return;
|
||||
}
|
||||
// https://www.w3.org/TR/WGSL/#alignof
|
||||
const sizeOfElement = v.type === 'float16' ? 2 : 4;
|
||||
const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
|
||||
let sizeOfVecOrMat;
|
||||
let baseAlignment;
|
||||
if (v.type === 'float16') {
|
||||
if (v.type === DataType.float16) {
|
||||
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
|
||||
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
|
||||
} else {
|
||||
|
|
@ -470,7 +470,7 @@ export class WebGpuBackend {
|
|||
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
|
||||
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
|
||||
// length is N * SizeOf(mat2x4<f16>).
|
||||
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
|
||||
const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
|
||||
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
|
||||
data.length * sizeOfElement;
|
||||
});
|
||||
|
|
@ -483,15 +483,17 @@ export class WebGpuBackend {
|
|||
programUniforms.forEach((v, i) => {
|
||||
const offset = offsets[i];
|
||||
const data = typeof v.data === 'number' ? [v.data] : v.data;
|
||||
if (v.type === 'int32') {
|
||||
if (v.type === DataType.int32) {
|
||||
new Int32Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else if (v.type === 'uint32') {
|
||||
} else if (v.type === DataType.uint32) {
|
||||
new Uint32Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else if (v.type === 'float16') {
|
||||
} else if (v.type === DataType.float16) {
|
||||
// TODO: use Float16Array.
|
||||
new Uint16Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else {
|
||||
} else if (v.type === DataType.float) {
|
||||
new Float32Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else {
|
||||
throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
//
|
||||
// modified to fit the needs of the project
|
||||
|
||||
import {DataType} from '../../../../wasm-common';
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
|
|
@ -189,9 +190,9 @@ export const createConv2DMatMulProgramInfo =
|
|||
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
|
||||
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides},
|
||||
{type: 'int32', data: attributes.dilations}
|
||||
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
|
||||
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]},
|
||||
{type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
|
||||
];
|
||||
appendActivationUniformsData(attributes, programUniforms);
|
||||
programUniforms.push(
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
//
|
||||
// modified to fit the needs of the project
|
||||
|
||||
import {DataType} from '../../../../wasm-common';
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
|
|
@ -197,9 +198,10 @@ export const createConv2DTransposeMatMulProgramInfo =
|
|||
];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner},
|
||||
{type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations},
|
||||
{type: 'int32', data: filterDims}, {type: 'int32', data: pads}
|
||||
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
|
||||
{type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides},
|
||||
{type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims},
|
||||
{type: DataType.int32, data: pads}
|
||||
];
|
||||
appendActivationUniformsData(attributes, programUniforms);
|
||||
programUniforms.push(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts
|
||||
|
||||
import {DataType} from '../../../../wasm-common';
|
||||
import {LOG_DEBUG} from '../../../log';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
|
|
@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo =
|
|||
const outputChannelsPerGroup = wShape[1];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims},
|
||||
{type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads},
|
||||
{type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup},
|
||||
{type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides},
|
||||
{type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations},
|
||||
{type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads},
|
||||
{type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup},
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)
|
||||
];
|
||||
if (hasBias) {
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
//
|
||||
// modified to fit the needs of the project
|
||||
|
||||
import {DataType} from '../../../../wasm-common';
|
||||
import {TensorView} from '../../../tensor-view';
|
||||
import {ShapeUtil} from '../../../util';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
|
||||
|
|
@ -447,8 +448,10 @@ export const createMatmulProgramInfo =
|
|||
const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components];
|
||||
const bRank = bShapeTemp.length;
|
||||
const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
|
||||
{type: DataType.int32, data: dimInner}
|
||||
];
|
||||
appendActivationUniformsData(activationAttributes, programUniforms);
|
||||
programUniforms.push(
|
||||
...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp),
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {tensorDataTypeEnumToString} from '../../../wasm-common';
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ComputeContext, GpuDataType, ProgramUniform} from '../types';
|
||||
|
||||
|
|
@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
|
|||
WG = Math.ceil(dComp / 8);
|
||||
}
|
||||
const elementsPerWG = Math.ceil(d / components / WG);
|
||||
const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type'];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp},
|
||||
{type: DataType.uint32, data: elementsPerWG}
|
||||
];
|
||||
const dataType = tensorTypeToWsglStorageType(input.dataType, components);
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
@ -336,11 +337,10 @@ const computeAttentionProbs =
|
|||
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
|
||||
z: parameters.batchSize * parameters.numHeads
|
||||
};
|
||||
const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type'];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize},
|
||||
{type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength},
|
||||
{type: tensorDataType, data: alpha}
|
||||
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
|
||||
{type: DataType.uint32, data: parameters.totalSequenceLength},
|
||||
{type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha}
|
||||
];
|
||||
|
||||
const inputs = [q, key];
|
||||
|
|
@ -430,9 +430,9 @@ const computeVxAttentionScore =
|
|||
z: params.batchSize * params.numHeads
|
||||
};
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength},
|
||||
{type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads},
|
||||
{type: 'uint32', data: params.vHiddenSize}
|
||||
{type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
|
||||
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
|
||||
{type: DataType.uint32, data: params.vHiddenSize}
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
|
|||
};
|
||||
const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N},
|
||||
{type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize},
|
||||
{type: 'uint32', data: parameters.hiddenSize},
|
||||
{type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
|
||||
{type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N},
|
||||
{type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize},
|
||||
{type: DataType.uint32, data: parameters.hiddenSize},
|
||||
{type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import {env} from 'onnxruntime-common';
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo =
|
|||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: useShapesUniforms ?
|
||||
[
|
||||
{type: 'uint32', data: outputSize},
|
||||
{type: DataType.uint32, data: outputSize},
|
||||
...createTensorShapeVariables(yShape),
|
||||
] :
|
||||
[
|
||||
{type: 'uint32', data: outputSize},
|
||||
{type: DataType.uint32, data: outputSize},
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -179,7 +179,7 @@ const createBinaryOpProgramInfo =
|
|||
outputs: [{dims: outputShape, dataType: outputDataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
|
||||
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)},
|
||||
...createTensorShapeVariables(a.dims),
|
||||
...createTensorShapeVariables(b.dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
|
|
|
|||
|
|
@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 =
|
|||
return typeof mappedType === 'string' ? mappedType : mappedType[1];
|
||||
};
|
||||
|
||||
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] =>
|
||||
dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}];
|
||||
export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ?
|
||||
[] :
|
||||
[{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}];
|
||||
|
||||
/**
|
||||
* A helper function to get maximum vector size for specified data length
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -95,14 +96,14 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P
|
|||
let previousSum = 0;
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = [];
|
||||
const inputRanks = [];
|
||||
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
|
||||
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}];
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
previousSum += inputs[i].dims[adjustedAxis];
|
||||
sizeInConcatAxis[i] = previousSum;
|
||||
inputRanks.push(inputs[i].dims.length);
|
||||
inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]);
|
||||
inputDependencies.push('rank');
|
||||
programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]});
|
||||
programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]});
|
||||
}
|
||||
for (let i = 0; i < inputs.length; ++i) {
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[i].dims));
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
|
@ -28,9 +29,10 @@ export const createGroupedConvProgramInfo =
|
|||
const outputSize = ShapeUtil.size(outputShape);
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations},
|
||||
{type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]},
|
||||
{type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup}
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations},
|
||||
{type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]},
|
||||
{type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]},
|
||||
{type: DataType.uint32, data: outputChannelsPerGroup}
|
||||
];
|
||||
appendActivationUniformsData(attributes, programUniforms);
|
||||
programUniforms.push(
|
||||
|
|
@ -127,8 +129,9 @@ export const createGroupedConvVectorizeProgramInfo =
|
|||
const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]},
|
||||
{type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}
|
||||
{type: DataType.uint32, data: outputSize},
|
||||
{type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]},
|
||||
{type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}
|
||||
];
|
||||
appendActivationUniformsData(attributes, programUniforms);
|
||||
programUniforms.push(
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ const createCumsumProgramInfo =
|
|||
outputs: [{dims: inputShape, dataType: inputType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: outputSize}, {type: 'int32', data: axis},
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis},
|
||||
...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -272,8 +273,10 @@ const createEinsumProgramInfo =
|
|||
// filter is added to make sure that dimValue is never 0.
|
||||
const programUniformsInit: ProgramUniform[] =
|
||||
uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol))
|
||||
.map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
|
||||
programUniformsInit.push({type: 'uint32', data: outputSize});
|
||||
.map(
|
||||
(symbol) =>
|
||||
({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0}));
|
||||
programUniformsInit.push({type: DataType.uint32, data: outputSize});
|
||||
const programUniforms: ProgramUniform[] =
|
||||
inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)])
|
||||
.reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit);
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
|
|||
};
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
|
||||
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {MAX_CLIP, MIN_CLIP} from '../../util';
|
||||
import {ProgramUniform} from '../types';
|
||||
|
||||
|
|
@ -36,9 +37,11 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v
|
|||
export const appendActivationUniformsData =
|
||||
(attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => {
|
||||
if (attributes.activation === 'Clip') {
|
||||
programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!});
|
||||
programUniform.push(
|
||||
{type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!});
|
||||
} else if (attributes.activation === 'HardSigmoid') {
|
||||
programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!});
|
||||
programUniform.push(
|
||||
{type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!});
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -46,8 +47,10 @@ const createGatherElementsProgramInfo =
|
|||
const output = outputVariable('output', inputOutputDataType, outputShape.length);
|
||||
|
||||
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit},
|
||||
{type: DataType.uint32, data: axis}
|
||||
];
|
||||
programUniforms.push(...createTensorShapeVariables(inputShape));
|
||||
programUniforms.push(...createTensorShapeVariables(indicesShape));
|
||||
programUniforms.push(...createTensorShapeVariables(outputShape));
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
|
|||
const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components);
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis},
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit},
|
||||
{type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {GemmUtil, ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -45,8 +46,9 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt
|
|||
}
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K},
|
||||
{type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta}
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N},
|
||||
{type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha},
|
||||
{type: DataType.float, data: attributes.beta}
|
||||
];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
if (inputs.length === 3) {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ const createInstanceNormProgramInfo =
|
|||
const inputShape = [xShape[0], xShape[1], normPackedSize];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type'];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}];
|
||||
[{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}];
|
||||
programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape));
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
@ -132,8 +132,9 @@ const computeMean =
|
|||
|
||||
const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
|
||||
const meanProgramUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)},
|
||||
{type: 'uint32', data: Math.floor(h * c / components)}
|
||||
{type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h},
|
||||
{type: DataType.uint32, data: Math.floor(c / components)},
|
||||
{type: DataType.uint32, data: Math.floor(h * c / components)}
|
||||
];
|
||||
|
||||
const getMeanShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
@ -182,8 +183,9 @@ const computeMean =
|
|||
{inputs: [input], outputs: [-1]})[0];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h},
|
||||
{type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)}
|
||||
{type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h},
|
||||
{type: DataType.uint32, data: Math.floor(c / components)},
|
||||
{type: DataType.uint32, data: Math.floor(WG * c / components)}
|
||||
];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type'];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
|
|
@ -246,7 +248,7 @@ const createInstanceNormNHWCProgramInfo =
|
|||
const components = getMaxComponents(C);
|
||||
const outputSize = ShapeUtil.size(outputShape) / components;
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}];
|
||||
[{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}];
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
// first compute mean
|
||||
const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon);
|
||||
|
|
|
|||
|
|
@ -49,8 +49,9 @@ const createLayerNormProgramInfo =
|
|||
const components = getMaxComponents(normSize);
|
||||
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type'];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: normCount}, {type: 'float32', data: normSize},
|
||||
{type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon}
|
||||
{type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize},
|
||||
{type: DataType.uint32, data: Math.floor(normSize / components)},
|
||||
{type: DataType.float, data: attributes.epsilon}
|
||||
];
|
||||
if (bias) {
|
||||
inputDependencies.push('type');
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {BroadcastUtil, ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
|
@ -29,8 +30,8 @@ export const createNaiveMatmulProgramInfo =
|
|||
const outputShapeInShader = [batchSize, M, N];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N},
|
||||
{type: 'uint32', data: K}
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N},
|
||||
{type: DataType.uint32, data: K}
|
||||
];
|
||||
appendActivationUniformsData(activationAttributes, programUniforms);
|
||||
programUniforms.push(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -238,8 +239,10 @@ const addBiasTranspose =
|
|||
hiddenSize: number, biasOffset: number) => {
|
||||
const outputShape = [batchSize, sequenceLength, hiddenSize];
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset},
|
||||
{type: DataType.uint32, data: hiddenSize}
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape);
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common';
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
|
||||
|
|
@ -153,10 +153,9 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
|
|||
const inputDims = inputs[0].dims;
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}];
|
||||
[{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.pads}];
|
||||
if (attributes.mode === 0) {
|
||||
const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type'];
|
||||
programUniforms.push({type: tensorDataType, data: attributes.value});
|
||||
programUniforms.push({type: inputs[0].dataType, data: attributes.value});
|
||||
}
|
||||
|
||||
programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape));
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import {env} from 'onnxruntime-common';
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {PoolConvUtil, ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -56,7 +57,8 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
|
|||
const isChannelsLast = attributes.format === 'NHWC';
|
||||
const outputSize = ShapeUtil.size(outputShape);
|
||||
const kernelSize = ShapeUtil.size(attributes.kernelShape);
|
||||
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}, {type: 'uint32', data: kernelSize}];
|
||||
const programUniforms: ProgramUniform[] =
|
||||
[{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: kernelSize}];
|
||||
const uniforms: UniformsArrayType = [{name: 'outputSize', type: 'u32'}, {name: 'kernelSize', type: 'u32'}];
|
||||
if (attributes.kernelShape.length <= 2) {
|
||||
const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
|
||||
|
|
@ -65,10 +67,10 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
|
|||
const pwEnd = attributes.pads[attributes.pads.length - 1];
|
||||
const pwStartEndNotZero = !!(pwStart + pwEnd);
|
||||
programUniforms.push(
|
||||
{type: 'uint32', data: kw},
|
||||
{type: 'uint32', data: sw},
|
||||
{type: 'uint32', data: pwStart},
|
||||
{type: 'uint32', data: pwEnd},
|
||||
{type: DataType.uint32, data: kw},
|
||||
{type: DataType.uint32, data: sw},
|
||||
{type: DataType.uint32, data: pwStart},
|
||||
{type: DataType.uint32, data: pwEnd},
|
||||
);
|
||||
uniforms.push(
|
||||
{name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'},
|
||||
|
|
@ -82,8 +84,8 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
|
|||
const phEnd = attributes.pads[attributes.pads.length - 2];
|
||||
phStartEndNotZero = !!(phStart + phEnd);
|
||||
programUniforms.push(
|
||||
{type: 'uint32', data: kh}, {type: 'uint32', data: sh}, {type: 'uint32', data: phStart},
|
||||
{type: 'uint32', data: phEnd});
|
||||
{type: DataType.uint32, data: kh}, {type: DataType.uint32, data: sh}, {type: DataType.uint32, data: phStart},
|
||||
{type: DataType.uint32, data: phEnd});
|
||||
|
||||
uniforms.push(
|
||||
{name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'},
|
||||
|
|
@ -96,8 +98,8 @@ const getUniformAndPadInfo = <AttributeType extends AveragePoolAttributes|MaxPoo
|
|||
}
|
||||
const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
|
||||
programUniforms.push(
|
||||
{type: 'uint32', data: kernelStrides}, {type: 'uint32', data: attributes.pads},
|
||||
{type: 'uint32', data: attributes.strides});
|
||||
{type: DataType.uint32, data: kernelStrides}, {type: DataType.uint32, data: attributes.pads},
|
||||
{type: DataType.uint32, data: attributes.strides});
|
||||
uniforms.push(
|
||||
{name: 'kernelStrides', type: 'u32', length: kernelStrides.length},
|
||||
{name: 'pads', type: 'u32', length: attributes.pads.length},
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import {env} from 'onnxruntime-common';
|
||||
|
||||
import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common';
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
|
||||
|
||||
import {createTensorShapeVariables, outputVariable, ShaderHelper, UniformDataElementType, UniformsArrayType} from './common';
|
||||
|
|
@ -22,9 +22,8 @@ const createRangeProgramInfo = (start: number, limit: number, delta: number, dat
|
|||
const numElements = Math.abs(Math.ceil((limit - start) / delta));
|
||||
const outputShape: number[] = [numElements];
|
||||
const outputSize = numElements;
|
||||
const tensorDataType = tensorDataTypeEnumToString(dataType) as ProgramUniform['type'];
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: tensorDataType, data: start}, {type: tensorDataType, data: delta},
|
||||
{type: DataType.uint32, data: outputSize}, {type: dataType, data: start}, {type: dataType, data: delta},
|
||||
...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ export const createReduceSharedProgramInfo =
|
|||
getRunData: () => ({
|
||||
outputs: [{dims: outputShape, dataType: outputDataType}],
|
||||
dispatchGroup: {x: outputSize},
|
||||
programUniforms: [{type: 'uint32', data: reduceSize}]
|
||||
programUniforms: [{type: DataType.uint32, data: reduceSize}]
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ export const createReduceProgramInfo =
|
|||
outputs: [{dims: outputShape, dataType: outputDataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
|
||||
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
]
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -641,9 +642,9 @@ const createResizeProgramInfo =
|
|||
outputs: [{dims: outputShape, dataType: inputTensor.dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: outputSize},
|
||||
{type: 'float32', data: scales},
|
||||
{type: 'float32', data: roi},
|
||||
{type: DataType.uint32, data: outputSize},
|
||||
{type: DataType.float, data: scales},
|
||||
{type: DataType.float, data: roi},
|
||||
...createTensorShapeVariables(inputShape),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -88,10 +88,10 @@ const createSkipLayerNormProgramInfo =
|
|||
const components = getMaxComponents(hiddenSize);
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize},
|
||||
{type: 'uint32', data: components},
|
||||
{type: 'uint32', data: hiddenSize},
|
||||
{type: 'float32', data: attributes.epsilon},
|
||||
{type: DataType.uint32, data: outputSize},
|
||||
{type: DataType.uint32, data: components},
|
||||
{type: DataType.uint32, data: hiddenSize},
|
||||
{type: DataType.float, data: attributes.epsilon},
|
||||
];
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
||||
const uniformsArray: UniformsArrayType = [
|
||||
|
|
|
|||
|
|
@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
|
|||
];
|
||||
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs},
|
||||
{type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts},
|
||||
{type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps},
|
||||
...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)
|
||||
];
|
||||
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
// performance limitations when the reduced axis is long. Need to add
|
||||
// a optimized codepath for this.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -136,7 +137,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut
|
|||
getRunData: () => ({
|
||||
outputs: [{dims: shape, dataType: input.dataType}],
|
||||
dispatchGroup: {x: rows},
|
||||
programUniforms: [{type: 'uint32', data: packedCols}]
|
||||
programUniforms: [{type: DataType.uint32, data: packedCols}]
|
||||
}),
|
||||
getShaderSource,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -72,7 +73,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
|
|||
const outputsTensorInfo: TensorInfo[] = [];
|
||||
const outputShapes: number[][] = [];
|
||||
let previousSum = 0;
|
||||
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}];
|
||||
const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}];
|
||||
for (let i = 0; i < attributes.numOutputs; i++) {
|
||||
previousSum += attributes.splitSizes[i];
|
||||
sizeInSplitAxis[i] = previousSum;
|
||||
|
|
@ -82,7 +83,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
|
|||
outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
|
||||
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
|
||||
}
|
||||
programUniforms.push({type: 'uint32', data: sizeInSplitAxis});
|
||||
programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis});
|
||||
programUniforms.push(...createTensorShapeVariables(inputShape));
|
||||
outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape)));
|
||||
const getShaderSource = (shaderHelper: ShaderHelper) => `
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
|
|||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
|
||||
{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
],
|
||||
}),
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../../wasm-common';
|
||||
import {TensorView} from '../../tensor-view';
|
||||
import {ShapeUtil} from '../../util';
|
||||
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
|
||||
|
|
@ -65,7 +66,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
|
|||
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: outputSize},
|
||||
{type: DataType.uint32, data: outputSize},
|
||||
...createTensorShapeVariables(inputs[0].dims),
|
||||
...createTensorShapeVariables(outputShape),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ const createElementwiseProgramInfo =
|
|||
dispatchGroup:
|
||||
{x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)},
|
||||
{type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)},
|
||||
],
|
||||
})
|
||||
});
|
||||
|
|
|
|||
|
|
@ -98,8 +98,9 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
|
|||
outputs: [{dims: outputShape, dataType: outputDataType}],
|
||||
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
|
||||
programUniforms: [
|
||||
{type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA),
|
||||
...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape)
|
||||
{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC),
|
||||
...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB),
|
||||
...createTensorShapeVariables(outputShape)
|
||||
],
|
||||
}),
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {DataType} from '../../wasm-common';
|
||||
import {TensorView} from '../tensor-view';
|
||||
|
||||
import {ShaderHelper} from './ops/common';
|
||||
|
|
@ -26,7 +27,7 @@ export interface TensorInfo {
|
|||
}
|
||||
|
||||
export interface ProgramUniform {
|
||||
type: 'int32'|'float16'|'float32'|'uint32';
|
||||
type: DataType;
|
||||
data: number|readonly number[];
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue