mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
[//]: # (## Work In Progress. Feedbacks are welcome!) ### Description This PR adds a few properties, methods and factories to Tensor type to support IO-binding feature. This will allow user to create tensor from GPU/CPU bound data without a force transferring of data between CPU and GPU. This change is a way to resolve #15312 ### Change Summary 1. Add properties to `Tensor` type: a. `location`: indicating where the data is sitting. valid values are `cpu`, `cpu-pinned`, `texture`, `gpu-buffer`. b. `texture`: sit side to `data`, a readonly property of `WebGLTexture` type. available only when `location === 'texture'` c. `gpuBuffer`: sit side to `data`, a readonly property of `GPUBuffer` type. available only when `location === 'gpu-buffer'` 2. Add methods to `Tensor` type (usually dealing with inference outputs): - async function `getData()` allows user to download data from GPU to CPU manually. - function `dispose()` allows user to release GPU resources manually. 3. Add factories for creating `Tensor` instances: a. `fromTexture()` to create a WebGL texture bound tensor data b. `fromGpuBuffer()` to create a WebGPUBuffer bound tensor data c. `fromPinnedBuffer()` to create a tensor using a CPU pinned buffer ### Examples: create tensors from texture and pass to inference session as inputs ```js // when create session, specify we prefer 'image_output:0' to be stored on GPU as texture const session = await InferenceSession.create('./my_model.onnx', { executionProviders: [ 'webgl' ], preferredOutputLocation: { 'image_output:0': 'texture' } }); ... const myImageTexture = getTexture(); // user's function to get a texture const myFeeds = { input0: Tensor.fromTexture(myImageTexture, { width: 224, height: 224 }) }; // shape [1, 224, 224, 4], RGBA format. const results = await session.run(myFeeds); const myOutputTexture = results['image_output:0'].texture; ```
157 lines
4 KiB
TypeScript
157 lines
4 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {env as envImpl} from './env-impl.js';
|
|
|
|
export declare namespace Env {
|
|
export type WasmPrefixOrFilePaths = string|{
|
|
/* eslint-disable @typescript-eslint/naming-convention */
|
|
'ort-wasm.wasm'?: string;
|
|
'ort-wasm-threaded.wasm'?: string;
|
|
'ort-wasm-simd.wasm'?: string;
|
|
'ort-wasm-simd-threaded.wasm'?: string;
|
|
/* eslint-enable @typescript-eslint/naming-convention */
|
|
};
|
|
export interface WebAssemblyFlags {
|
|
/**
|
|
* set or get number of thread(s). If omitted or set to 0, number of thread(s) will be determined by system. If set
|
|
* to 1, no worker thread will be spawned.
|
|
*
|
|
* This setting is available only when WebAssembly multithread feature is available in current context.
|
|
*
|
|
* @defaultValue `0`
|
|
*/
|
|
numThreads?: number;
|
|
|
|
/**
|
|
* set or get a boolean value indicating whether to enable SIMD. If set to false, SIMD will be forcely disabled.
|
|
*
|
|
* This setting is available only when WebAssembly SIMD feature is available in current context.
|
|
*
|
|
* @defaultValue `true`
|
|
*/
|
|
simd?: boolean;
|
|
|
|
/**
|
|
* Set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds. A zero
|
|
* value indicates no timeout is set.
|
|
*
|
|
* @defaultValue `0`
|
|
*/
|
|
initTimeout?: number;
|
|
|
|
/**
|
|
* Set a custom URL prefix to the .wasm files or a set of overrides for each .wasm file. The override path should be
|
|
* an absolute path.
|
|
*/
|
|
wasmPaths?: WasmPrefixOrFilePaths;
|
|
|
|
/**
|
|
* Set or get a boolean value indicating whether to proxy the execution of main thread to a worker thread.
|
|
*
|
|
* @defaultValue `false`
|
|
*/
|
|
proxy?: boolean;
|
|
}
|
|
|
|
export interface WebGLFlags {
|
|
/**
|
|
* Set or get the WebGL Context ID (webgl or webgl2).
|
|
*
|
|
* @defaultValue `'webgl2'`
|
|
*/
|
|
contextId?: 'webgl'|'webgl2';
|
|
/**
|
|
* Get the WebGL rendering context.
|
|
*/
|
|
readonly context: WebGLRenderingContext;
|
|
/**
|
|
* Set or get the maximum batch size for matmul. 0 means to disable batching.
|
|
*
|
|
* @deprecated
|
|
*/
|
|
matmulMaxBatchSize?: number;
|
|
/**
|
|
* Set or get the texture cache mode.
|
|
*
|
|
* @defaultValue `'full'`
|
|
*/
|
|
textureCacheMode?: 'initializerOnly'|'full';
|
|
/**
|
|
* Set or get the packed texture mode
|
|
*
|
|
* @defaultValue `false`
|
|
*/
|
|
pack?: boolean;
|
|
/**
|
|
* Set or get whether enable async download.
|
|
*
|
|
* @defaultValue `false`
|
|
*/
|
|
async?: boolean;
|
|
}
|
|
|
|
export interface WebGpuFlags {
|
|
/**
|
|
* Set or get the profiling mode.
|
|
*/
|
|
profilingMode?: 'off'|'default';
|
|
/**
|
|
* Get the device for WebGPU.
|
|
*
|
|
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
|
|
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
|
|
*
|
|
* see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types".
|
|
*/
|
|
readonly device: unknown;
|
|
}
|
|
}
|
|
|
|
export interface Env {
|
|
/**
|
|
* set the severity level for logging.
|
|
*
|
|
* @defaultValue `'warning'`
|
|
*/
|
|
logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal';
|
|
/**
|
|
* Indicate whether run in debug mode.
|
|
*
|
|
* @defaultValue `false`
|
|
*/
|
|
debug?: boolean;
|
|
|
|
/**
|
|
* Get version of the current package.
|
|
*/
|
|
readonly versions: {
|
|
readonly common: string;
|
|
readonly web?: string;
|
|
readonly node?: string;
|
|
// eslint-disable-next-line @typescript-eslint/naming-convention
|
|
readonly 'react-native'?: string;
|
|
};
|
|
|
|
/**
|
|
* Represent a set of flags for WebAssembly
|
|
*/
|
|
readonly wasm: Env.WebAssemblyFlags;
|
|
|
|
/**
|
|
* Represent a set of flags for WebGL
|
|
*/
|
|
readonly webgl: Env.WebGLFlags;
|
|
|
|
/**
|
|
* Represent a set of flags for WebGPU
|
|
*/
|
|
readonly webgpu: Env.WebGpuFlags;
|
|
|
|
[name: string]: unknown;
|
|
}
|
|
|
|
/**
|
|
* Represent a set of flags as a global singleton.
|
|
*/
|
|
export const env: Env = envImpl;
|