onnxruntime/js/web/lib/wasm/jsep/init.ts
Yang Gu e803f8eb0f
[js/webgpu] Refactor timestamp-query and introduce timestamp-query-inside-passes (#18894)
We submit kernels in a batch (a fixed number 16 is used except for the
last batch) for better performance. However, timestamp query support is
at pass level so we disable the batch execution in profiling mode in
previous implementation. Actually we can have multiple passes in a batch
so that we don't have to disable batch execution, which is the first
enhancement of this PR.
Furthermore, WebGPU has an extension to support timestamp query inside
passes, which isn't supported by all the platforms (e.g., Windows
supports it, while macOS doesn't). This is expected to have lower cost
compared with multiple passes solution. So this PR also introduce this
support when available.
This PR also refactors some implementation related to kernelInfo, and
try to unify the related kernel names.
2024-01-13 00:23:17 -08:00

205 lines
8 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Env} from 'onnxruntime-common';
import {OrtWasmModule} from '../binding/ort-wasm';
import {DataType, getTensorElementSize} from '../wasm-common';
import {WebGpuBackend} from './backend-webgpu';
import {LOG_DEBUG} from './log';
import {TensorView} from './tensor-view';
import {ShapeUtil} from './util';
import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
/* eslint-disable no-bitwise */
class TensorViewImpl implements TensorView {
constructor(
private module: OrtWasmModule, public readonly dataType: number, public readonly data: number,
public readonly dims: readonly number[]) {}
getFloat32Array(): Float32Array {
if (this.dataType !== DataType.float) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Float32Array() :
new Float32Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getBigInt64Array(): BigInt64Array {
if (this.dataType !== DataType.int64) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new BigInt64Array() :
new BigInt64Array(this.module.HEAP8.buffer, this.data, elementCount);
}
getInt32Array(): Int32Array {
if (this.dataType !== DataType.int32) {
throw new Error('Invalid data type');
}
const elementCount = ShapeUtil.size(this.dims);
return elementCount === 0 ? new Int32Array() : new Int32Array(this.module.HEAP8.buffer, this.data, elementCount);
}
reshape(newDims: readonly number[]): TensorView {
if (ShapeUtil.size(newDims) !== ShapeUtil.size(this.dims)) {
throw new Error('Invalid new shape');
}
return new TensorViewImpl(this.module, this.dataType, this.data, newDims);
}
}
class ComputeContextImpl implements ComputeContext {
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
get kernelCustomData(): {[key: string]: unknown} {
return this.backend.currentKernelCustomData;
}
get customDataBuffer(): Uint8Array {
return this.module.HEAPU8.subarray(this.customDataOffset, this.customDataOffset + this.customDataSize);
}
private customDataOffset = 0;
private customDataSize = 0;
constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) {
const heapU32 = module.HEAPU32;
// extract context data
let dataIndex = (contextDataOffset >>> 2);
this.opKernelContext = heapU32[dataIndex++];
const inputCount = heapU32[dataIndex++];
this.outputCount = heapU32[dataIndex++];
this.customDataOffset = heapU32[dataIndex++];
this.customDataSize = heapU32[dataIndex++];
const inputs: TensorView[] = [];
for (let i = 0; i < inputCount; i++) {
const dataType = heapU32[dataIndex++];
const data = heapU32[dataIndex++];
const dim = heapU32[dataIndex++];
const dims: number[] = [];
for (let d = 0; d < dim; d++) {
dims.push(heapU32[dataIndex++]);
}
inputs.push(new TensorViewImpl(module, dataType, data, dims));
}
this.inputs = inputs;
}
compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
inputsOutputsMapping?.inputs?.map(i => typeof i === 'number' ? this.inputs[i] : i) ?? this.inputs;
// prepare outputs.
const outputIndices = inputsOutputsMapping?.outputs ?? [];
const createKernelOutput = (index: number, dataType: number, dims: readonly number[]): TensorView =>
new TensorViewImpl(this.module, dataType, this.output(index, dims), dims);
const createTemporaryOutput = (dataType: number, dims: readonly number[]): TensorView => {
const elementSize = getTensorElementSize(dataType);
if (!elementSize) {
throw new Error(`Unsupported data type: ${dataType}`);
}
const bufferSize = elementSize * ShapeUtil.size(dims);
return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims);
};
return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput);
}
output(index: number, dims: readonly number[]): number {
const stack = this.module.stackSave();
try {
const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
let offset = data >> 2;
this.module.HEAPU32[offset++] = dims.length;
for (let i = 0; i < dims.length; i++) {
this.module.HEAPU32[offset++] = dims[i];
}
return this.module._JsepOutput(this.opKernelContext, index, data);
} catch (e) {
throw new Error(
`Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' +
`Error: ${e}`);
} finally {
this.module.stackRestore(stack);
}
}
}
/**
* Initialize JSEP with WebGPU backend.
*
* This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
* This function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
* 'webgpu' backend.
*
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise<void> => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}
const backend = new WebGpuBackend();
await backend.initialize(env, gpuAdapter);
jsepInit(
// backend
backend,
// jsepAlloc()
(size: number) => backend.alloc(size),
// jsepFree()
(ptr: number) => backend.free(ptr),
// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
}
},
// jsepCopyAsync(src, dst, size)
async(gpuDataId: number, dataOffset: number, size: number):
Promise<void> => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);
await backend.download(
gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
},
// jsepCreateKernel
(kernelType: string, kernelId: number, attribute: unknown) =>
backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName(kernelId))),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
// jsepRun
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array<Promise<string|null>>) => {
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
});
};