[js/web/training] Implemented runEvalStep & runOptimizerStep (#18259)

### Description
* implemented runEvalStep and runOptimizerStep
* added hasEvalModel and hasOptimizerModel boolean fields in
TrainingSession representation
* added evalInputNames and evalOutputNames fields to
TrainingSessionHandler & TrainingSession
* removed the inputNamesEncoded and outputNamesEncoded fields from
TrainingSessionHandler -- since none of the training methods require the
input names and output names as parameters, there's no need to store
them.

### Motivation and Context
* part of the work for implementing web bindings for training
* previous PR: #18250

---------

Co-authored-by: Ashwini Khade <askhade@microsoft.com>
This commit is contained in:
Caroline Zhu 2023-12-04 13:37:14 -08:00 committed by GitHub
parent 5353adcde3
commit c02a386145
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 242 additions and 61 deletions

View file

@ -45,9 +45,16 @@ export interface InferenceSessionHandler extends SessionHandler {
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];
runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;

View file

@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';
export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
this.hasOptimizerModel = hasOptimizerModel;
this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
private hasOptimizerModel: boolean;
private hasEvalModel: boolean;
get inputNames(): readonly string[] {
get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
get outputNames(): readonly string[] {
get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}
get evalInputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalInputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
get evalOutputNames(): readonly string[] {
if (this.hasEvalModel) {
return this.handler.evalOutputNames;
} else {
throw new Error('This training session has no evalModel loaded.');
}
}
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
}
@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface {
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
* names.
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
* names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
[SessionHandler.FetchesType, RunOptions] {
typeNarrowingForRunStep(
inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions,
arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
if (this.outputNames.indexOf(name) === -1) {
if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface {
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
for (const name of this.outputNames) {
for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
@ -130,7 +154,7 @@ export class TrainingSession implements TrainingSessionInterface {
}
// check if all inputs are in feed
for (const name of this.inputNames) {
for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
@ -138,7 +162,7 @@ export class TrainingSession implements TrainingSessionInterface {
// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
for (const name of this.outputNames) {
for (const name of outputNames) {
fetches[name] = null;
}
}
@ -171,11 +195,33 @@ export class TrainingSession implements TrainingSessionInterface {
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
const [fetches, options] =
this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}
async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise<void> {
if (this.hasOptimizerModel) {
await this.handler.runOptimizerStep(options || {});
} else {
throw new Error('This TrainingSession has no OptimizerModel loaded.');
}
}
runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise<ReturnType>;
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise<ReturnType>;
async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
if (this.hasEvalModel) {
const [fetches, options] =
this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2);
const results = await this.handler.runEvalStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
} else {
throw new Error('This TrainingSession has no EvalModel loaded.');
}
}
async getParametersSize(trainableOnly = true): Promise<number> {
return this.handler.getParametersSize(trainableOnly);
}

View file

@ -39,7 +39,7 @@ export interface TrainingSession {
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model inference.
* @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
@ -47,6 +47,38 @@ export interface TrainingSession {
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;
/**
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
*
* @param options - Optional. A set of options that controls the behavior of model optimizing.
*/
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
Promise<InferenceSession.ReturnType>;
/**
* Run a single eval step with the given inputs and options using the eval model.
*
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
* @param options - Optional. A set of options that controls the behavior of model eval step.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
runEvalStep(
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;
// #endregion
// #region copy parameters
@ -90,14 +122,25 @@ export interface TrainingSession {
// #region metadata
/**
* Get input names of the loaded model.
* Get input names of the loaded training model.
*/
readonly inputNames: readonly string[];
readonly trainingInputNames: readonly string[];
/**
* Get output names of the loaded model.
* Get output names of the loaded training model.
*/
readonly outputNames: readonly string[];
readonly trainingOutputNames: readonly string[];
/**
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalInputNames: readonly string[];
/**
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
readonly evalOutputNames: readonly string[];
// #endregion
}

View file

@ -6,7 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl';
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
private sessionId: number;
@ -15,8 +15,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
inputNames: string[];
outputNames: string[];
inputEncodedNames: number[];
outputEncodedNames: number[];
evalInputNames: string[] = [];
evalOutputNames: string[] = [];
async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
@ -51,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
}
this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
this.sessionId =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
if (evalModelUriOrBuffer !== '') {
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
}
}
/**
@ -118,6 +122,27 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
await runOptimizerStep(this.sessionId, options);
}
async runEvalStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds, this.evalInputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`));
const [outputArray, outputIndices, outputs] =
this.convertMapIntoValuesArrayAndIndicesArray<Tensor|null, TensorMetadata|null>(
fetches, this.evalOutputNames,
(t, i): TensorMetadata|null =>
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null);
const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}
@ -131,7 +156,6 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
}
async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
}
}

View file

@ -3,7 +3,7 @@
import {InferenceSession, Tensor} from 'onnxruntime-common';
import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages';
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
@ -77,50 +77,44 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea
};
const getModelInputOutputNamesLoop =
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => {
const names = [];
const wasm = getInstance();
const namesUTF8Encoded = [];
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetModelInputOutputName) {
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
namesUTF8Encoded.push(name);
names.push(wasm.UTF8ToString(name));
wasm._free(name);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
}
return [names, namesUTF8Encoded];
return names;
};
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
let inputNames: string[] = [];
let outputNames: string[] = [];
const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
const [outputNames, outputNamesUTF8Encoded] =
getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
return [inputNames, outputNames];
};
export const createTrainingSessionHandle =
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
optimizerModelData: SerializableModeldata,
options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => {
const wasm = getInstance();
let trainingSessionHandle = 0;
let sessionOptionsHandle = 0;
let allocs: number[] = [];
let inputNamesUTF8Encoded: number[] = [];
let outputNamesUTF8Encoded: number[] = [];
let inputNames: string[] = [];
let outputNames: string[] = [];
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
@ -133,11 +127,7 @@ export const createTrainingSessionHandle =
}
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
getTrainingModelInputOutputNames(trainingSessionHandle);
return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
return trainingSessionHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
@ -152,8 +142,6 @@ export const createTrainingSessionHandle =
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
}
};
@ -317,6 +305,83 @@ export const runTrainStep = async(
}
};
export const runOptimizerStep =
async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise<void> => {
const wasm = getInstance();
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
try {
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
if (wasm._OrtTrainingOptimizerStep) {
const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
}
runOptionsAllocs.forEach(p => wasm._free(p));
}
};
export const runEvalStep = async(
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
const wasm = getInstance();
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
const inputTensorHandles: number[] = [];
const outputTensorHandles: number[] = [];
const inputOutputAllocs: number[] = [];
const beforeRunStack = wasm.stackSave();
try {
// prepare parameters by moving them to heap
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
// handle inputs -- you don't want anything added to the index
const inputValuesOffset = createAndAllocateTensors(
trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0);
// handle outputs
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
const outputValuesOffset = createAndAllocateTensors(
trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount);
if (wasm._OrtTrainingEvalStep) {
const errorCode = wasm._OrtTrainingEvalStep(
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
} finally {
wasm.stackRestore(beforeRunStack);
inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
inputOutputAllocs.forEach(p => wasm._free(p));
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
}
runOptionsAllocs.forEach(p => wasm._free(p));
}
};
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
const wasm = getInstance();
const stack = wasm.stackSave();
@ -439,17 +504,13 @@ export const loadParametersBuffer =
}
};
export const releaseTrainingSessionAndCheckpoint =
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
void => {
const wasm = getInstance();
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
const wasm = getInstance();
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
};
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
};