mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
[js/webgpu] Support f16 uniform (#19098)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
8b4517218b
commit
a3f0e2422b
5 changed files with 56 additions and 26 deletions
|
|
@ -428,13 +428,26 @@ export class WebGpuBackend {
|
|||
return;
|
||||
}
|
||||
// https://www.w3.org/TR/WGSL/#alignof
|
||||
const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
|
||||
const sizeOfElement = v.type === 'float16' ? 2 : 4;
|
||||
let sizeOfVecOrMat;
|
||||
let baseAlignment;
|
||||
if (v.type === 'float16') {
|
||||
baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
|
||||
sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
|
||||
} else {
|
||||
baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
|
||||
sizeOfVecOrMat = 16;
|
||||
}
|
||||
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
|
||||
offsets.push(currentOffset);
|
||||
// When data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where N =
|
||||
// Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
|
||||
// SizeOf(vec4<i32|u32|f32>).
|
||||
currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
|
||||
// For non-float16 type, when data.length > 4, the uniform variable is of type array<vec4<i32|u32|f32>,N>, where
|
||||
// N = Math.ceil(data.length / 4) and SizeOf(vec4<i32|u32|f32>) = 16. The total byte length is N *
|
||||
// SizeOf(vec4<i32|u32|f32>). For float16 type, when data.length > 4, the uniform variable is of type
|
||||
// array<mat2x4<f16>,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4<f16>) = 16. The total byte
|
||||
// length is N * SizeOf(mat2x4<f16>).
|
||||
const elementPerVecOrMat = v.type === 'float16' ? 8 : 4;
|
||||
currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
|
||||
data.length * sizeOfElement;
|
||||
});
|
||||
|
||||
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
|
||||
|
|
@ -449,6 +462,9 @@ export class WebGpuBackend {
|
|||
new Int32Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else if (v.type === 'uint32') {
|
||||
new Uint32Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else if (v.type === 'float16') {
|
||||
// TODO: use Float16Array.
|
||||
new Uint16Array(arrayBuffer, offset, data.length).set(data);
|
||||
} else {
|
||||
new Float32Array(arrayBuffer, offset, data.length).set(data);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -330,18 +330,28 @@ export const sumVector = (name: string, components: number) => {
|
|||
* @param name - the name of variable.
|
||||
* @param index - the index of variable element.
|
||||
* @param length - the length of variable.
|
||||
* @param type - the type of variable, optional.
|
||||
*/
|
||||
export const getElementAt = (name: string, index: number|string, length: number): string => {
|
||||
if (name.startsWith('uniforms.') && length > 4) {
|
||||
if (typeof (index) === 'string') {
|
||||
return `${name}[(${index}) / 4][(${index}) % 4]`;
|
||||
} else {
|
||||
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
|
||||
}
|
||||
} else {
|
||||
return length > 1 ? `${name}[${index}]` : name;
|
||||
}
|
||||
};
|
||||
export const getElementAt =
|
||||
(name: string, index: number|string, length: number, type?: UniformDataElementType): string => {
|
||||
if (name.startsWith('uniforms.') && length > 4) {
|
||||
if (typeof (index) === 'string') {
|
||||
if (type === 'f16') {
|
||||
return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`;
|
||||
} else {
|
||||
return `${name}[(${index}) / 4][(${index}) % 4]`;
|
||||
}
|
||||
} else {
|
||||
if (type === 'f16') {
|
||||
return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`;
|
||||
} else {
|
||||
return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return length > 1 ? `${name}[${index}]` : name;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* A helper function to get a IndicesHelper for a given input or output.
|
||||
|
|
@ -688,7 +698,7 @@ export const internalVariable =
|
|||
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
|
||||
createIndicesHelper(name, type, shapeOrRank, 'internal', components);
|
||||
|
||||
export type UniformDataElementType = 'u32'|'f32'|'i32';
|
||||
export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32';
|
||||
export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;
|
||||
|
||||
/**
|
||||
|
|
@ -861,7 +871,11 @@ class ShaderHelperImpl implements ShaderHelper {
|
|||
const uniformSnippets: string[] = [];
|
||||
for (const {name, type, length} of this.uniforms) {
|
||||
if (length && length > 4) {
|
||||
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
|
||||
if (type === 'f16') {
|
||||
uniformSnippets.push(`@align(16) ${name}:array<mat2x4<${type}>, ${Math.ceil(length / 8)}>`);
|
||||
} else {
|
||||
uniformSnippets.push(`${name}:array<vec4<${type}>, ${Math.ceil(length / 4)}>`);
|
||||
}
|
||||
} else {
|
||||
const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
|
||||
uniformSnippets.push(`${name}:${typeTemp}`);
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
|
|||
if (!inputs || inputs.length < 1) {
|
||||
throw new Error('Too few inputs');
|
||||
}
|
||||
if (inputs[0].dataType !== DataType.float) {
|
||||
throw new Error('Input type must be float.');
|
||||
if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) {
|
||||
throw new Error('Input type must be float or float16.');
|
||||
}
|
||||
|
||||
if (inputs.length >= 2) {
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ export interface TensorInfo {
|
|||
}
|
||||
|
||||
export interface ProgramUniform {
|
||||
type: 'int32'|'float32'|'uint32';
|
||||
type: 'int32'|'float16'|'float32'|'uint32';
|
||||
data: number|readonly number[];
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
2,
|
||||
10,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
(*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Pad);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
12,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
|
|
@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
17,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
|
|
@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
18,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
|
|
@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
19,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes())
|
||||
.InputMemoryType(OrtMemTypeCPU, 1)
|
||||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3),
|
||||
|
|
|
|||
Loading…
Reference in a new issue