[js] remove more unused training types (#22753)

### Description

remove more unused training types
This commit is contained in:
Yulong Wang 2024-12-04 16:44:09 -08:00 committed by GitHub
parent 3975e79303
commit 3234487385
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 0 additions and 531 deletions

View file

@ -198,19 +198,6 @@ module.exports = {
'_OrtReleaseTensor',
'_OrtRun',
'_OrtRunWithBinding',
'_OrtTrainingCopyParametersFromBuffer',
'_OrtTrainingCopyParametersToBuffer',
'_OrtTrainingCreateSession',
'_OrtTrainingEvalStep',
'_OrtTrainingGetModelInputOutputCount',
'_OrtTrainingGetModelInputOutputName',
'_OrtTrainingGetParametersSize',
'_OrtTrainingLazyResetGrad',
'_OrtTrainingLoadCheckpoint',
'_OrtTrainingOptimizerStep',
'_OrtTrainingReleaseCheckpoint',
'_OrtTrainingReleaseSession',
'_OrtTrainingRunTrainStep',
],
},
],

View file

@ -3,7 +3,6 @@
import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { TrainingSession } from './training-session.js';
/**
* @ignore
@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler {
): Promise<SessionHandler.ReturnType>;
}
/**
* Represent a handler instance of a training inference session.
*
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];
lazyResetGrad(): Promise<void>;
runTrainStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}
/**
* Represent a backend that provides implementation of model inferencing.
*
@ -84,14 +56,6 @@ export interface Backend {
uriOrBuffer: string | Uint8Array,
options?: InferenceSession.SessionOptions,
): Promise<InferenceSessionHandler>;
createTrainingSessionHandler?(
checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer,
trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
evalModelUriOrBuffer: TrainingSession.UriOrBuffer,
optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
options: InferenceSession.SessionOptions,
): Promise<TrainingSessionHandler>;
}
export { registerBackend } from './backend-impl.js';

View file

@ -14,7 +14,6 @@ export declare namespace Env {
* If not modified, the filename of the .wasm file is:
* - `ort-wasm-simd-threaded.wasm` for default build
* - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN)
* - `ort-training-wasm-simd-threaded.wasm` for training build
*/
wasm?: URL | string;
/**
@ -25,7 +24,6 @@ export declare namespace Env {
* If not modified, the filename of the .mjs file is:
* - `ort-wasm-simd-threaded.mjs` for default build
* - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN)
* - `ort-training-wasm-simd-threaded.mjs` for training build
*/
mjs?: URL | string;
}

View file

@ -26,4 +26,3 @@ export * from './tensor-factory.js';
export * from './trace.js';
export * from './onnx-model.js';
export * from './onnx-value.js';
export * from './training-session.js';

View file

@ -1,273 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { resolveBackendAndExecutionProviders } from './backend-impl.js';
import { SessionHandler, TrainingSessionHandler } from './backend.js';
import { InferenceSession as InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { Tensor } from './tensor.js';
import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js';
type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = InferenceSession.FeedsType;
type FetchesType = InferenceSession.FetchesType;
type ReturnType = InferenceSession.ReturnType;
type RunOptions = InferenceSession.RunOptions;
const noBackendErrMsg: string =
'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files.";
export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
private hasOptimizerModel: boolean;
private hasEvalModel: boolean;
get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}
get evalInputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
static async create(
trainingOptions: TrainingSessionCreateOptions,
sessionOptions?: SessionOptions,
): Promise<TrainingSession> {
const evalModel: string | Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};
// resolve backend, update session options with validated EPs, and create session handler
const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState,
trainingOptions.trainModel,
evalModel,
optimizerModel,
optionsWithValidatedEPs,
);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
}
}
/**
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
* names.
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
* names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(
inputNames: readonly string[],
outputNames: readonly string[],
feeds: FeedsType,
arg1?: FetchesType | RunOptions,
arg2?: RunOptions,
): [SessionHandler.FetchesType, RunOptions] {
const fetches: { [name: string]: OnnxValue | null } = {};
let options: RunOptions = {};
// check inputs
if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
throw new TypeError(
"'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.",
);
}
let isFetchesEmpty = true;
// determine which override is being used
if (typeof arg1 === 'object') {
if (arg1 === null) {
throw new TypeError('Unexpected argument[1]: cannot be null.');
}
if (arg1 instanceof Tensor) {
throw new TypeError("'fetches' cannot be a Tensor");
}
if (Array.isArray(arg1)) {
if (arg1.length === 0) {
throw new TypeError("'fetches' cannot be an empty array.");
}
isFetchesEmpty = false;
// output names
for (const name of arg1) {
if (typeof name !== 'string') {
throw new TypeError("'fetches' must be a string array or an object.");
}
if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
}
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
} else {
// decide whether arg1 is fetches or options
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
isFetches = true;
isFetchesEmpty = false;
fetches[name] = v;
}
}
}
if (isFetches) {
if (typeof arg2 === 'object' && arg2 !== null) {
options = arg2;
} else if (typeof arg2 !== 'undefined') {
throw new TypeError("'options' must be an object.");
}
} else {
options = arg1 as RunOptions;
}
}
} else if (typeof arg1 !== 'undefined') {
throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
}
// check if all inputs are in feed
for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
}
// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of outputNames) {
fetches[name] = null;
}
}
return [fetches, options];
}
/**
* Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
* and changes it into a map of Tensors.
*
* @param results
* @returns
*/
convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType {
const returnValue: { [name: string]: OnnxValue } = {};
for (const key in results) {
if (Object.hasOwnProperty.call(results, key)) {
const result = results[key];
if (result instanceof Tensor) {
returnValue[key] = result;
} else {
returnValue[key] = new Tensor(result.type, result.data, result.dims);
}
}
}
return returnValue;
}
async lazyResetGrad(): Promise<void> {
await this.handler.lazyResetGrad();
}
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(
this.trainingInputNames,
this.trainingOutputNames,
feeds,
arg1,
arg2,
);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}
async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise<void> {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
} else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}
runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise<ReturnType>;
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise<ReturnType>;
async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise<ReturnType> {
if (this.hasEvalModel) {
const [fetches, options] = this.typeNarrowingForRunStep(
this.evalInputNames,
this.evalOutputNames,
feeds,
arg1,
arg2,
);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
} else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}
async getParametersSize(trainableOnly = true): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}
async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise<void> {
const paramsSize = await this.getParametersSize(trainableOnly);
// checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
// of parameters
if (array.length !== 4 * paramsSize) {
throw new Error(
'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
'the model. Please use getParametersSize method to check.',
);
}
return this.handler.loadParametersBuffer(array, trainableOnly);
}
async getContiguousParameters(trainableOnly = true): Promise<OnnxValue> {
return this.handler.getContiguousParameters(trainableOnly);
}
async release(): Promise<void> {
return this.handler.dispose();
}
}

View file

@ -1,206 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js';
/* eslint-disable @typescript-eslint/no-redeclare */
export declare namespace TrainingSession {
/**
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
*/
type UriOrBuffer = string | Uint8Array;
}
/**
* Represent a runtime instance of an ONNX training session,
* which contains a model that can be trained, and, optionally,
* an eval and optimizer model.
*/
export interface TrainingSession {
// #region run()
/**
* Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
* runOptimizerStep.
*/
lazyResetGrad(): Promise<void>;
/**
* Run TrainStep asynchronously with the given feeds and options.
*
* @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
detail.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Run a single train step with the given inputs and options.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runTrainStep(
feeds: InferenceSession.FeedsType,
fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType,
fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions,
): Promise<InferenceSession.ReturnType>;
// #endregion
// #region copy parameters
/**
* Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
* the parameters) elements of all the parameters in the training state.
*
* @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
*/
getParametersSize(trainableOnly: boolean): Promise<number>;
/**
* Copies parameter values from the given buffer to the training state. Currently, only supporting models with
* parameters of type Float32.
*
* @param buffer - A Uint8Array representation of Float32 parameters.
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
*/
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
/**
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
* Currently, only supporting models with parameters of type Float32.
*
* @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
* for which requires_grad is set to true. Default value is true.
* @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
*/
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
// #endregion
// #region release()
/**
* Release the inference session and the underlying resources.
*/
release(): Promise<void>;
// #endregion
// #region metadata
/**
* Get input names of the loaded training model.
*/
readonly trainingInputNames: readonly string[];
/**
* Get output names of the loaded training model.
*/
readonly trainingOutputNames: readonly string[];
/**
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalInputNames: readonly string[];
/**
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalOutputNames: readonly string[];
// #endregion
}
/**
* Represents the optional parameters that can be passed into the TrainingSessionFactory.
*/
export interface TrainingSessionCreateOptions {
/**
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
*/
checkpointState: TrainingSession.UriOrBuffer;
/**
* URI or buffer for the .onnx training file.
*/
trainModel: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx optimizer model file.
*/
optimizerModel?: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx eval model file.
*/
evalModel?: TrainingSession.UriOrBuffer;
}
/**
* Defines method overload possibilities for creating a TrainingSession.
*/
export interface TrainingSessionFactory {
// #region create()
/**
* Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
*
* @param trainingOptions specify models and checkpoints to load into the Training Session
* @param sessionOptions specify configuration for training session behavior
*
* @returns Promise that resolves to a TrainingSession object
*/
create(
trainingOptions: TrainingSessionCreateOptions,
sessionOptions?: InferenceSession.SessionOptions,
): Promise<TrainingSession>;
// #endregion
}
// eslint-disable-next-line @typescript-eslint/naming-convention
export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;