onnxruntime/js/common/lib/training-session.ts
Yulong Wang abdc31de40
[js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728)
### Description

See
454996d496
for manual changes (excluded auto-generated formatting changes)

### Why

Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.

- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.

No one in community seems interested in fixing those.

Choose Prettier as it is the most popular TS/JS formatter.

### How to merge

It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.

So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
2024-08-14 16:51:22 -07:00

206 lines
7 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js';
/* eslint-disable @typescript-eslint/no-redeclare */
export declare namespace TrainingSession {
/**
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
*/
type UriOrBuffer = string | Uint8Array;
}
/**
* Represent a runtime instance of an ONNX training session,
* which contains a model that can be trained, and, optionally,
* an eval and optimizer model.
*/
export interface TrainingSession {
// #region run()
/**
* Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
* runOptimizerStep.
*/
lazyResetGrad(): Promise<void>;
/**
* Run TrainStep asynchronously with the given feeds and options.
*
* @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
detail.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Run a single train step with the given inputs and options.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType,
fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType,
fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
// #endregion
// #region copy parameters
/**
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
* the parameters) elements of all the parameters in the training state.
*
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;
/**
* Copies parameter values from the given buffer to the training state. Currently, only supporting models with
* parameters of type Float32.
*
* @param buffer - A Uint8Array representation of Float32 parameters.
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
*/
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
/**
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
* Currently, only supporting models with parameters of type Float32.
*
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
* for which requires_grad is set to true. Default value is true.
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion
// #region release()
/**
* Release the inference session and the underlying resources.
*/
release(): Promise<void>;
// #endregion
// #region metadata
/**
* Get input names of the loaded training model.
*/
readonly trainingInputNames: readonly string[];
/**
* Get output names of the loaded training model.
*/
readonly trainingOutputNames: readonly string[];
/**
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalInputNames: readonly string[];
/**
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalOutputNames: readonly string[];
// #endregion
}
/**
* Represents the optional parameters that can be passed into the TrainingSessionFactory.
*/
export interface TrainingSessionCreateOptions {
/**
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
*/
checkpointState: TrainingSession.UriOrBuffer;
/**
* URI or buffer for the .onnx training file.
*/
trainModel: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx optimizer model file.
*/
optimizerModel?: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx eval model file.
*/
evalModel?: TrainingSession.UriOrBuffer;
}
/**
* Defines method overload possibilities for creating a TrainingSession.
*/
export interface TrainingSessionFactory {
// #region create()
/**
* Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
*
* @param trainingOptions specify models and checkpoints to load into the Training Session
* @param sessionOptions specify configuration for training session behavior
*
* @returns Promise that resolves to a TrainingSession object
*/
create(
trainingOptions: TrainingSessionCreateOptions,
sessionOptions?: InferenceSession.SessionOptions,
): Promise<TrainingSession>;
// #endregion
}
// eslint-disable-next-line @typescript-eslint/naming-convention
export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;