2024-08-09 21:44:19 +00:00
|
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
|
|
|
// Licensed under the MIT License.
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
import { DataType } from '../../../wasm-common';
|
|
|
|
|
import { TensorView } from '../../tensor-view';
|
|
|
|
|
import { ShapeUtil } from '../../util';
|
|
|
|
|
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
|
|
|
|
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
|
2024-08-09 21:44:19 +00:00
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
import {
|
|
|
|
|
createTensorShapeVariables,
|
|
|
|
|
getMaxComponents,
|
|
|
|
|
inputVariable,
|
|
|
|
|
outputVariable,
|
|
|
|
|
ShaderHelper,
|
|
|
|
|
UniformsArrayType,
|
|
|
|
|
} from './common';
|
2024-08-09 21:44:19 +00:00
|
|
|
|
|
|
|
|
export interface DequantizeLinerAttributes extends AttributeWithCacheKey {
|
|
|
|
|
axis: number;
|
|
|
|
|
blockSize: number;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const validateInputs = (inputs: readonly TensorView[], attributes: DequantizeLinerAttributes): void => {
|
|
|
|
|
if (inputs.length < 2 || inputs.length > 3) {
|
|
|
|
|
throw new Error('DequantizeLinear requires 2 or 3 inputs.');
|
|
|
|
|
}
|
|
|
|
|
if (inputs.length === 3 && inputs[1].dims === inputs[2].dims) {
|
|
|
|
|
throw new Error('x-scale and x-zero-point must have the same shape.');
|
|
|
|
|
}
|
|
|
|
|
if (inputs.length === 3 && inputs[0].dataType !== inputs[2].dataType) {
|
|
|
|
|
throw new Error('x and x-zero-point must have the same data type.');
|
|
|
|
|
}
|
|
|
|
|
if (inputs[0].dataType === DataType.int32 && inputs.length > 2) {
|
|
|
|
|
throw new Error('In the case of dequantizing int32 there is no zero point.');
|
|
|
|
|
}
|
|
|
|
|
if (inputs[1].dims.length !== 0 && inputs[1].dims.length !== 1 && inputs[1].dims.length !== inputs[0].dims.length) {
|
|
|
|
|
throw new Error('scale input must be a scalar, a 1D tensor, or have the same rank as the input tensor.');
|
|
|
|
|
}
|
|
|
|
|
// validate scale and zero-point input shapes
|
|
|
|
|
if (inputs.length > 2) {
|
|
|
|
|
// zero-point input type should be the same as input data type.
|
|
|
|
|
if (inputs[0].dataType !== inputs[2].dataType) {
|
|
|
|
|
throw new Error('x and x-zero-point must have the same data type.');
|
|
|
|
|
}
|
|
|
|
|
// Scale and zero-point inputs must have the same shape
|
|
|
|
|
if (inputs[1].dims.length !== inputs[2].dims.length) {
|
|
|
|
|
throw new Error('scale and zero-point inputs must have the same rank.');
|
|
|
|
|
}
|
|
|
|
|
if (!inputs[1].dims.map((d, i) => d === inputs[2].dims[i]).reduce((a, b) => a && b, true)) {
|
|
|
|
|
throw new Error('scale and zero-point inputs must have the same shape.');
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// Validate blockSize
|
|
|
|
|
if (attributes.blockSize > 0) {
|
|
|
|
|
// Block qunatization
|
|
|
|
|
if (inputs[1].dims.length === 0 || (inputs[1].dims.length === 1 && inputs[1].dims[0] === 1)) {
|
|
|
|
|
throw new Error('blockSize must be set only for block quantization.');
|
|
|
|
|
}
|
2024-08-14 23:51:22 +00:00
|
|
|
if (
|
|
|
|
|
!inputs[1].dims.map((d, i) => i === attributes.axis || d === inputs[0].dims[i]).reduce((a, b) => a && b, true)
|
|
|
|
|
) {
|
2024-08-09 21:44:19 +00:00
|
|
|
throw new Error('For block qunatization, scale input shape to match the input shape except for the axis');
|
|
|
|
|
}
|
|
|
|
|
// Scale input rank should be same as the input rank
|
|
|
|
|
if (inputs[1].dims.length !== inputs[0].dims.length) {
|
|
|
|
|
throw new Error('For block qunatization the scale input rank must be the same as the x rank.');
|
|
|
|
|
}
|
|
|
|
|
const dI = inputs[0].dims[attributes.axis];
|
|
|
|
|
const si = inputs[1].dims[attributes.axis];
|
|
|
|
|
if (attributes.blockSize < Math.ceil(dI / si) || attributes.blockSize > Math.ceil(dI / (si - 1) - 1)) {
|
|
|
|
|
throw new Error('blockSize must be with in the range [ceil(dI / Si), ceil(dI / (Si - 1) - 1)].');
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
2024-08-14 23:51:22 +00:00
|
|
|
const createDequantizeLinearProgramInfo = (
|
|
|
|
|
inputs: readonly TensorView[],
|
|
|
|
|
attributes: DequantizeLinerAttributes,
|
|
|
|
|
): ProgramInfo => {
|
|
|
|
|
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputs[0].dims.length);
|
|
|
|
|
const inputType = inputs[0].dataType;
|
|
|
|
|
const isSigned = inputType === DataType.int8;
|
|
|
|
|
const outputShape = inputs[0].dims; // output shape is same as the input shape
|
|
|
|
|
const dataType = inputs[1].dataType; // output type is same as the the scale input type
|
|
|
|
|
const outputSize = ShapeUtil.size(outputShape);
|
|
|
|
|
const isPacked = inputType === DataType.int8 || inputType === DataType.uint8;
|
|
|
|
|
const inputShape = isPacked ? [Math.ceil(ShapeUtil.size(inputs[0].dims) / 4)] : inputs[0].dims;
|
|
|
|
|
const scaleShape = inputs[1].dims;
|
|
|
|
|
const zeroPointInput = inputs.length > 2 ? inputs[2] : undefined;
|
|
|
|
|
const zeroPointShape = zeroPointInput
|
|
|
|
|
? isPacked
|
|
|
|
|
? [Math.ceil(ShapeUtil.size(zeroPointInput.dims) / 4)]
|
|
|
|
|
: zeroPointInput.dims
|
|
|
|
|
: undefined;
|
|
|
|
|
// Scales input is a scaler for per-tensor/per-layer quantization, 1-D tensor for per-axis quantization
|
|
|
|
|
// or tensor with same rank as input for blocked quantization.
|
|
|
|
|
const perLayerQuantization = scaleShape.length === 0 || (scaleShape.length === 1 && scaleShape[0] === 1);
|
|
|
|
|
const perAxisQuantization = perLayerQuantization === false && scaleShape.length === 1;
|
|
|
|
|
// Left unnecessary commented-out assignment for documentation
|
|
|
|
|
// const blockQuantization = perLayerQuantization === false && perAxisQuantization === false;
|
|
|
|
|
const maxComponents = getMaxComponents(outputSize);
|
|
|
|
|
const useComponents = perLayerQuantization && (!isPacked || maxComponents === 4);
|
|
|
|
|
const components = useComponents ? maxComponents : 1;
|
|
|
|
|
const inputComponent = useComponents && !isPacked ? maxComponents : 1;
|
|
|
|
|
const input = inputVariable('input', isPacked ? DataType.uint32 : inputType, inputShape.length, inputComponent);
|
|
|
|
|
const scale = inputVariable('scale', dataType, scaleShape.length);
|
|
|
|
|
const zeroPoint = zeroPointInput
|
|
|
|
|
? inputVariable('zero_point', isPacked ? DataType.uint32 : inputType, zeroPointShape!.length)
|
|
|
|
|
: undefined;
|
|
|
|
|
const output = outputVariable('output', dataType, outputShape.length, components);
|
|
|
|
|
const inputVariables = [input, scale];
|
|
|
|
|
if (zeroPoint) {
|
|
|
|
|
inputVariables.push(zeroPoint);
|
|
|
|
|
}
|
|
|
|
|
const inputShapes = [inputShape, scaleShape];
|
|
|
|
|
if (zeroPointInput) {
|
|
|
|
|
inputShapes.push(zeroPointShape!);
|
|
|
|
|
}
|
|
|
|
|
const programUniforms: ProgramUniform[] = [
|
|
|
|
|
{ type: DataType.uint32, data: outputSize / components },
|
|
|
|
|
{ type: DataType.uint32, data: axis },
|
|
|
|
|
{ type: DataType.uint32, data: attributes.blockSize },
|
|
|
|
|
...createTensorShapeVariables(...inputShapes, outputShape),
|
|
|
|
|
];
|
|
|
|
|
const getShaderSource = (shaderHelper: ShaderHelper) => {
|
|
|
|
|
const uniforms: UniformsArrayType = [
|
|
|
|
|
{ name: 'output_size', type: 'u32' },
|
|
|
|
|
{ name: 'axis', type: 'u32' },
|
|
|
|
|
{ name: 'block_size', type: 'u32' },
|
|
|
|
|
];
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
|
|
|
|
|
${shaderHelper.mainStart()}
|
|
|
|
|
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
|
|
|
|
|
let output_indices = ${output.offsetToIndices('global_idx')};
|
|
|
|
|
|
|
|
|
|
// Set input x
|
|
|
|
|
${(() => {
|
2024-08-14 23:51:22 +00:00
|
|
|
if (isPacked) {
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
let input = ${input.getByOffset('global_idx / 4')};
|
|
|
|
|
let x_vec = ${isSigned ? 'unpack4xI8(input)' : 'unpack4xU8(input)'};
|
|
|
|
|
let x_value = ${components === 1 ? 'x_vec[global_idx % 4]' : 'x_vec'};`;
|
2024-08-14 23:51:22 +00:00
|
|
|
} else {
|
|
|
|
|
return `let x_value = ${input.getByOffset('global_idx')};`;
|
|
|
|
|
}
|
|
|
|
|
})()};
|
2024-08-09 21:44:19 +00:00
|
|
|
|
|
|
|
|
// Set scale input
|
|
|
|
|
${(() => {
|
2024-08-14 23:51:22 +00:00
|
|
|
if (perLayerQuantization) {
|
|
|
|
|
// scale input is a scalar ()
|
|
|
|
|
return `let scale_value= ${scale.getByOffset('0')}`;
|
|
|
|
|
} else if (perAxisQuantization) {
|
|
|
|
|
// scale input is a 1D tensor
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
let scale_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
|
|
|
|
|
let scale_value= ${scale.getByOffset('scale_index')};`;
|
2024-08-14 23:51:22 +00:00
|
|
|
} else {
|
|
|
|
|
// Block quantization. Scale input rank is same as input/output rank.
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
var scale_indices: ${scale.type.indices} = output_indices;
|
|
|
|
|
let index = ${scale.indicesGet('scale_indices', 'uniforms.axis')} / uniforms.block_size;
|
|
|
|
|
${scale.indicesSet('scale_indices', 'uniforms.axis', 'index')};
|
|
|
|
|
let scale_value= ${scale.getByIndices('scale_indices')};`;
|
2024-08-14 23:51:22 +00:00
|
|
|
}
|
|
|
|
|
})()};
|
2024-08-09 21:44:19 +00:00
|
|
|
|
|
|
|
|
// Set zero-point input
|
|
|
|
|
${(() => {
|
2024-08-14 23:51:22 +00:00
|
|
|
if (zeroPoint) {
|
|
|
|
|
if (perLayerQuantization) {
|
|
|
|
|
// zero-point input is a scalar
|
|
|
|
|
if (isPacked) {
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
let zero_point_input = ${zeroPoint.getByOffset('0')};
|
|
|
|
|
let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
|
|
|
|
|
let zero_point_value= zero_point_vec[0]`;
|
2024-08-14 23:51:22 +00:00
|
|
|
} else {
|
|
|
|
|
return `let zero_point_value = ${zeroPoint.getByOffset('0')}`;
|
|
|
|
|
}
|
|
|
|
|
} else if (perAxisQuantization) {
|
|
|
|
|
// zero-point input is a 1D tensor
|
|
|
|
|
if (isPacked) {
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
|
|
|
|
|
let zero_point_input = ${zeroPoint.getByOffset('zero_point_index / 4')};
|
|
|
|
|
let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
|
|
|
|
|
let zero_point_value = zero_point_vec[zero_point_index % 4]`;
|
2024-08-14 23:51:22 +00:00
|
|
|
} else {
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
let zero_point_index = ${output.indicesGet('output_indices', 'uniforms.axis')};
|
|
|
|
|
let zero_point_value = ${zeroPoint.getByOffset('zero_point_index')};`;
|
2024-08-14 23:51:22 +00:00
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// BlockedQuantization. The zero-point input shape is same as the input shape except along axis.
|
|
|
|
|
if (isPacked) {
|
|
|
|
|
return `
|
2024-08-09 21:44:19 +00:00
|
|
|
let zero_point_offset = ${scale.indicesToOffset('scale_indices')};
|
|
|
|
|
let zero_point_input = ${zeroPoint.getByOffset('zero_point_offset / 4')};
|
|
|
|
|
let zero_point_vec = ${isSigned ? 'unpack4xI8(zero_point_input)' : 'unpack4xU8(zero_point_input)'};
|
|
|
|
|
let zero_point_value = zero_point_vec[zero_point_offset % 4];`;
|
2024-08-14 23:51:22 +00:00
|
|
|
} else {
|
|
|
|
|
return `let zero_point_value = ${zeroPoint.getByIndices('scale_indices')};`;
|
|
|
|
|
}
|
2024-08-09 21:44:19 +00:00
|
|
|
}
|
2024-08-14 23:51:22 +00:00
|
|
|
} else {
|
|
|
|
|
return `let zero_point_value = ${isPacked ? (isSigned ? 'i32' : 'u32') : input.type.value}(0);`;
|
2024-08-09 21:44:19 +00:00
|
|
|
}
|
2024-08-14 23:51:22 +00:00
|
|
|
})()};
|
2024-08-09 21:44:19 +00:00
|
|
|
// Compute and write output
|
|
|
|
|
${output.setByOffset('global_idx', `${output.type.value}(x_value - zero_point_value) * scale_value`)};
|
|
|
|
|
}`;
|
2024-08-14 23:51:22 +00:00
|
|
|
};
|
|
|
|
|
return {
|
|
|
|
|
name: 'DequantizeLinear',
|
|
|
|
|
shaderCache: {
|
|
|
|
|
hint: attributes.cacheKey,
|
|
|
|
|
inputDependencies: zeroPoint ? ['rank', 'rank', 'rank'] : ['rank', 'rank'],
|
|
|
|
|
},
|
|
|
|
|
getShaderSource,
|
|
|
|
|
getRunData: () => ({
|
|
|
|
|
outputs: [{ dims: outputShape, dataType }],
|
|
|
|
|
dispatchGroup: { x: Math.ceil(outputSize / components / 64), y: 1, z: 1 },
|
|
|
|
|
programUniforms,
|
|
|
|
|
}),
|
|
|
|
|
};
|
|
|
|
|
};
|
2024-08-09 21:44:19 +00:00
|
|
|
|
|
|
|
|
export const dequantizeLinear = (context: ComputeContext, attributes: DequantizeLinerAttributes): void => {
|
|
|
|
|
validateInputs(context.inputs, attributes);
|
|
|
|
|
context.compute(createDequantizeLinearProgramInfo(context.inputs, attributes));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
export const parseDequantizeLinearAttributes = (attributes: Record<string, unknown>): DequantizeLinerAttributes =>
|
2024-08-14 23:51:22 +00:00
|
|
|
createAttributeWithCacheKey({ axis: attributes.axis as number, blockSize: attributes.blockSize as number });
|