mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
### Description * implemented runEvalStep and runOptimizerStep * added hasEvalModel and hasOptimizerModel boolean fields in TrainingSession representation * added evalInputNames and evalOutputNames fields to TrainingSessionHandler & TrainingSession * removed the inputNamesEncoded and outputNamesEncoded fields from TrainingSessionHandler -- since none of the training methods require the input names and output names as parameters, there's no need to store them. ### Motivation and Context * part of the work for implementing web bindings for training * previous PR: #18250 --------- Co-authored-by: Ashwini Khade <askhade@microsoft.com>
84 lines
2.6 KiB
TypeScript
84 lines
2.6 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {InferenceSession} from './inference-session.js';
|
|
import {OnnxValue} from './onnx-value.js';
|
|
import {TrainingSession} from './training-session.js';
|
|
|
|
/**
|
|
* @ignore
|
|
*/
|
|
export declare namespace SessionHandler {
|
|
type FeedsType = {[name: string]: OnnxValue};
|
|
type FetchesType = {[name: string]: OnnxValue | null};
|
|
type ReturnType = {[name: string]: OnnxValue};
|
|
}
|
|
|
|
/**
|
|
* Represents shared SessionHandler functionality
|
|
*
|
|
* @ignore
|
|
*/
|
|
interface SessionHandler {
|
|
dispose(): Promise<void>;
|
|
|
|
readonly inputNames: readonly string[];
|
|
readonly outputNames: readonly string[];
|
|
}
|
|
|
|
/**
|
|
* Represent a handler instance of an inference session.
|
|
*
|
|
* @ignore
|
|
*/
|
|
export interface InferenceSessionHandler extends SessionHandler {
|
|
startProfiling(): void;
|
|
endProfiling(): void;
|
|
|
|
run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
|
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
|
|
}
|
|
|
|
/**
|
|
* Represent a handler instance of a training inference session.
|
|
*
|
|
* @ignore
|
|
*/
|
|
export interface TrainingSessionHandler extends SessionHandler {
|
|
readonly evalInputNames: readonly string[];
|
|
readonly evalOutputNames: readonly string[];
|
|
|
|
runTrainStep(
|
|
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
|
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
|
|
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
|
|
runEvalStep(
|
|
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
|
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
|
|
|
|
getParametersSize(trainableOnly: boolean): Promise<number>;
|
|
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
|
|
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
|
|
}
|
|
|
|
/**
|
|
* Represent a backend that provides implementation of model inferencing.
|
|
*
|
|
* @ignore
|
|
*/
|
|
export interface Backend {
|
|
/**
|
|
* Initialize the backend asynchronously. Should throw when failed.
|
|
*/
|
|
init(): Promise<void>;
|
|
|
|
createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
|
|
Promise<InferenceSessionHandler>;
|
|
|
|
createTrainingSessionHandler?
|
|
(checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer,
|
|
evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer,
|
|
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler>;
|
|
}
|
|
|
|
export {registerBackend} from './backend-impl.js';
|