[js/webgpu] Support f16 uniform (#19098)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Xu Xing 2024-01-26 08:58:22 +08:00 committed by GitHub
parent 8b4517218b
commit a3f0e2422b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 56 additions and 26 deletions

View file

@ -428,13 +428,26 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
const sizeOfElement = v.type === 'float16' ? 2 : 4;
let sizeOfVecOrMat;
let baseAlignment;
if (v.type === 'float16') {
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
} else {
baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
sizeOfVecOrMat = 16;
}
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
offsets.push(currentOffset);
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>).
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
// For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
// N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
// length is N * SizeOf(mat2x4<f16>).
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
data.length * sizeOfElement;
});
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
@ -449,6 +462,9 @@ export class WebGpuBackend {
new Int32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'uint32') {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === 'float16') {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else {
new Float32Array(arrayBuffer, offset, data.length).set(data);
}

View file

@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => {
* @param name - the name of variable.
* @param index - the index of variable element.
* @param length - the length of variable.
* @param type - the type of variable, optional.
*/
export const getElementAt = (name: string, index: number|string, length: number): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
return `${name}[(${index}) / 4][(${index}) % 4]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
} else {
return length > 1 ? `${name}[${index}]` : name;
}
};
export const getElementAt =
(name: string, index: number|string, length: number, type?: UniformDataElementType): string => {
if (name.startsWith('uniforms.') && length > 4) {
if (typeof (index) === 'string') {
if (type === 'f16') {
return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
} else {
return `${name}[(${index}) / 4][(${index}) % 4]`;
}
} else {
if (type === 'f16') {
return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`;
} else {
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
}
}
} else {
return length > 1 ? `${name}[${index}]` : name;
}
};
/**
* A helper function to get a IndicesHelper for a given input or output.
@ -688,7 +698,7 @@ export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);
export type UniformDataElementType = 'u32'|'f32'|'i32';
export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32';
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;
/**
@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper {
const uniformSnippets: string[] = [];
for (const {name, type, length} of this.uniforms) {
if (length && length > 4) {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
if (type === 'f16') {
uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
} else {
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
}
} else {
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
uniformSnippets.push(`${name}:${typeTemp}`);

View file

@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length < 1) {
throw new Error('Too few inputs');
}
if (inputs[0].dataType !== DataType.float) {
throw new Error('Input type must be float.');
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
throw new Error('Input type must be float or float16.');
}
if (inputs.length >= 2) {

View file

@ -24,7 +24,7 @@ export interface TensorInfo {
}
export interface ProgramUniform {
type: 'int32'|'float32'|'uint32';
type: 'int32'|'float16'|'float32'|'uint32';
data: number|readonly number[];
}

View file

@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
2,
10,
kJsExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
Pad);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
17,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
18,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),
@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX(
19,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T", JsepSupportedFloatTypes())
.InputMemoryType(OrtMemTypeCPU, 1)
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3),