mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js] remove more unused training types (#22753)
### Description remove more unused training types
This commit is contained in:
parent
3975e79303
commit
3234487385
6 changed files with 0 additions and 531 deletions
|
|
@ -198,19 +198,6 @@ module.exports = {
|
|||
'_OrtReleaseTensor',
|
||||
'_OrtRun',
|
||||
'_OrtRunWithBinding',
|
||||
'_OrtTrainingCopyParametersFromBuffer',
|
||||
'_OrtTrainingCopyParametersToBuffer',
|
||||
'_OrtTrainingCreateSession',
|
||||
'_OrtTrainingEvalStep',
|
||||
'_OrtTrainingGetModelInputOutputCount',
|
||||
'_OrtTrainingGetModelInputOutputName',
|
||||
'_OrtTrainingGetParametersSize',
|
||||
'_OrtTrainingLazyResetGrad',
|
||||
'_OrtTrainingLoadCheckpoint',
|
||||
'_OrtTrainingOptimizerStep',
|
||||
'_OrtTrainingReleaseCheckpoint',
|
||||
'_OrtTrainingReleaseSession',
|
||||
'_OrtTrainingRunTrainStep',
|
||||
],
|
||||
},
|
||||
],
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import { InferenceSession } from './inference-session.js';
|
||||
import { OnnxValue } from './onnx-value.js';
|
||||
import { TrainingSession } from './training-session.js';
|
||||
|
||||
/**
|
||||
* @ignore
|
||||
|
|
@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler {
|
|||
): Promise<SessionHandler.ReturnType>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a handler instance of a training inference session.
|
||||
*
|
||||
* @ignore
|
||||
*/
|
||||
export interface TrainingSessionHandler extends SessionHandler {
|
||||
readonly evalInputNames: readonly string[];
|
||||
readonly evalOutputNames: readonly string[];
|
||||
|
||||
lazyResetGrad(): Promise<void>;
|
||||
runTrainStep(
|
||||
feeds: SessionHandler.FeedsType,
|
||||
fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<SessionHandler.ReturnType>;
|
||||
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
|
||||
runEvalStep(
|
||||
feeds: SessionHandler.FeedsType,
|
||||
fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<SessionHandler.ReturnType>;
|
||||
|
||||
getParametersSize(trainableOnly: boolean): Promise<number>;
|
||||
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
|
||||
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a backend that provides implementation of model inferencing.
|
||||
*
|
||||
|
|
@ -84,14 +56,6 @@ export interface Backend {
|
|||
uriOrBuffer: string | Uint8Array,
|
||||
options?: InferenceSession.SessionOptions,
|
||||
): Promise<InferenceSessionHandler>;
|
||||
|
||||
createTrainingSessionHandler?(
|
||||
checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
evalModelUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
|
||||
options: InferenceSession.SessionOptions,
|
||||
): Promise<TrainingSessionHandler>;
|
||||
}
|
||||
|
||||
export { registerBackend } from './backend-impl.js';
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ export declare namespace Env {
|
|||
* If not modified, the filename of the .wasm file is:
|
||||
* - `ort-wasm-simd-threaded.wasm` for default build
|
||||
* - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN)
|
||||
* - `ort-training-wasm-simd-threaded.wasm` for training build
|
||||
*/
|
||||
wasm?: URL | string;
|
||||
/**
|
||||
|
|
@ -25,7 +24,6 @@ export declare namespace Env {
|
|||
* If not modified, the filename of the .mjs file is:
|
||||
* - `ort-wasm-simd-threaded.mjs` for default build
|
||||
* - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN)
|
||||
* - `ort-training-wasm-simd-threaded.mjs` for training build
|
||||
*/
|
||||
mjs?: URL | string;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,4 +26,3 @@ export * from './tensor-factory.js';
|
|||
export * from './trace.js';
|
||||
export * from './onnx-model.js';
|
||||
export * from './onnx-value.js';
|
||||
export * from './training-session.js';
|
||||
|
|
|
|||
|
|
@ -1,273 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { resolveBackendAndExecutionProviders } from './backend-impl.js';
|
||||
import { SessionHandler, TrainingSessionHandler } from './backend.js';
|
||||
import { InferenceSession as InferenceSession } from './inference-session.js';
|
||||
import { OnnxValue } from './onnx-value.js';
|
||||
import { Tensor } from './tensor.js';
|
||||
import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js';
|
||||
|
||||
type SessionOptions = InferenceSession.SessionOptions;
|
||||
type FeedsType = InferenceSession.FeedsType;
|
||||
type FetchesType = InferenceSession.FetchesType;
|
||||
type ReturnType = InferenceSession.ReturnType;
|
||||
type RunOptions = InferenceSession.RunOptions;
|
||||
|
||||
const noBackendErrMsg: string =
|
||||
'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files.";
|
||||
|
||||
export class TrainingSession implements TrainingSessionInterface {
|
||||
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
|
||||
this.handler = handler;
|
||||
this.hasOptimizerModel = hasOptimizerModel;
|
||||
this.hasEvalModel = hasEvalModel;
|
||||
}
|
||||
private handler: TrainingSessionHandler;
|
||||
private hasOptimizerModel: boolean;
|
||||
private hasEvalModel: boolean;
|
||||
|
||||
get trainingInputNames(): readonly string[] {
|
||||
return this.handler.inputNames;
|
||||
}
|
||||
get trainingOutputNames(): readonly string[] {
|
||||
return this.handler.outputNames;
|
||||
}
|
||||
|
||||
get evalInputNames(): readonly string[] {
|
||||
if (this.hasEvalModel) {
|
||||
return this.handler.evalInputNames;
|
||||
} else {
|
||||
throw new Error('This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
get evalOutputNames(): readonly string[] {
|
||||
if (this.hasEvalModel) {
|
||||
return this.handler.evalOutputNames;
|
||||
} else {
|
||||
throw new Error('This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
static async create(
|
||||
trainingOptions: TrainingSessionCreateOptions,
|
||||
sessionOptions?: SessionOptions,
|
||||
): Promise<TrainingSession> {
|
||||
const evalModel: string | Uint8Array = trainingOptions.evalModel || '';
|
||||
const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || '';
|
||||
const options: SessionOptions = sessionOptions || {};
|
||||
|
||||
// resolve backend, update session options with validated EPs, and create session handler
|
||||
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
|
||||
if (backend.createTrainingSessionHandler) {
|
||||
const handler = await backend.createTrainingSessionHandler(
|
||||
trainingOptions.checkpointState,
|
||||
trainingOptions.trainModel,
|
||||
evalModel,
|
||||
optimizerModel,
|
||||
optionsWithValidatedEPs,
|
||||
);
|
||||
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
|
||||
} else {
|
||||
throw new Error(noBackendErrMsg);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
|
||||
* the given parameters to SessionHandler.FetchesType and RunOptions.
|
||||
*
|
||||
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
|
||||
* names.
|
||||
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
|
||||
* names.
|
||||
* @param feeds the required input
|
||||
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
|
||||
* @param arg2 optional RunOptions object.
|
||||
* @returns
|
||||
*/
|
||||
typeNarrowingForRunStep(
|
||||
inputNames: readonly string[],
|
||||
outputNames: readonly string[],
|
||||
feeds: FeedsType,
|
||||
arg1?: FetchesType | RunOptions,
|
||||
arg2?: RunOptions,
|
||||
): [SessionHandler.FetchesType, RunOptions] {
|
||||
const fetches: { [name: string]: OnnxValue | null } = {};
|
||||
let options: RunOptions = {};
|
||||
// check inputs
|
||||
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
|
||||
throw new TypeError(
|
||||
"'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.",
|
||||
);
|
||||
}
|
||||
|
||||
let isFetchesEmpty = true;
|
||||
// determine which override is being used
|
||||
if (typeof arg1 === 'object') {
|
||||
if (arg1 === null) {
|
||||
throw new TypeError('Unexpected argument[1]: cannot be null.');
|
||||
}
|
||||
if (arg1 instanceof Tensor) {
|
||||
throw new TypeError("'fetches' cannot be a Tensor");
|
||||
}
|
||||
|
||||
if (Array.isArray(arg1)) {
|
||||
if (arg1.length === 0) {
|
||||
throw new TypeError("'fetches' cannot be an empty array.");
|
||||
}
|
||||
isFetchesEmpty = false;
|
||||
// output names
|
||||
for (const name of arg1) {
|
||||
if (typeof name !== 'string') {
|
||||
throw new TypeError("'fetches' must be a string array or an object.");
|
||||
}
|
||||
if (outputNames.indexOf(name) === -1) {
|
||||
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
|
||||
}
|
||||
fetches[name] = null;
|
||||
}
|
||||
|
||||
if (typeof arg2 === 'object' && arg2 !== null) {
|
||||
options = arg2;
|
||||
} else if (typeof arg2 !== 'undefined') {
|
||||
throw new TypeError("'options' must be an object.");
|
||||
}
|
||||
} else {
|
||||
// decide whether arg1 is fetches or options
|
||||
// if any output name is present and its value is valid OnnxValue, we consider it fetches
|
||||
let isFetches = false;
|
||||
const arg1Keys = Object.getOwnPropertyNames(arg1);
|
||||
for (const name of outputNames) {
|
||||
if (arg1Keys.indexOf(name) !== -1) {
|
||||
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
|
||||
if (v === null || v instanceof Tensor) {
|
||||
isFetches = true;
|
||||
isFetchesEmpty = false;
|
||||
fetches[name] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isFetches) {
|
||||
if (typeof arg2 === 'object' && arg2 !== null) {
|
||||
options = arg2;
|
||||
} else if (typeof arg2 !== 'undefined') {
|
||||
throw new TypeError("'options' must be an object.");
|
||||
}
|
||||
} else {
|
||||
options = arg1 as RunOptions;
|
||||
}
|
||||
}
|
||||
} else if (typeof arg1 !== 'undefined') {
|
||||
throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
|
||||
}
|
||||
|
||||
// check if all inputs are in feed
|
||||
for (const name of inputNames) {
|
||||
if (typeof feeds[name] === 'undefined') {
|
||||
throw new Error(`input '${name}' is missing in 'feeds'.`);
|
||||
}
|
||||
}
|
||||
|
||||
// if no fetches is specified, we use the full output names list
|
||||
if (isFetchesEmpty) {
|
||||
for (const name of outputNames) {
|
||||
fetches[name] = null;
|
||||
}
|
||||
}
|
||||
|
||||
return [fetches, options];
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
|
||||
* and changes it into a map of Tensors.
|
||||
*
|
||||
* @param results
|
||||
* @returns
|
||||
*/
|
||||
convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType {
|
||||
const returnValue: { [name: string]: OnnxValue } = {};
|
||||
for (const key in results) {
|
||||
if (Object.hasOwnProperty.call(results, key)) {
|
||||
const result = results[key];
|
||||
if (result instanceof Tensor) {
|
||||
returnValue[key] = result;
|
||||
} else {
|
||||
returnValue[key] = new Tensor(result.type, result.data, result.dims);
|
||||
}
|
||||
}
|
||||
}
|
||||
return returnValue;
|
||||
}
|
||||
|
||||
async lazyResetGrad(): Promise<void> {
|
||||
await this.handler.lazyResetGrad();
|
||||
}
|
||||
|
||||
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
|
||||
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
|
||||
async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
|
||||
const [fetches, options] = this.typeNarrowingForRunStep(
|
||||
this.trainingInputNames,
|
||||
this.trainingOutputNames,
|
||||
feeds,
|
||||
arg1,
|
||||
arg2,
|
||||
);
|
||||
const results = await this.handler.runTrainStep(feeds, fetches, options);
|
||||
return this.convertHandlerReturnTypeToMapOfTensors(results);
|
||||
}
|
||||
|
||||
async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise<void> {
|
||||
if (this.hasOptimizerModel) {
|
||||
await this.handler.runOptimizerStep(options || {});
|
||||
} else {
|
||||
throw new Error('This TrainingSession has no OptimizerModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise<ReturnType>;
|
||||
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise<ReturnType>;
|
||||
async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
|
||||
if (this.hasEvalModel) {
|
||||
const [fetches, options] = this.typeNarrowingForRunStep(
|
||||
this.evalInputNames,
|
||||
this.evalOutputNames,
|
||||
feeds,
|
||||
arg1,
|
||||
arg2,
|
||||
);
|
||||
const results = await this.handler.runEvalStep(feeds, fetches, options);
|
||||
return this.convertHandlerReturnTypeToMapOfTensors(results);
|
||||
} else {
|
||||
throw new Error('This TrainingSession has no EvalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
async getParametersSize(trainableOnly = true): Promise<number> {
|
||||
return this.handler.getParametersSize(trainableOnly);
|
||||
}
|
||||
|
||||
async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise<void> {
|
||||
const paramsSize = await this.getParametersSize(trainableOnly);
|
||||
// checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
|
||||
// of parameters
|
||||
if (array.length !== 4 * paramsSize) {
|
||||
throw new Error(
|
||||
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
|
||||
'the model. Please use getParametersSize method to check.',
|
||||
);
|
||||
}
|
||||
return this.handler.loadParametersBuffer(array, trainableOnly);
|
||||
}
|
||||
|
||||
async getContiguousParameters(trainableOnly = true): Promise<OnnxValue> {
|
||||
return this.handler.getContiguousParameters(trainableOnly);
|
||||
}
|
||||
|
||||
async release(): Promise<void> {
|
||||
return this.handler.dispose();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,206 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { InferenceSession } from './inference-session.js';
|
||||
import { OnnxValue } from './onnx-value.js';
|
||||
import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-redeclare */
|
||||
|
||||
export declare namespace TrainingSession {
|
||||
/**
|
||||
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
|
||||
*/
|
||||
type UriOrBuffer = string | Uint8Array;
|
||||
}
|
||||
|
||||
/**
|
||||
* Represent a runtime instance of an ONNX training session,
|
||||
* which contains a model that can be trained, and, optionally,
|
||||
* an eval and optimizer model.
|
||||
*/
|
||||
export interface TrainingSession {
|
||||
// #region run()
|
||||
|
||||
/**
|
||||
* Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
|
||||
* runOptimizerStep.
|
||||
*/
|
||||
lazyResetGrad(): Promise<void>;
|
||||
|
||||
/**
|
||||
* Run TrainStep asynchronously with the given feeds and options.
|
||||
*
|
||||
* @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
|
||||
detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model training.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
|
||||
*/
|
||||
runTrainStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Run a single train step with the given inputs and options.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param fetches - Representation of the model output.
|
||||
* detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model training.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runTrainStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
fetches: InferenceSession.FetchesType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
|
||||
*
|
||||
* @param options - Optional. A set of options that controls the behavior of model optimizing.
|
||||
*/
|
||||
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
|
||||
|
||||
/**
|
||||
* Run a single eval step with the given inputs and options using the eval model.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runEvalStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Run a single eval step with the given inputs and options using the eval model.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param fetches - Representation of the model output.
|
||||
* detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runEvalStep(
|
||||
feeds: InferenceSession.FeedsType,
|
||||
fetches: InferenceSession.FetchesType,
|
||||
options?: InferenceSession.RunOptions,
|
||||
): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region copy parameters
|
||||
|
||||
/**
|
||||
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
|
||||
* the parameters) elements of all the parameters in the training state.
|
||||
*
|
||||
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
|
||||
*/
|
||||
getParametersSize(trainableOnly: boolean): Promise<number>;
|
||||
|
||||
/**
|
||||
* Copies parameter values from the given buffer to the training state. Currently, only supporting models with
|
||||
* parameters of type Float32.
|
||||
*
|
||||
* @param buffer - A Uint8Array representation of Float32 parameters.
|
||||
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
|
||||
*/
|
||||
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
|
||||
|
||||
/**
|
||||
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
|
||||
* Currently, only supporting models with parameters of type Float32.
|
||||
*
|
||||
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
|
||||
* for which requires_grad is set to true. Default value is true.
|
||||
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
|
||||
*/
|
||||
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
|
||||
// #endregion
|
||||
|
||||
// #region release()
|
||||
|
||||
/**
|
||||
* Release the inference session and the underlying resources.
|
||||
*/
|
||||
release(): Promise<void>;
|
||||
// #endregion
|
||||
|
||||
// #region metadata
|
||||
|
||||
/**
|
||||
* Get input names of the loaded training model.
|
||||
*/
|
||||
readonly trainingInputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get output names of the loaded training model.
|
||||
*/
|
||||
readonly trainingOutputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
|
||||
*/
|
||||
readonly evalInputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
|
||||
*/
|
||||
readonly evalOutputNames: readonly string[];
|
||||
|
||||
// #endregion
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents the optional parameters that can be passed into the TrainingSessionFactory.
|
||||
*/
|
||||
export interface TrainingSessionCreateOptions {
|
||||
/**
|
||||
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
|
||||
*/
|
||||
checkpointState: TrainingSession.UriOrBuffer;
|
||||
/**
|
||||
* URI or buffer for the .onnx training file.
|
||||
*/
|
||||
trainModel: TrainingSession.UriOrBuffer;
|
||||
/**
|
||||
* Optional. URI or buffer for the .onnx optimizer model file.
|
||||
*/
|
||||
optimizerModel?: TrainingSession.UriOrBuffer;
|
||||
/**
|
||||
* Optional. URI or buffer for the .onnx eval model file.
|
||||
*/
|
||||
evalModel?: TrainingSession.UriOrBuffer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Defines method overload possibilities for creating a TrainingSession.
|
||||
*/
|
||||
export interface TrainingSessionFactory {
|
||||
// #region create()
|
||||
|
||||
/**
|
||||
* Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
|
||||
*
|
||||
* @param trainingOptions specify models and checkpoints to load into the Training Session
|
||||
* @param sessionOptions specify configuration for training session behavior
|
||||
*
|
||||
* @returns Promise that resolves to a TrainingSession object
|
||||
*/
|
||||
create(
|
||||
trainingOptions: TrainingSessionCreateOptions,
|
||||
sessionOptions?: InferenceSession.SessionOptions,
|
||||
): Promise<TrainingSession>;
|
||||
|
||||
// #endregion
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/naming-convention
|
||||
export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;
|
||||
Loading…
Reference in a new issue