onnxruntime/js/web/lib/wasm/jsep/webgpu/program-manager.ts
Yulong Wang d9b9c5a537
[js/webgpu] support using uniform buffer (#17803)
### Description
support using uniform buffer.

This PR allows to use uniform buffer in shader program, so that some
runtime information (eg. input/output shape) is no longer need to be
hardcoded into shader code.

There are 2 commits in this PR:
-
[667f31c](667f31c83d):
framework changes to support uniform buffer, as well as updates in
program manager, gpu data manager and indices helper.
-
[09e1d2a](09e1d2ad1d):
an example change for operator `Transpose` to use input's rank-only
instead of dims as shader key. With this change, model mobilenetv2-12
shader compile times dropped from 71 to 52.
2023-10-10 00:31:12 -07:00

171 lines
7.2 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {tensorDataTypeEnumToString} from '../../wasm-common';
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';
import {TensorView} from '../tensor-view';
import {createShaderHelper} from './ops/common';
import {Artifact, GpuData, ProgramInfo} from './types';
/**
* ProgramManager is the main class behind running computations
* It builds ProgramInfo's into Artifacts
* It compiles given ProgramInfo's into WebGL Prorams (cached as Artifacts)
* Uses the artifact to run the computation by calling Draw on
* the WebGL drawing buffer
* ProgramManager automatically maps (binds) input variables to their
* corresponding Location's in the binary program
*/
export class ProgramManager {
repo: Map<unknown, Artifact>; // this should be per-session object
attributesBound: boolean;
constructor(private backend: WebGpuBackend) {
this.repo = new Map();
this.attributesBound = false;
}
getArtifact(key: unknown): Artifact|undefined {
return this.repo.get(key);
}
setArtifact(key: unknown, artifact: Artifact): void {
this.repo.set(key, artifact);
}
run(buildArtifact: Artifact, inputTensorViews: readonly TensorView[], outputTensorViews: readonly TensorView[],
inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number],
uniformBufferBinding: GPUBindingResource|undefined): void {
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
const profilingEnabled = this.backend.supportTimestampQuery && this.backend.env.webgpu.profilingMode === 'default';
if (profilingEnabled) {
// profiling write start timestamp
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 0);
}
computePassEncoder.setPipeline(buildArtifact.computePipeline);
const entries = [];
for (const input of inputs) {
entries.push({binding: entries.length, resource: {buffer: input.buffer}});
}
for (const output of outputs) {
entries.push({binding: entries.length, resource: {buffer: output.buffer}});
}
if (uniformBufferBinding) {
entries.push({binding: entries.length, resource: uniformBufferBinding});
}
const bindGroup = device.createBindGroup(
{layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name});
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.pendingDispatchNumber++;
if (profilingEnabled) {
// profiling write end timestamp
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(computePassEncoder as any).writeTimestamp(this.backend.profilingQuerySet, 1);
if (this.backend.profilingQueryData == null) {
this.backend.profilingQueryData =
// eslint-disable-next-line no-bitwise
this.backend.gpuDataManager.create(16, GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE);
}
// eslint-disable-next-line no-bitwise
const syncData = this.backend.gpuDataManager.create(16, GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST);
this.backend.endComputePass();
this.backend.getCommandEncoder().resolveQuerySet(
this.backend.profilingQuerySet, 0, 2, this.backend.profilingQueryData.buffer, 0);
this.backend.getCommandEncoder().copyBufferToBuffer(
this.backend.profilingQueryData.buffer, 0, syncData.buffer, 0, 16);
this.backend.flush();
const kernelId = this.backend.currentKernelId!;
const kernelInfo = this.backend.kernels.get(kernelId)!;
const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;
syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());
const startTimeU64 = mappedData[0];
const endTimeU64 = mappedData[1];
syncData.buffer.unmap();
if (typeof this.backend.profilingTimeBase === 'undefined') {
this.backend.profilingTimeBase = startTimeU64;
}
const startTime = Number(startTimeU64 - this.backend.profilingTimeBase);
const endTime = Number(endTimeU64 - this.backend.profilingTimeBase);
if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) {
throw new RangeError('incorrect timestamp range');
}
this.backend.gpuDataManager.release(syncData.id);
let inputShapes = '';
inputTensorViews.forEach((value, i) => {
inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
let outputShapes = '';
outputTensorViews.forEach((value, i) => {
outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
});
// eslint-disable-next-line no-console
console.log(`[profiling] kernel "${kernelId}|${kernelName}" ${inputShapes}${outputShapes}execution time: ${
endTime - startTime} ns`);
});
}
if (this.backend.pendingDispatchNumber >= 16) {
this.backend.flush();
}
}
dispose(): void {
// this.repo.forEach(a => this.glContext.deleteProgram(a.program));
}
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
extensions.push('enable f16;');
}
const shaderHelper = createShaderHelper(normalizedDispatchGroupSize);
const userCode = programInfo.getShaderSource(shaderHelper);
const code = `${extensions.join('\n')}\n${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({code, label: programInfo.name});
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);
const computePipeline = device.createComputePipeline(
{compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name});
return {programInfo, computePipeline};
}
normalizeDispatchGroupSize(dispatchGroup: ReturnType<ProgramInfo['getRunData']>['dispatchGroup']):
[number, number, number] {
const x = typeof dispatchGroup === 'number' ? dispatchGroup : dispatchGroup.x;
const y = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.y || 1);
const z = typeof dispatchGroup === 'number' ? 1 : (dispatchGroup.z || 1);
const limitPerDimension = this.backend.device.limits.maxComputeWorkgroupsPerDimension;
if (x <= limitPerDimension && y <= limitPerDimension && z <= limitPerDimension) {
return [x, y, z];
}
const size = x * y * z;
let dispatchAverage = Math.ceil(Math.sqrt(size));
if (dispatchAverage > limitPerDimension) {
dispatchAverage = Math.ceil(Math.cbrt(size));
if (dispatchAverage > limitPerDimension) {
throw new Error('Total dispatch size exceeds WebGPU maximum.');
}
return [dispatchAverage, dispatchAverage, dispatchAverage];
} else {
return [dispatchAverage, dispatchAverage, 1];
}
}
}