onnxruntime/js/common/lib/backend.ts
Yulong Wang 9a61388f0a
[js/web] revise backend registration (#18715)
### Description
This PR revises the backend registration.

The following describes the expected behavior after this change:
(**bolded are changed behavior**)

- (ort.min.js - built without webgpu support)
    - loading: do not register 'webgpu' backend
- creating session without EP list: use default EP list ['webnn', 'cpu',
'wasm']
- creating session with ['webgpu'] as EP list: should fail with backend
not available
- (ort.webgpu.min.js - built with webgpu support)
    - loading: **always register 'webgpu' backend**
( previous behavior: only register 'webgpu' backend when `navigator.gpu`
is available)
- creating session without EP list: use default EP list ['webgpu',
'webnn', 'cpu', 'wasm']
        - when WebGPU is available (win): use WebGPU backend
- when WebGPU is unavailable (android): **should fail backend init,**
and try to use next backend in the list, 'webnn'
(previous behavior: does not fail backend init, but fail in JSEP init,
which was too late to switch to next backend)
    - creating session with ['webgpu'] as EP list
        - when WebGPU is available (win): use WebGPU backend
- when WebGPU is unavailable (android): **should fail backend init, and
because no more EP listed, fail.


related PRs: #18190 #18144
2023-12-20 14:45:55 -08:00

85 lines
2.7 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {TrainingSession} from './training-session.js';
/**
* @ignore
*/
export declare namespace SessionHandler {
type FeedsType = {[name: string]: OnnxValue};
type FetchesType = {[name: string]: OnnxValue | null};
type ReturnType = {[name: string]: OnnxValue};
}
/**
* Represents shared SessionHandler functionality
*
* @ignore
*/
interface SessionHandler {
dispose(): Promise<void>;
readonly inputNames: readonly string[];
readonly outputNames: readonly string[];
}
/**
* Represent a handler instance of an inference session.
*
* @ignore
*/
export interface InferenceSessionHandler extends SessionHandler {
startProfiling(): void;
endProfiling(): void;
run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
}
/**
* Represent a handler instance of a training inference session.
*
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];
lazyResetGrad(): Promise<void>;
runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}
/**
* Represent a backend that provides implementation of model inferencing.
*
* @ignore
*/
export interface Backend {
/**
* Initialize the backend asynchronously. Should throw when failed.
*/
init(backendName: string): Promise<void>;
createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler>;
createTrainingSessionHandler?
(checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer,
evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler>;
}
export {registerBackend} from './backend-impl.js';