mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
This CL make WebGPU backend support subgroup features and thus allow using subgroup optimizations in the future. ### Description With this CL WebGPU backends will create devices with subgroups and subgroups-f16 features (both are under origin trial in Chrome) or chromium-experimental-subgroups feature enabled whenever available. ### Motivation and Context This CL would allow WebGPU operator shaders to use subgroup optimizations in the future, and might get some significant speedup with these optimization.
300 lines
11 KiB
TypeScript
300 lines
11 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import { Env } from 'onnxruntime-common';
|
|
|
|
import { calculateTensorSizeInBytes, DataType } from '../wasm-common';
|
|
|
|
import type { OrtWasmModule } from '../wasm-types';
|
|
|
|
import { WebGpuBackend } from './backend-webgpu';
|
|
import { LOG_DEBUG } from './log';
|
|
import { TensorView } from './tensor-view';
|
|
import { ShapeUtil } from './util';
|
|
import {
|
|
AdapterInfo,
|
|
ComputeContext,
|
|
ComputeContextInputsOutputsMapping,
|
|
DeviceInfo,
|
|
ProgramInfo,
|
|
} from './webgpu/types';
|
|
import { WebNNBackend } from './backend-webnn';
|
|
|
|
/* eslint-disable no-bitwise */
|
|
|
|
class TensorViewImpl implements TensorView {
|
|
constructor(
|
|
private module: OrtWasmModule,
|
|
public readonly dataType: number,
|
|
public readonly data: number,
|
|
public readonly dims: readonly number[],
|
|
) {}
|
|
|
|
getFloat32Array(): Float32Array {
|
|
if (this.dataType !== DataType.float) {
|
|
throw new Error('Invalid data type');
|
|
}
|
|
const elementCount = ShapeUtil.size(this.dims);
|
|
return elementCount === 0
|
|
? new Float32Array()
|
|
: new Float32Array(this.module.HEAP8.buffer, this.data, elementCount);
|
|
}
|
|
|
|
getBigInt64Array(): BigInt64Array {
|
|
if (this.dataType !== DataType.int64) {
|
|
throw new Error('Invalid data type');
|
|
}
|
|
const elementCount = ShapeUtil.size(this.dims);
|
|
return elementCount === 0
|
|
? new BigInt64Array()
|
|
: new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
|
|
}
|
|
|
|
getInt32Array(): Int32Array {
|
|
if (this.dataType !== DataType.int32) {
|
|
throw new Error('Invalid data type');
|
|
}
|
|
const elementCount = ShapeUtil.size(this.dims);
|
|
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
|
|
}
|
|
|
|
getUint16Array(): Uint16Array {
|
|
if (this.dataType !== DataType.float16 && this.dataType !== DataType.uint16) {
|
|
throw new Error('Invalid data type');
|
|
}
|
|
const elementCount = ShapeUtil.size(this.dims);
|
|
return elementCount === 0 ? new Uint16Array() : new Uint16Array(this.module.HEAP8.buffer, this.data, elementCount);
|
|
}
|
|
|
|
reshape(newDims: readonly number[]): TensorView {
|
|
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
|
|
throw new Error('Invalid new shape');
|
|
}
|
|
return new TensorViewImpl(this.module, this.dataType, this.data, newDims);
|
|
}
|
|
}
|
|
|
|
class ComputeContextImpl implements ComputeContext {
|
|
readonly adapterInfo: AdapterInfo;
|
|
readonly deviceInfo: DeviceInfo;
|
|
readonly opKernelContext: number;
|
|
readonly inputs: readonly TensorView[];
|
|
readonly outputCount: number;
|
|
get kernelCustomData(): { [key: string]: unknown } {
|
|
return this.backend.currentKernelCustomData;
|
|
}
|
|
get customDataBuffer(): Uint8Array {
|
|
return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize);
|
|
}
|
|
private customDataOffset = 0;
|
|
private customDataSize = 0;
|
|
constructor(
|
|
private module: OrtWasmModule,
|
|
private backend: WebGpuBackend,
|
|
contextDataOffset: number,
|
|
) {
|
|
this.adapterInfo = backend.adapterInfo;
|
|
this.deviceInfo = backend.deviceInfo;
|
|
|
|
// extract context data
|
|
const ptrSize = module.PTR_SIZE;
|
|
let dataIndex = contextDataOffset / module.PTR_SIZE;
|
|
const type = ptrSize === 4 ? 'i32' : 'i64';
|
|
this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type));
|
|
const inputCount = Number(module.getValue(ptrSize * dataIndex++, type));
|
|
this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type));
|
|
this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*'));
|
|
this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type));
|
|
|
|
const inputs: TensorView[] = [];
|
|
for (let i = 0; i < inputCount; i++) {
|
|
const dataType = Number(module.getValue(ptrSize * dataIndex++, type));
|
|
const data = Number(module.getValue(ptrSize * dataIndex++, '*'));
|
|
const dim = Number(module.getValue(ptrSize * dataIndex++, type));
|
|
const dims: number[] = [];
|
|
for (let d = 0; d < dim; d++) {
|
|
dims.push(Number(module.getValue(ptrSize * dataIndex++, type)));
|
|
}
|
|
inputs.push(new TensorViewImpl(module, dataType, data, dims));
|
|
}
|
|
this.inputs = inputs;
|
|
}
|
|
|
|
compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
|
|
// prepare inputs. inputs should always be valid data.
|
|
const mappedInputs =
|
|
inputsOutputsMapping?.inputs?.map((i) => (typeof i === 'number' ? this.inputs[i] : i)) ?? this.inputs;
|
|
// prepare outputs.
|
|
const outputIndices = inputsOutputsMapping?.outputs ?? [];
|
|
const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView =>
|
|
new TensorViewImpl(this.module, dataType, this.output(index, dims), dims);
|
|
const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => {
|
|
const bufferSize = calculateTensorSizeInBytes(dataType, dims);
|
|
if (!bufferSize) {
|
|
throw new Error(`Unsupported data type: ${dataType}`);
|
|
}
|
|
const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
|
|
return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
|
|
};
|
|
return this.backend.run(
|
|
program,
|
|
mappedInputs,
|
|
outputIndices,
|
|
createKernelOutput,
|
|
createTemporaryOutput,
|
|
this.outputCount,
|
|
);
|
|
}
|
|
|
|
output(index: number, dims: readonly number[]): number {
|
|
const stack = this.module.stackSave();
|
|
try {
|
|
const ptrSize = this.module.PTR_SIZE;
|
|
const type = ptrSize === 4 ? 'i32' : 'i64';
|
|
const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */);
|
|
this.module.setValue(data, dims.length, type);
|
|
for (let i = 0; i < dims.length; i++) {
|
|
this.module.setValue(data + ptrSize * (i + 1), dims[i], type);
|
|
}
|
|
return this.module._JsepOutput!(this.opKernelContext, index, data);
|
|
} catch (e) {
|
|
throw new Error(
|
|
`Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
|
|
'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' +
|
|
`Error: ${e}`,
|
|
);
|
|
} finally {
|
|
this.module.stackRestore(stack);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Initialize JSEP with WebGPU backend.
|
|
*
|
|
* This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
|
|
* each of the following EPs if they are specified:
|
|
* - "webgpu"
|
|
* - "webnn"
|
|
*
|
|
* For WebGPU, this function expects:
|
|
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_JSEP === false).
|
|
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
|
|
*
|
|
* For WebNN, this function expects:
|
|
* - WebNN is enabled in build (BUILD_DEFS.DISABLE_JSEP === false).
|
|
* - WebNN is available in current environment. (navigator.ml is not undefined)
|
|
*
|
|
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
|
|
* 'webgpu'/'webnn' backend.
|
|
*
|
|
* @param name - the name of the EP, either "webgpu" or "webnn"
|
|
* @param module - the ORT WebAssembly module
|
|
* @param env - the ORT environment variable (ort.env)
|
|
* @param gpuAdapter - the pre-created GPU adapter
|
|
*/
|
|
export const init = async (
|
|
name: 'webgpu' | 'webnn',
|
|
module: OrtWasmModule,
|
|
env: Env,
|
|
gpuAdapter?: GPUAdapter,
|
|
): Promise<void> => {
|
|
const jsepInit = module.jsepInit;
|
|
if (!jsepInit) {
|
|
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
|
|
}
|
|
|
|
if (name === 'webgpu') {
|
|
const backend = new WebGpuBackend();
|
|
await backend.initialize(env, gpuAdapter!);
|
|
|
|
jsepInit('webgpu', [
|
|
// backend
|
|
backend,
|
|
|
|
// jsepAlloc()
|
|
(size: number) => backend.alloc(Number(size)),
|
|
|
|
// jsepFree()
|
|
(ptr: number) => backend.free(ptr),
|
|
|
|
// jsepCopy(src, dst, size, isSourceGpu)
|
|
(src: number, dst: number, size: number, isSourceGpu = false) => {
|
|
if (isSourceGpu) {
|
|
LOG_DEBUG(
|
|
'verbose',
|
|
() => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`,
|
|
);
|
|
backend.memcpy(Number(src), Number(dst));
|
|
} else {
|
|
LOG_DEBUG(
|
|
'verbose',
|
|
() =>
|
|
`[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`,
|
|
);
|
|
const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size));
|
|
backend.upload(Number(dst), data);
|
|
}
|
|
},
|
|
|
|
// jsepCopyAsync(src, dst, size)
|
|
async (gpuDataId: number, dataOffset: number, size: number): Promise<void> => {
|
|
LOG_DEBUG(
|
|
'verbose',
|
|
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`,
|
|
);
|
|
|
|
await backend.download(Number(gpuDataId), () =>
|
|
module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0),
|
|
);
|
|
},
|
|
|
|
// jsepCreateKernel
|
|
(kernelType: string, kernelId: number, attribute: unknown) =>
|
|
backend.createKernel(
|
|
kernelType,
|
|
Number(kernelId),
|
|
attribute,
|
|
module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))),
|
|
),
|
|
|
|
// jsepReleaseKernel
|
|
(kernel: number) => backend.releaseKernel(kernel),
|
|
|
|
// jsepRun
|
|
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string | null>>) => {
|
|
LOG_DEBUG(
|
|
'verbose',
|
|
() =>
|
|
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
|
|
);
|
|
const context = new ComputeContextImpl(module, backend, Number(contextDataOffset));
|
|
return backend.computeKernel(Number(kernel), context, errors);
|
|
},
|
|
// jsepCaptureBegin
|
|
() => backend.captureBegin(),
|
|
// jsepCaptureEnd
|
|
() => backend.captureEnd(),
|
|
// jsepReplay
|
|
() => backend.replay(),
|
|
]);
|
|
} else {
|
|
const backend = new WebNNBackend(env);
|
|
jsepInit('webnn', [
|
|
backend,
|
|
// jsepReserveTensorId
|
|
() => backend.reserveTensorId(),
|
|
// jsepReleaseTensorId,
|
|
(tensorId: number) => backend.releaseTensorId(tensorId),
|
|
// jsepEnsureTensor
|
|
async (tensorId: number, onnxDataType: number, shape: number[], copyOld) =>
|
|
backend.ensureTensor(tensorId, onnxDataType, shape, copyOld),
|
|
// jsepUploadTensor
|
|
(tensorId: number, data: Uint8Array) => {
|
|
backend.uploadTensor(tensorId, data);
|
|
},
|
|
// jsepDownloadTensor
|
|
async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer),
|
|
]);
|
|
}
|
|
};
|