diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b31fbc6255..5dda7425bd 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -59,6 +59,14 @@ class TensorViewImpl implements TensorView { return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount); } + getUint16Array(): Uint16Array { + if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) { + throw new Error('Invalid data type'); + } + const elementCount = ShapeUtil.size(this.dims); + return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount); + } + reshape(newDims: readonly number[]): TensorView { if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) { throw new Error('Invalid new shape'); diff --git a/js/web/lib/wasm/jsep/tensor-view.ts b/js/web/lib/wasm/jsep/tensor-view.ts index 5f1fdfa453..027c6f5660 100644 --- a/js/web/lib/wasm/jsep/tensor-view.ts +++ b/js/web/lib/wasm/jsep/tensor-view.ts @@ -48,6 +48,11 @@ export interface TensorView { */ getInt32Array(): Int32Array; + /** + * get a Uint16Array data view of the tensor data. tensor data must be on CPU. + */ + getUint16Array(): Uint16Array; + /** * create a new tensor view with the same data but different dimensions. */ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 1fc2732f24..168d644fe0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -3,11 +3,18 @@ import { DataType } from '../../../wasm-common'; import { TensorView } from '../../tensor-view'; -import { MAX_CLIP, MIN_CLIP, ShapeUtil } from '../../util'; +import { ShapeUtil } from '../../util'; import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key'; -import { ComputeContext, ProgramInfo } from '../types'; +import { ComputeContext, ProgramInfo, ProgramUniform } from '../types'; -import { inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType } from './common'; +import { + inputVariable, + outputVariable, + ShaderHelper, + tensorTypeToWsglValueType, + UniformDataElementType, + UniformsArrayType, +} from './common'; type BuiltinFunctionName = string; type ElementwiseCustomExpression = (expression: string) => string; @@ -20,6 +27,7 @@ const createElementwiseProgramShader = ( outputDataType: number, funcCall: ElementwiseFunctionCall, additionalImplementation?: string, + additionalUniformsType?: UniformsArrayType, ): string => { const vecSize = Math.ceil(datasize / 4); @@ -32,9 +40,13 @@ const createElementwiseProgramShader = ( const input = inputVariable('inputData', inputDataType, [vecSize], 4); const output = outputVariable('outputData', outputDataType, [vecSize], 4); + const uniforms: UniformsArrayType = [{ name: 'vec_size', type: 'u32' }]; + if (additionalUniformsType) { + uniforms.push(...additionalUniformsType); + } return ` - ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)} ${additionalImplementation ?? ''} @@ -53,24 +65,38 @@ const createElementwiseProgramInfo = ( additionalImplementation?: string, cacheKey?: string, outputDataType: number = input.dataType, -): ProgramInfo => ({ - name, - shaderCache: { hint: cacheKey, inputDependencies: ['type'] }, - getShaderSource: (shaderHelper) => - createElementwiseProgramShader( - shaderHelper, - ShapeUtil.size(input.dims), - input.dataType, - outputDataType, - funcCall, - additionalImplementation, - ), - getRunData: (inputTensors) => ({ - outputs: [{ dims: input.dims, dataType: outputDataType }], - dispatchGroup: { x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */) }, - programUniforms: [{ type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }], - }), -}); + additionalUniforms?: ProgramUniform[], + additionalUniformsType?: UniformsArrayType, +): ProgramInfo => { + const programUniforms: ProgramUniform[] = [ + { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4) }, + ]; + if (additionalUniforms) { + programUniforms.push(...additionalUniforms); + } + + return { + name, + shaderCache: { hint: cacheKey, inputDependencies: ['type'] }, + getShaderSource: (shaderHelper) => + createElementwiseProgramShader( + shaderHelper, + ShapeUtil.size(input.dims), + input.dataType, + outputDataType, + funcCall, + additionalImplementation, + additionalUniformsType, + ), + getRunData: (inputTensors) => ({ + outputs: [{ dims: input.dims, dataType: outputDataType }], + dispatchGroup: { + x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */), + }, + programUniforms, + }), + }; +}; export const abs = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Abs', 'abs')); @@ -139,24 +165,46 @@ export interface ClipAttributes extends AttributeWithCacheKey { } const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => { - const min = inputs.length >= 2 && inputs[1].data !== 0 ? inputs[1].getFloat32Array()[0] : MIN_CLIP; - const max = inputs.length >= 3 && inputs[2].data !== 0 ? inputs[2].getFloat32Array()[0] : MAX_CLIP; + let min: number; + let max: number; + const hasMin = inputs.length >= 2 && inputs[1].data !== 0; + const hasMax = inputs.length >= 3 && inputs[2].data !== 0; + + switch (inputs[0].dataType) { + case DataType.float: + min = hasMin ? inputs[1].getFloat32Array()[0] : -3.4028234663852886e38; + max = hasMax ? inputs[2].getFloat32Array()[0] : 3.4028234663852886e38; + break; + case DataType.float16: + min = hasMin ? inputs[1].getUint16Array()[0] : 64511; // uint16(64511) <-> float16(-65504.0) + max = hasMax ? inputs[2].getUint16Array()[0] : 31743; // uint16(31743) <-> float16(65504.0) + break; + default: + throw new Error('Unsupport data type'); + } + return createAttributeWithCacheKey({ min, max }); }; export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => { - const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs); + const attributes = clipAttributes ? clipAttributes : generateClipAttributesFromInputs(context.inputs); const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute( createElementwiseProgramInfo( context.inputs[0], 'Clip', - (a) => `clamp(${a}, clip_min_, clip_max_)`, - ` - const clip_min_: vec4<${dataType}> = vec4(${dataType}(${attributes.min})); - const clip_max_: vec4<${dataType}> = vec4(${dataType}(${attributes.max})); -`, + (a) => `clamp(${a}, vec4<${dataType}>(uniforms.min), vec4<${dataType}>(uniforms.max))`, + undefined, attributes.cacheKey, + undefined, + [ + { type: context.inputs[0].dataType, data: attributes.min }, + { type: context.inputs[0].dataType, data: attributes.max }, + ], + [ + { name: 'min', type: dataType as UniformDataElementType }, + { name: 'max', type: dataType as UniformDataElementType }, + ], ), { inputs: [0] }, ); @@ -302,9 +350,7 @@ export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttr context.inputs[0], 'HardSigmoid', (a) => - `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ - attributes.beta - })))`, + `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${attributes.beta})))`, undefined, attributes.cacheKey, ), diff --git a/js/web/test/data/ops/clip.jsonc b/js/web/test/data/ops/clip.jsonc new file mode 100644 index 0000000000..f2bcc2fd58 --- /dev/null +++ b/js/web/test/data/ops/clip.jsonc @@ -0,0 +1,248 @@ +[ + { + "name": "clip float32 type with min and max attributes", + "operator": "Clip", + "opset": { "domain": "", "version": 10 }, + "attributes": [ + { "name": "min", "type": "float", "data": 1.0 }, + { "name": "max", "type": "float", "data": 5.0 } + ], + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type with min attribute but no max attribute", + "operator": "Clip", + "opset": { "domain": "", "version": 10 }, + "attributes": [{ "name": "min", "type": "float", "data": 1.0 }], + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type without min and max attributes", + "operator": "Clip", + "opset": { "domain": "", "version": 10 }, + "attributes": [], + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type with min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float32" + }, + { + "data": [1.0], + "dims": [], + "type": "float32" + }, + { + "data": [5.0], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0], + "dims": [2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type with min input but no max input", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + }, + { + "data": [1.0], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float32 type without min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "clip float16 type with min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[2, 3]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [2, 3], + "type": "float16" + }, + { + "data": [1.0], + "dims": [], + "type": "float16" + }, + { + "data": [5.0], + "dims": [], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.0], + "dims": [2, 3], + "type": "float16" + } + ] + } + ] + }, + { + "name": "clip float16 type with min input but no max input", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + }, + { + "data": [1.0], + "dims": [], + "type": "float16" + } + ], + "outputs": [ + { + "data": [1.0, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + } + ] + } + ] + }, + { + "name": "clip float16 type without min and max inputs", + "operator": "Clip", + "cases": [ + { + "name": "T[3, 2]", + "inputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + } + ], + "outputs": [ + { + "data": [0.5, 1.4, 2.7, 3.3, 4.1, 5.8], + "dims": [3, 2], + "type": "float16" + } + ] + } + ] + } +]