onnxruntime/js/common/lib/training-session-impl.ts
Caroline Zhu 6a5f469d44
Add training interfaces to js/common (#17333)
### Description
Following the design document:
* Added CreateTrainingSessionHandler to the Backend interface
* All existing Backend implementations throw an error for the new method
createTrainingSessionHandler
* Created TrainingSession namespace, interface, and
TrainingSessionFactory interface
* Created TrainingSessionImpl class implementation 

As methods are implemented, the TrainingSession interface will be added
to or modified.

### Motivation and Context
Adding the public-facing interfaces to the onnxruntime-common package is
one of the first steps to support ORT training for web bindings.

---------

Co-authored-by: Caroline Zhu <carolinezhu@microsoft.com>
2023-09-29 19:05:10 -07:00

49 lines
1.8 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
type SessionOptions = InferenceSession.SessionOptions;
export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
this.handler = handler;
}
private handler: TrainingSessionHandler;
get inputNames(): readonly string[] {
return this.handler.inputNames;
}
get outputNames(): readonly string[] {
return this.handler.outputNames;
}
static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
}
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined):
Promise<InferenceSession.OnnxValueMapType>;
runTrainStep(
feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions|undefined): Promise<InferenceSession.OnnxValueMapType>;
async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown):
Promise<InferenceSession.OnnxValueMapType> {
throw new Error('Method not implemented.');
}
async release(): Promise<void> {
return this.handler.dispose();
}
}