mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
### Description * Adds TrainingSession.create() functionality following the web bindings for training design doc * Added 2 new training APIs to wasm/api.h: * OrtTrainingGetInputOutputName * OrtTrainingGetInputOutputCount * Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a method that references it ### Motivation and Context * Adding web bindings for training #### Related work * #16521 allowed for training artifacts to be built * #17333 added interfaces for training * #17474 allows for training package to be built + adds training backend to web package **[MUST BE MERGED IN BEFORE THIS ONE]** --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Ashwini Khade <askhade@microsoft.com>
162 lines
6.2 KiB
TypeScript
162 lines
6.2 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {InferenceSession} from 'onnxruntime-common';
|
|
|
|
import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages';
|
|
import {setSessionOptions} from './session-options';
|
|
import {getInstance} from './wasm-factory';
|
|
import {checkLastError} from './wasm-utils';
|
|
|
|
const NO_TRAIN_FUNCS_MSG =
|
|
'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' +
|
|
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
|
|
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
|
|
|
|
export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
|
|
const wasm = getInstance();
|
|
|
|
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
|
|
let checkpointHandle = 0;
|
|
|
|
try {
|
|
if (wasm._OrtTrainingLoadCheckpoint) {
|
|
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
if (checkpointHandle === 0) {
|
|
checkLastError('Error occurred when trying to create a CheckpointState.');
|
|
}
|
|
return checkpointHandle;
|
|
} catch (e) {
|
|
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
|
|
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
|
|
}
|
|
throw e;
|
|
} finally {
|
|
// free buffer from wasm heap
|
|
wasm._OrtFree(checkpointData[0]);
|
|
}
|
|
};
|
|
|
|
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
|
|
const wasm = getInstance();
|
|
const stack = wasm.stackSave();
|
|
try {
|
|
const dataOffset = wasm.stackAlloc(8);
|
|
if (wasm._OrtTrainingGetModelInputOutputCount) {
|
|
const errorCode =
|
|
wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
|
|
if (errorCode !== 0) {
|
|
checkLastError('Can\'t get session input/output count.');
|
|
}
|
|
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
} finally {
|
|
wasm.stackRestore(stack);
|
|
}
|
|
};
|
|
|
|
const getModelInputOutputNamesLoop =
|
|
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
|
|
const names = [];
|
|
const wasm = getInstance();
|
|
|
|
const namesUTF8Encoded = [];
|
|
|
|
for (let i = 0; i < count; i++) {
|
|
if (wasm._OrtTrainingGetModelInputOutputName) {
|
|
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
|
|
if (name === 0) {
|
|
checkLastError('Can\'t get input or output name');
|
|
}
|
|
|
|
namesUTF8Encoded.push(name);
|
|
names.push(wasm.UTF8ToString(name));
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
}
|
|
return [names, namesUTF8Encoded];
|
|
};
|
|
|
|
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
|
|
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
|
|
|
|
const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
|
|
const [outputNames, outputNamesUTF8Encoded] =
|
|
getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
|
|
|
|
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
|
|
};
|
|
|
|
export const createTrainingSessionHandle =
|
|
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
|
|
optimizerModelData: SerializableModeldata,
|
|
options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
|
|
const wasm = getInstance();
|
|
|
|
let trainingSessionHandle = 0;
|
|
let sessionOptionsHandle = 0;
|
|
let allocs: number[] = [];
|
|
let inputNamesUTF8Encoded: number[] = [];
|
|
let outputNamesUTF8Encoded: number[] = [];
|
|
|
|
let inputNames: string[] = [];
|
|
let outputNames: string[] = [];
|
|
|
|
try {
|
|
[sessionOptionsHandle, allocs] = setSessionOptions(options);
|
|
if (wasm._OrtTrainingCreateSession) {
|
|
trainingSessionHandle = wasm._OrtTrainingCreateSession(
|
|
sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0],
|
|
evalModelData[1], optimizerModelData[0], optimizerModelData[1]);
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
if (trainingSessionHandle === 0) {
|
|
checkLastError('Error occurred when trying to create a TrainingSession.');
|
|
}
|
|
|
|
[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
|
|
getTrainingModelInputOutputNames(trainingSessionHandle);
|
|
return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
|
|
|
|
} catch (e) {
|
|
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
|
|
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
|
|
}
|
|
throw e;
|
|
} finally {
|
|
wasm._free(trainModelData[0]);
|
|
wasm._free(evalModelData[0]);
|
|
wasm._free(optimizerModelData[0]);
|
|
|
|
if (sessionOptionsHandle !== 0) {
|
|
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
|
}
|
|
allocs.forEach(alloc => wasm._free(alloc));
|
|
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
}
|
|
};
|
|
|
|
export const releaseTrainingSessionAndCheckpoint =
|
|
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
|
|
void => {
|
|
const wasm = getInstance();
|
|
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
|
|
|
if (wasm._OrtTrainingReleaseSession) {
|
|
wasm._OrtTrainingReleaseSession(sessionId);
|
|
}
|
|
if (wasm._OrtTrainingReleaseCheckpoint) {
|
|
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
|
|
}
|
|
};
|