[js/webgpu] support float16 for Clip (#21584)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
xhcao 2024-08-29 04:19:20 +08:00 committed by GitHub
parent 59114227fd
commit 3bfb5e4f62
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 340 additions and 33 deletions

View file

@ -59,6 +59,14 @@ class TensorViewImpl implements TensorView {
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getUint16Array(): Uint16Array {
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
}
reshape(newDims: readonly number[]): TensorView {
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
throw new Error('Invalid new shape');

View file

@ -48,6 +48,11 @@ export interface TensorView {
*/
getInt32Array(): Int32Array;
/**
* get a Uint16Array data view of the tensor data. tensor data must be on CPU.
*/
getUint16Array(): Uint16Array;
/**
* create a new tensor view with the same data but different dimensions.
*/

View file

@ -3,11 +3,18 @@
import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { MAX_CLIP, MIN_CLIP, ShapeUtil } from '../../util';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInfo } from '../types';
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType } from './common';
import {
inputVariable,
outputVariable,
ShaderHelper,
tensorTypeToWsglValueType,
UniformDataElementType,
UniformsArrayType,
} from './common';
type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
@ -20,6 +27,7 @@ const createElementwiseProgramShader = (
outputDataType: number,
funcCall: ElementwiseFunctionCall,
additionalImplementation?: string,
additionalUniformsType?: UniformsArrayType,
): string => {
const vecSize = Math.ceil(datasize / 4);
@ -32,9 +40,13 @@ const createElementwiseProgramShader = (
const input = inputVariable('inputData', inputDataType, [vecSize], 4);
const output = outputVariable('outputData', outputDataType, [vecSize], 4);
const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }];
if (additionalUniformsType) {
uniforms.push(...additionalUniformsType);
}
return `
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
${additionalImplementation ?? ''}
@ -53,24 +65,38 @@ const createElementwiseProgramInfo = (
additionalImplementation?: string,
cacheKey?: string,
outputDataType: number = input.dataType,
): ProgramInfo => ({
name,
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
getShaderSource: (shaderHelper) =>
createElementwiseProgramShader(
shaderHelper,
ShapeUtil.size(input.dims),
input.dataType,
outputDataType,
funcCall,
additionalImplementation,
),
getRunData: (inputTensors) => ({
outputs: [{ dims: input.dims, dataType: outputDataType }],
dispatchGroup: { x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */) },
programUniforms: [{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }],
}),
});
additionalUniforms?: ProgramUniform[],
additionalUniformsType?: UniformsArrayType,
): ProgramInfo => {
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) },
];
if (additionalUniforms) {
programUniforms.push(...additionalUniforms);
}
return {
name,
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
getShaderSource: (shaderHelper) =>
createElementwiseProgramShader(
shaderHelper,
ShapeUtil.size(input.dims),
input.dataType,
outputDataType,
funcCall,
additionalImplementation,
additionalUniformsType,
),
getRunData: (inputTensors) => ({
outputs: [{ dims: input.dims, dataType: outputDataType }],
dispatchGroup: {
x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */),
},
programUniforms,
}),
};
};
export const abs = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs'));
@ -139,24 +165,46 @@ export interface ClipAttributes extends AttributeWithCacheKey {
}
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = inputs.length >= 2 && inputs[1].data !== 0 ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = inputs.length >= 3 && inputs[2].data !== 0 ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
let min: number;
let max: number;
const hasMin = inputs.length >= 2 && inputs[1].data !== 0;
const hasMax = inputs.length >= 3 && inputs[2].data !== 0;
switch (inputs[0].dataType) {
case DataType.float:
min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38;
max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38;
break;
case DataType.float16:
min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0)
max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0)
break;
default:
throw new Error('Unsupport data type');
}
return createAttributeWithCacheKey({ min, max });
};
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
context.inputs[0],
'Clip',
(a) => `clamp(${a}, clip_min_, clip_max_)`,
`
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
`,
(a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`,
undefined,
attributes.cacheKey,
undefined,
[
{ type: context.inputs[0].dataType, data: attributes.min },
{ type: context.inputs[0].dataType, data: attributes.max },
],
[
{ name: 'min', type: dataType as UniformDataElementType },
{ name: 'max', type: dataType as UniformDataElementType },
],
),
{ inputs: [0] },
);
@ -302,9 +350,7 @@ export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttr
context.inputs[0],
'HardSigmoid',
(a) =>
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
attributes.beta
})))`,
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`,
undefined,
attributes.cacheKey,
),

View file

@ -0,0 +1,248 @@
[
{
"name": "clip float32 type with min and max attributes",
"operator": "Clip",
"opset": { "domain": "", "version": 10 },
"attributes": [
{ "name": "min", "type": "float", "data": 1.0 },
{ "name": "max", "type": "float", "data": 5.0 }
],
"cases": [
{
"name": "T[2, 3]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0],
"dims": [2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "clip float32 type with min attribute but no max attribute",
"operator": "Clip",
"opset": { "domain": "", "version": 10 },
"attributes": [{ "name": "min", "type": "float", "data": 1.0 }],
"cases": [
{
"name": "T[2, 3]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "clip float32 type without min and max attributes",
"operator": "Clip",
"opset": { "domain": "", "version": 10 },
"attributes": [],
"cases": [
{
"name": "T[2, 3]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float32"
}
],
"outputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "clip float32 type with min and max inputs",
"operator": "Clip",
"cases": [
{
"name": "T[2, 3]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float32"
},
{
"data": [1.0],
"dims": [],
"type": "float32"
},
{
"data": [5.0],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0],
"dims": [2, 3],
"type": "float32"
}
]
}
]
},
{
"name": "clip float32 type with min input but no max input",
"operator": "Clip",
"cases": [
{
"name": "T[3, 2]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float32"
},
{
"data": [1.0],
"dims": [],
"type": "float32"
}
],
"outputs": [
{
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float32"
}
]
}
]
},
{
"name": "clip float32 type without min and max inputs",
"operator": "Clip",
"cases": [
{
"name": "T[3, 2]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float32"
}
],
"outputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float32"
}
]
}
]
},
{
"name": "clip float16 type with min and max inputs",
"operator": "Clip",
"cases": [
{
"name": "T[2, 3]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [2, 3],
"type": "float16"
},
{
"data": [1.0],
"dims": [],
"type": "float16"
},
{
"data": [5.0],
"dims": [],
"type": "float16"
}
],
"outputs": [
{
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0],
"dims": [2, 3],
"type": "float16"
}
]
}
]
},
{
"name": "clip float16 type with min input but no max input",
"operator": "Clip",
"cases": [
{
"name": "T[3, 2]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float16"
},
{
"data": [1.0],
"dims": [],
"type": "float16"
}
],
"outputs": [
{
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float16"
}
]
}
]
},
{
"name": "clip float16 type without min and max inputs",
"operator": "Clip",
"cases": [
{
"name": "T[3, 2]",
"inputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float16"
}
],
"outputs": [
{
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
"dims": [3, 2],
"type": "float16"
}
]
}
]
}
]