onnxruntime/js/web/lib/wasm/wasm-training-core-impl.ts
Caroline Zhu 64de71c5e2
[js/web/training] Add CreateTrainingSession (#17891)
### Description
* Adds TrainingSession.create() functionality following the web bindings
for training design doc
* Added 2 new training APIs to wasm/api.h:
   * OrtTrainingGetInputOutputName
   * OrtTrainingGetInputOutputCount
* Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a
method that references it

### Motivation and Context
* Adding web bindings for training

#### Related work
* #16521 allowed for training artifacts to be built
* #17333 added interfaces for training
* #17474 allows for training package to be built + adds training backend
to web package **[MUST BE MERGED IN BEFORE THIS ONE]**

---------

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Co-authored-by: Ashwini Khade <askhade@microsoft.com>
2023-10-26 09:22:10 -07:00

162 lines
6.2 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession} from 'onnxruntime-common';
import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages';
import {setSessionOptions} from './session-options';
import {getInstance} from './wasm-factory';
import {checkLastError} from './wasm-utils';
const NO_TRAIN_FUNCS_MSG =
'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' +
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
const wasm = getInstance();
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
let checkpointHandle = 0;
try {
if (wasm._OrtTrainingLoadCheckpoint) {
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
if (checkpointHandle === 0) {
checkLastError('Error occurred when trying to create a CheckpointState.');
}
return checkpointHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
}
throw e;
} finally {
// free buffer from wasm heap
wasm._OrtFree(checkpointData[0]);
}
};
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
const wasm = getInstance();
const stack = wasm.stackSave();
try {
const dataOffset = wasm.stackAlloc(8);
if (wasm._OrtTrainingGetModelInputOutputCount) {
const errorCode =
wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
if (errorCode !== 0) {
checkLastError('Can\'t get session input/output count.');
}
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};
const getModelInputOutputNamesLoop =
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
const names = [];
const wasm = getInstance();
const namesUTF8Encoded = [];
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetModelInputOutputName) {
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
if (name === 0) {
checkLastError('Can\'t get input or output name');
}
namesUTF8Encoded.push(name);
names.push(wasm.UTF8ToString(name));
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
}
return [names, namesUTF8Encoded];
};
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
const [outputNames, outputNamesUTF8Encoded] =
getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
};
export const createTrainingSessionHandle =
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
optimizerModelData: SerializableModeldata,
options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
const wasm = getInstance();
let trainingSessionHandle = 0;
let sessionOptionsHandle = 0;
let allocs: number[] = [];
let inputNamesUTF8Encoded: number[] = [];
let outputNamesUTF8Encoded: number[] = [];
let inputNames: string[] = [];
let outputNames: string[] = [];
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
if (wasm._OrtTrainingCreateSession) {
trainingSessionHandle = wasm._OrtTrainingCreateSession(
sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0],
evalModelData[1], optimizerModelData[0], optimizerModelData[1]);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
if (trainingSessionHandle === 0) {
checkLastError('Error occurred when trying to create a TrainingSession.');
}
[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
getTrainingModelInputOutputNames(trainingSessionHandle);
return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
} catch (e) {
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
}
throw e;
} finally {
wasm._free(trainModelData[0]);
wasm._free(evalModelData[0]);
wasm._free(optimizerModelData[0]);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
}
};
export const releaseTrainingSessionAndCheckpoint =
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
void => {
const wasm = getInstance();
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
};