[js/webgpu] Enable pad f16 uniform (#21691)

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
Xu Xing 2024-08-26 22:58:48 +08:00 committed by GitHub
parent 2877de73e1
commit d9c57ac7db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 29 additions and 8 deletions

View file

@ -593,7 +593,6 @@ export class WebGpuBackend {
} else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === DataType.float16) {
// TODO: use Float16Array.
new Uint16Array(arrayBuffer, offset, data.length).set(data);
} else if (v.type === DataType.float) {
new Float32Array(arrayBuffer, offset, data.length).set(data);

View file

@ -3,8 +3,9 @@
import { Env } from 'onnxruntime-common';
import { calculateTensorSizeInBytes, DataType } from '../wasm-common';
import type { OrtWasmModule } from '../wasm-types';
import { DataType, calculateTensorSizeInBytes } from '../wasm-common';
import { WebGpuBackend } from './backend-webgpu';
import { LOG_DEBUG } from './log';
@ -22,6 +23,14 @@ class TensorViewImpl implements TensorView {
public readonly dims: readonly number[],
) {}
getUint16Array(): Uint16Array {
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getFloat32Array(): Float32Array {
if (this.dataType !== DataType.float) {
throw new Error('Invalid data type');

View file

@ -28,6 +28,11 @@ export interface TensorView {
readonly dataType: number;
readonly dims: readonly number[];
/**
* get a Float16Array data view of the tensor data. tensor data must be on CPU.
*/
getUint16Array(): Uint16Array;
/**
* get a Float32Array data view of the tensor data. tensor data must be on CPU.
*/

View file

@ -165,8 +165,10 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
{ type: DataType.uint32, data: outputSize },
{ type: DataType.int32, data: attributes.pads },
];
const isValueFromInput = inputs.length >= 3 && inputs[2].data;
if (attributes.mode === 0) {
programUniforms.push({ type: inputs[0].dataType, data: attributes.value });
programUniforms.push({ type: isValueFromInput ? inputs[2].dataType : DataType.float, data: attributes.value });
}
programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape));
@ -182,7 +184,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
{ name: 'pads', type: 'i32', length: attributes.pads.length },
];
if (attributes.mode === 0) {
uniforms.push({ name: 'constant_value', type: dataType as UniformDataElementType });
uniforms.push({ name: 'constant_value', type: (isValueFromInput ? dataType : 'f32') as UniformDataElementType });
}
return `
@ -200,7 +202,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
return {
name: 'Pad',
shaderCache: { hint: `${attributes.mode}`, inputDependencies },
shaderCache: { hint: `${attributes.mode}${isValueFromInput}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
dispatchGroup: { x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */) },
@ -213,7 +215,12 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr
const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => {
if (inputs.length > 1) {
const bigInt64Pads = inputs[1].getBigInt64Array();
const value = inputs.length >= 3 && inputs[2].data ? inputs[2].getFloat32Array()[0] : 0.0;
const value =
inputs.length >= 3 && inputs[2].data
? inputs[2].dataType === DataType.float16
? inputs[2].getUint16Array()[0]
: inputs[2].getFloat32Array()[0]
: 0.0;
const inputRank = inputs[0].dims.length;
const updatePads = new Int32Array(2 * inputRank).fill(0);

View file

@ -1,6 +1,6 @@
[
{
"name": "constant 2D float16",
"name": "constant 2D float16 v10",
"operator": "Pad",
"opset": { "domain": "", "version": 10 },
"attributes": [
@ -33,7 +33,7 @@
]
},
{
"name": "constant 2D float16",
"name": "constant 2D float16 v19",
"operator": "Pad",
"opset": { "domain": "", "version": 19 },
"attributes": [{ "name": "mode", "data": "constant", "type": "string" }],

View file

@ -1385,6 +1385,7 @@
"reduce-min.jsonc",
"relu.jsonc",
"gelu.jsonc",
"pad_f16.jsonc",
"pad.jsonc",
"pad-big.jsonc",
"pow.jsonc",