mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/webgpu] support float16 for Clip (#21584)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
59114227fd
commit
3bfb5e4f62
4 changed files with 340 additions and 33 deletions
|
|
@ -59,6 +59,14 @@ class TensorViewImpl implements TensorView {
|
|||
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
|
||||
}
|
||||
|
||||
getUint16Array(): Uint16Array {
|
||||
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
|
||||
throw new Error('Invalid data type');
|
||||
}
|
||||
const elementCount = ShapeUtil.size(this.dims);
|
||||
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
|
||||
}
|
||||
|
||||
reshape(newDims: readonly number[]): TensorView {
|
||||
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
|
||||
throw new Error('Invalid new shape');
|
||||
|
|
|
|||
|
|
@ -48,6 +48,11 @@ export interface TensorView {
|
|||
*/
|
||||
getInt32Array(): Int32Array;
|
||||
|
||||
/**
|
||||
* get a Uint16Array data view of the tensor data. tensor data must be on CPU.
|
||||
*/
|
||||
getUint16Array(): Uint16Array;
|
||||
|
||||
/**
|
||||
* create a new tensor view with the same data but different dimensions.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -3,11 +3,18 @@
|
|||
|
||||
import { DataType } from '../../../wasm-common';
|
||||
import { TensorView } from '../../tensor-view';
|
||||
import { MAX_CLIP, MIN_CLIP, ShapeUtil } from '../../util';
|
||||
import { ShapeUtil } from '../../util';
|
||||
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
|
||||
import { ComputeContext, ProgramInfo } from '../types';
|
||||
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
|
||||
|
||||
import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType } from './common';
|
||||
import {
|
||||
inputVariable,
|
||||
outputVariable,
|
||||
ShaderHelper,
|
||||
tensorTypeToWsglValueType,
|
||||
UniformDataElementType,
|
||||
UniformsArrayType,
|
||||
} from './common';
|
||||
|
||||
type BuiltinFunctionName = string;
|
||||
type ElementwiseCustomExpression = (expression: string) => string;
|
||||
|
|
@ -20,6 +27,7 @@ const createElementwiseProgramShader = (
|
|||
outputDataType: number,
|
||||
funcCall: ElementwiseFunctionCall,
|
||||
additionalImplementation?: string,
|
||||
additionalUniformsType?: UniformsArrayType,
|
||||
): string => {
|
||||
const vecSize = Math.ceil(datasize / 4);
|
||||
|
||||
|
|
@ -32,9 +40,13 @@ const createElementwiseProgramShader = (
|
|||
|
||||
const input = inputVariable('inputData', inputDataType, [vecSize], 4);
|
||||
const output = outputVariable('outputData', outputDataType, [vecSize], 4);
|
||||
const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }];
|
||||
if (additionalUniformsType) {
|
||||
uniforms.push(...additionalUniformsType);
|
||||
}
|
||||
|
||||
return `
|
||||
${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
|
||||
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
|
||||
|
||||
${additionalImplementation ?? ''}
|
||||
|
||||
|
|
@ -53,24 +65,38 @@ const createElementwiseProgramInfo = (
|
|||
additionalImplementation?: string,
|
||||
cacheKey?: string,
|
||||
outputDataType: number = input.dataType,
|
||||
): ProgramInfo => ({
|
||||
name,
|
||||
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
|
||||
getShaderSource: (shaderHelper) =>
|
||||
createElementwiseProgramShader(
|
||||
shaderHelper,
|
||||
ShapeUtil.size(input.dims),
|
||||
input.dataType,
|
||||
outputDataType,
|
||||
funcCall,
|
||||
additionalImplementation,
|
||||
),
|
||||
getRunData: (inputTensors) => ({
|
||||
outputs: [{ dims: input.dims, dataType: outputDataType }],
|
||||
dispatchGroup: { x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */) },
|
||||
programUniforms: [{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }],
|
||||
}),
|
||||
});
|
||||
additionalUniforms?: ProgramUniform[],
|
||||
additionalUniformsType?: UniformsArrayType,
|
||||
): ProgramInfo => {
|
||||
const programUniforms: ProgramUniform[] = [
|
||||
{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) },
|
||||
];
|
||||
if (additionalUniforms) {
|
||||
programUniforms.push(...additionalUniforms);
|
||||
}
|
||||
|
||||
return {
|
||||
name,
|
||||
shaderCache: { hint: cacheKey, inputDependencies: ['type'] },
|
||||
getShaderSource: (shaderHelper) =>
|
||||
createElementwiseProgramShader(
|
||||
shaderHelper,
|
||||
ShapeUtil.size(input.dims),
|
||||
input.dataType,
|
||||
outputDataType,
|
||||
funcCall,
|
||||
additionalImplementation,
|
||||
additionalUniformsType,
|
||||
),
|
||||
getRunData: (inputTensors) => ({
|
||||
outputs: [{ dims: input.dims, dataType: outputDataType }],
|
||||
dispatchGroup: {
|
||||
x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */),
|
||||
},
|
||||
programUniforms,
|
||||
}),
|
||||
};
|
||||
};
|
||||
|
||||
export const abs = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs'));
|
||||
|
|
@ -139,24 +165,46 @@ export interface ClipAttributes extends AttributeWithCacheKey {
|
|||
}
|
||||
|
||||
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
|
||||
const min = inputs.length >= 2 && inputs[1].data !== 0 ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
|
||||
const max = inputs.length >= 3 && inputs[2].data !== 0 ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
|
||||
let min: number;
|
||||
let max: number;
|
||||
const hasMin = inputs.length >= 2 && inputs[1].data !== 0;
|
||||
const hasMax = inputs.length >= 3 && inputs[2].data !== 0;
|
||||
|
||||
switch (inputs[0].dataType) {
|
||||
case DataType.float:
|
||||
min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38;
|
||||
max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38;
|
||||
break;
|
||||
case DataType.float16:
|
||||
min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0)
|
||||
max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0)
|
||||
break;
|
||||
default:
|
||||
throw new Error('Unsupport data type');
|
||||
}
|
||||
|
||||
return createAttributeWithCacheKey({ min, max });
|
||||
};
|
||||
|
||||
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
|
||||
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
|
||||
const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
|
||||
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
|
||||
context.compute(
|
||||
createElementwiseProgramInfo(
|
||||
context.inputs[0],
|
||||
'Clip',
|
||||
(a) => `clamp(${a}, clip_min_, clip_max_)`,
|
||||
`
|
||||
const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min}));
|
||||
const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max}));
|
||||
`,
|
||||
(a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`,
|
||||
undefined,
|
||||
attributes.cacheKey,
|
||||
undefined,
|
||||
[
|
||||
{ type: context.inputs[0].dataType, data: attributes.min },
|
||||
{ type: context.inputs[0].dataType, data: attributes.max },
|
||||
],
|
||||
[
|
||||
{ name: 'min', type: dataType as UniformDataElementType },
|
||||
{ name: 'max', type: dataType as UniformDataElementType },
|
||||
],
|
||||
),
|
||||
{ inputs: [0] },
|
||||
);
|
||||
|
|
@ -302,9 +350,7 @@ export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttr
|
|||
context.inputs[0],
|
||||
'HardSigmoid',
|
||||
(a) =>
|
||||
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
|
||||
attributes.beta
|
||||
})))`,
|
||||
`max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`,
|
||||
undefined,
|
||||
attributes.cacheKey,
|
||||
),
|
||||
|
|
|
|||
248
js/web/test/data/ops/clip.jsonc
Normal file
248
js/web/test/data/ops/clip.jsonc
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
[
|
||||
{
|
||||
"name": "clip float32 type with min and max attributes",
|
||||
"operator": "Clip",
|
||||
"opset": { "domain": "", "version": 10 },
|
||||
"attributes": [
|
||||
{ "name": "min", "type": "float", "data": 1.0 },
|
||||
{ "name": "max", "type": "float", "data": 5.0 }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2, 3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float32 type with min attribute but no max attribute",
|
||||
"operator": "Clip",
|
||||
"opset": { "domain": "", "version": 10 },
|
||||
"attributes": [{ "name": "min", "type": "float", "data": 1.0 }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2, 3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float32 type without min and max attributes",
|
||||
"operator": "Clip",
|
||||
"opset": { "domain": "", "version": 10 },
|
||||
"attributes": [],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2, 3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float32 type with min and max inputs",
|
||||
"operator": "Clip",
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2, 3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [5.0],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0],
|
||||
"dims": [2, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float32 type with min input but no max input",
|
||||
"operator": "Clip",
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3, 2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float32"
|
||||
},
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float32 type without min and max inputs",
|
||||
"operator": "Clip",
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3, 2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float16 type with min and max inputs",
|
||||
"operator": "Clip",
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[2, 3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [2, 3],
|
||||
"type": "float16"
|
||||
},
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [],
|
||||
"type": "float16"
|
||||
},
|
||||
{
|
||||
"data": [5.0],
|
||||
"dims": [],
|
||||
"type": "float16"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0],
|
||||
"dims": [2, 3],
|
||||
"type": "float16"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float16 type with min input but no max input",
|
||||
"operator": "Clip",
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3, 2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float16"
|
||||
},
|
||||
{
|
||||
"data": [1.0],
|
||||
"dims": [],
|
||||
"type": "float16"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float16"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "clip float16 type without min and max inputs",
|
||||
"operator": "Clip",
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3, 2]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float16"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8],
|
||||
"dims": [3, 2],
|
||||
"type": "float16"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
Loading…
Reference in a new issue