mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Description Fix a few warnings in typedoc (for generating JS API): ``` [warning] The signature TrainingSession.loadParametersBuffer has an @param with name "buffer", which was not used. [warning] NonTensorType, defined in ./lib/onnx-value.ts, is referenced by OnnxValue but not included in the documentation. [warning] TensorFactory, defined in ./lib/tensor-factory.ts, is referenced by Tensor but not included in the documentation. [warning] ExternalDataFileType, defined in ./lib/onnx-model.ts, is referenced by InferenceSession.SessionOptions.externalData but not included in the documentation. [warning] TensorToDataUrlOptions, defined in ./lib/tensor-conversion.ts, is referenced by Tensor.toDataURL.toDataURL.options but not included in the documentation. [warning] TensorToImageDataOptions, defined in ./lib/tensor-conversion.ts, is referenced by Tensor.toImageData.toImageData.options but not included in the documentation. [warning] Failed to resolve link to "GpuBufferType" in comment for Env.WebGpuFlags.adapter. [warning] Failed to resolve link to "GpuBufferType" in comment for Env.WebGpuFlags.device. ``` Changes highlighted: - Merge `CoreMlExecutionProviderOption` and `CoreMLExecutionProviderOption`. They expose 2 set of different options for React-native and ORT nodejs binding. This should be fixed in future. - Fix a few inconsistency of names between JSDoc and parameters - Fix broken type links - Exclude trace functions
196 lines
7 KiB
TypeScript
196 lines
7 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {InferenceSession} from './inference-session.js';
|
|
import {OnnxValue} from './onnx-value.js';
|
|
import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js';
|
|
|
|
/* eslint-disable @typescript-eslint/no-redeclare */
|
|
|
|
export declare namespace TrainingSession {
|
|
/**
|
|
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
|
|
*/
|
|
type UriOrBuffer = string|Uint8Array;
|
|
}
|
|
|
|
/**
|
|
* Represent a runtime instance of an ONNX training session,
|
|
* which contains a model that can be trained, and, optionally,
|
|
* an eval and optimizer model.
|
|
*/
|
|
export interface TrainingSession {
|
|
// #region run()
|
|
|
|
/**
|
|
* Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
|
|
* runOptimizerStep.
|
|
*/
|
|
lazyResetGrad(): Promise<void>;
|
|
|
|
/**
|
|
* Run TrainStep asynchronously with the given feeds and options.
|
|
*
|
|
* @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
|
|
detail.
|
|
* @param options - Optional. A set of options that controls the behavior of model training.
|
|
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
|
|
*/
|
|
runTrainStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
|
|
Promise<InferenceSession.ReturnType>;
|
|
|
|
/**
|
|
* Run a single train step with the given inputs and options.
|
|
*
|
|
* @param feeds - Representation of the model input.
|
|
* @param fetches - Representation of the model output.
|
|
* detail.
|
|
* @param options - Optional. A set of options that controls the behavior of model training.
|
|
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
|
values.
|
|
*/
|
|
runTrainStep(
|
|
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
|
|
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;
|
|
|
|
/**
|
|
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
|
|
*
|
|
* @param options - Optional. A set of options that controls the behavior of model optimizing.
|
|
*/
|
|
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
|
|
|
|
/**
|
|
* Run a single eval step with the given inputs and options using the eval model.
|
|
*
|
|
* @param feeds - Representation of the model input.
|
|
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
|
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
|
values.
|
|
*/
|
|
runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
|
|
Promise<InferenceSession.ReturnType>;
|
|
|
|
/**
|
|
* Run a single eval step with the given inputs and options using the eval model.
|
|
*
|
|
* @param feeds - Representation of the model input.
|
|
* @param fetches - Representation of the model output.
|
|
* detail.
|
|
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
|
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
|
values.
|
|
*/
|
|
runEvalStep(
|
|
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
|
|
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;
|
|
|
|
// #endregion
|
|
|
|
// #region copy parameters
|
|
|
|
/**
|
|
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
|
|
* the parameters) elements of all the parameters in the training state.
|
|
*
|
|
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
|
|
*/
|
|
getParametersSize(trainableOnly: boolean): Promise<number>;
|
|
|
|
/**
|
|
* Copies parameter values from the given buffer to the training state. Currently, only supporting models with
|
|
* parameters of type Float32.
|
|
*
|
|
* @param buffer - A Uint8Array representation of Float32 parameters.
|
|
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
|
|
*/
|
|
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
|
|
|
|
/**
|
|
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
|
|
* Currently, only supporting models with parameters of type Float32.
|
|
*
|
|
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
|
|
* for which requires_grad is set to true. Default value is true.
|
|
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
|
|
*/
|
|
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
|
|
// #endregion
|
|
|
|
// #region release()
|
|
|
|
/**
|
|
* Release the inference session and the underlying resources.
|
|
*/
|
|
release(): Promise<void>;
|
|
// #endregion
|
|
|
|
// #region metadata
|
|
|
|
/**
|
|
* Get input names of the loaded training model.
|
|
*/
|
|
readonly trainingInputNames: readonly string[];
|
|
|
|
/**
|
|
* Get output names of the loaded training model.
|
|
*/
|
|
readonly trainingOutputNames: readonly string[];
|
|
|
|
/**
|
|
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
|
|
*/
|
|
readonly evalInputNames: readonly string[];
|
|
|
|
/**
|
|
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
|
|
*/
|
|
readonly evalOutputNames: readonly string[];
|
|
|
|
// #endregion
|
|
}
|
|
|
|
/**
|
|
* Represents the optional parameters that can be passed into the TrainingSessionFactory.
|
|
*/
|
|
export interface TrainingSessionCreateOptions {
|
|
/**
|
|
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
|
|
*/
|
|
checkpointState: TrainingSession.UriOrBuffer;
|
|
/**
|
|
* URI or buffer for the .onnx training file.
|
|
*/
|
|
trainModel: TrainingSession.UriOrBuffer;
|
|
/**
|
|
* Optional. URI or buffer for the .onnx optimizer model file.
|
|
*/
|
|
optimizerModel?: TrainingSession.UriOrBuffer;
|
|
/**
|
|
* Optional. URI or buffer for the .onnx eval model file.
|
|
*/
|
|
evalModel?: TrainingSession.UriOrBuffer;
|
|
}
|
|
|
|
/**
|
|
* Defines method overload possibilities for creating a TrainingSession.
|
|
*/
|
|
export interface TrainingSessionFactory {
|
|
// #region create()
|
|
|
|
/**
|
|
* Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
|
|
*
|
|
* @param trainingOptions specify models and checkpoints to load into the Training Session
|
|
* @param sessionOptions specify configuration for training session behavior
|
|
*
|
|
* @returns Promise that resolves to a TrainingSession object
|
|
*/
|
|
create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: InferenceSession.SessionOptions):
|
|
Promise<TrainingSession>;
|
|
|
|
// #endregion
|
|
}
|
|
|
|
// eslint-disable-next-line @typescript-eslint/naming-convention
|
|
export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;
|