mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
[js/web/training] Add CreateTrainingSession (#17891)
### Description * Adds TrainingSession.create() functionality following the web bindings for training design doc * Added 2 new training APIs to wasm/api.h: * OrtTrainingGetInputOutputName * OrtTrainingGetInputOutputCount * Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a method that references it ### Motivation and Context * Adding web bindings for training #### Related work * #16521 allowed for training artifacts to be built * #17333 added interfaces for training * #17474 allows for training package to be built + adds training backend to web package **[MUST BE MERGED IN BEFORE THIS ONE]** --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Co-authored-by: Ashwini Khade <askhade@microsoft.com>
This commit is contained in:
parent
0f72739b6d
commit
64de71c5e2
12 changed files with 399 additions and 12 deletions
|
|
@ -1,11 +1,14 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {resolveBackend} from './backend-impl.js';
|
||||
import {TrainingSessionHandler} from './backend.js';
|
||||
import {InferenceSession as InferenceSession} from './inference-session.js';
|
||||
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
|
||||
|
||||
type SessionOptions = InferenceSession.SessionOptions;
|
||||
const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
|
||||
'Make sure you\'re using the correct configuration & WebAssembly files.';
|
||||
|
||||
export class TrainingSession implements TrainingSessionInterface {
|
||||
private constructor(handler: TrainingSessionHandler) {
|
||||
|
|
@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
return this.handler.outputNames;
|
||||
}
|
||||
|
||||
static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
|
||||
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
|
||||
Promise<TrainingSession> {
|
||||
throw new Error('Method not implemented');
|
||||
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
|
||||
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
|
||||
const options: SessionOptions = sessionOptions || {};
|
||||
|
||||
// get backend hints
|
||||
const eps = options.executionProviders || [];
|
||||
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
|
||||
const backend = await resolveBackend(backendHints);
|
||||
if (backend.createTrainingSessionHandler) {
|
||||
const handler = await backend.createTrainingSessionHandler(
|
||||
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
|
||||
return new TrainingSession(handler);
|
||||
} else {
|
||||
throw new Error(noBackendErrMsg);
|
||||
}
|
||||
}
|
||||
|
||||
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
|
||||
|
|
|
|||
|
|
@ -4,13 +4,17 @@
|
|||
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';
|
||||
|
||||
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
|
||||
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';
|
||||
|
||||
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
|
||||
async createTrainingSessionHandler(
|
||||
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
|
||||
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
|
||||
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
|
||||
throw new Error('Method not implemented yet.');
|
||||
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
|
||||
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
|
||||
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
|
||||
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
|
||||
await handler.createTrainingSession(
|
||||
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
|
||||
return Promise.resolve(handler);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
5
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
5
js/web/lib/wasm/binding/ort-wasm.d.ts
vendored
|
|
@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule {
|
|||
_OrtTrainingCopyParametersFromBuffer?
|
||||
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
|
||||
|
||||
_OrtTrainingGetModelInputOutputCount?
|
||||
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
|
||||
_OrtTrainingGetModelInputOutputName?
|
||||
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;
|
||||
|
||||
_OrtTrainingReleaseSession?(trainingHandle: number): void;
|
||||
// #endregion
|
||||
|
||||
|
|
|
|||
|
|
@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
|
|||
in ?: number;
|
||||
}
|
||||
|
||||
interface MessageIsOrtEnvInitialized extends MessageError {
|
||||
type: 'is-ort-env-initialized';
|
||||
out?: boolean;
|
||||
}
|
||||
|
||||
export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
|
||||
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
|
||||
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
/// <reference lib="webworker" />
|
||||
|
||||
import {OrtWasmMessage} from '../proxy-messages';
|
||||
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
|
||||
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
|
||||
import {initializeWebAssembly} from '../wasm-factory';
|
||||
|
||||
self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
|
||||
|
|
@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
|
|||
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
|
||||
}
|
||||
break;
|
||||
case 'is-ort-env-initialized':
|
||||
try {
|
||||
const ortEnvInitialized = isOrtEnvInitialized();
|
||||
postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
|
||||
} catch (err) {
|
||||
postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata
|
|||
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
|
||||
const runCallbacks: Array<PromiseCallbacks<SerializableTensorMetadata[]>> = [];
|
||||
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
|
||||
const isOrtEnvInitializedCallbacks: Array<PromiseCallbacks<boolean>> = [];
|
||||
|
||||
const ensureWorker = (): void => {
|
||||
if (initializing || !initialized || aborted || !proxyWorker) {
|
||||
|
|
@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
|
|||
endProfilingCallbacks.shift()![0]();
|
||||
}
|
||||
break;
|
||||
case 'is-ort-env-initialized':
|
||||
if (ev.data.err) {
|
||||
isOrtEnvInitializedCallbacks.shift();
|
||||
} else {
|
||||
isOrtEnvInitializedCallbacks.shift();
|
||||
}
|
||||
break;
|
||||
default:
|
||||
}
|
||||
};
|
||||
|
|
@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise<void> => {
|
|||
core.endProfiling(sessionId);
|
||||
}
|
||||
};
|
||||
|
||||
export const isOrtEnvInitialized = async(): Promise<boolean> => {
|
||||
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
|
||||
ensureWorker();
|
||||
return new Promise<boolean>((resolve, reject) => {
|
||||
isOrtEnvInitializedCallbacks.push([resolve, reject]);
|
||||
const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
|
||||
proxyWorker!.postMessage(message);
|
||||
});
|
||||
} else {
|
||||
return core.isOrtEnvInitialized();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
73
js/web/lib/wasm/session-handler-for-training.ts
Normal file
73
js/web/lib/wasm/session-handler-for-training.ts
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';
|
||||
|
||||
import {SerializableModeldata} from './proxy-messages';
|
||||
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
|
||||
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';
|
||||
|
||||
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
|
||||
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
|
||||
throw new Error('Method not implemented.');
|
||||
}
|
||||
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
|
||||
throw new Error('Method not implemented.');
|
||||
}
|
||||
private sessionId: number;
|
||||
private checkpointId: number;
|
||||
|
||||
inputNames: string[];
|
||||
outputNames: string[];
|
||||
|
||||
inputEncodedNames: number[];
|
||||
outputEncodedNames: number[];
|
||||
|
||||
async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
|
||||
let buffer: Uint8Array;
|
||||
if (typeof uriOrBuffer === 'string') {
|
||||
const response = await fetch(uriOrBuffer);
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
buffer = new Uint8Array(arrayBuffer);
|
||||
} else {
|
||||
buffer = uriOrBuffer;
|
||||
}
|
||||
return createSessionAllocate(buffer);
|
||||
}
|
||||
|
||||
async createTrainingSession(
|
||||
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
|
||||
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
|
||||
options: InferenceSession.SessionOptions) {
|
||||
if (!isOrtEnvInitialized()) {
|
||||
await initRuntime(env);
|
||||
}
|
||||
const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
|
||||
const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
|
||||
// 0 is supposed to be the nullptr
|
||||
let evalModelData: SerializableModeldata = [0, 0];
|
||||
let optimizerModelData: SerializableModeldata = [0, 0];
|
||||
|
||||
if (evalModelUriOrBuffer !== '') {
|
||||
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
|
||||
}
|
||||
if (optimizerModelUriOrBuffer !== '') {
|
||||
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
|
||||
}
|
||||
|
||||
this.checkpointId = createCheckpointHandle(checkpointData);
|
||||
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
|
||||
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
|
||||
}
|
||||
|
||||
async dispose(): Promise<void> {
|
||||
return releaseTrainingSessionAndCheckpoint(
|
||||
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
|
||||
}
|
||||
|
||||
async runTrainStep(
|
||||
_feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
|
||||
_options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
|
||||
throw new Error('Method not implemented yet.');
|
||||
}
|
||||
}
|
||||
|
|
@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises';
|
|||
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
|
||||
|
||||
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
|
||||
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
|
||||
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
|
||||
import {isGpuBufferSupportedType} from './wasm-common';
|
||||
|
||||
let runtimeInitialized: boolean;
|
||||
let runtimeInitializationPromise: Promise<void>|undefined;
|
||||
|
||||
const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
|
||||
|
|
@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
|
|||
}
|
||||
|
||||
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
|
||||
if (!runtimeInitialized) {
|
||||
if (!(await isOrtEnvInitialized())) {
|
||||
if (!runtimeInitializationPromise) {
|
||||
runtimeInitializationPromise = initializeRuntime(env);
|
||||
}
|
||||
await runtimeInitializationPromise;
|
||||
runtimeInitializationPromise = undefined;
|
||||
runtimeInitialized = true;
|
||||
}
|
||||
|
||||
if (typeof pathOrBuffer === 'string') {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
|
|||
import {getInstance} from './wasm-factory';
|
||||
import {allocWasmString, checkLastError} from './wasm-utils';
|
||||
|
||||
let ortEnvInitialized = false;
|
||||
|
||||
/**
|
||||
* get the input/output count of the session.
|
||||
* @param sessionHandle the handle representing the session. should be non-zero.
|
||||
|
|
@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise<void> => {
|
|||
const initJsep = require('./jsep/init').init;
|
||||
await initJsep(getInstance(), env);
|
||||
}
|
||||
|
||||
ortEnvInitialized = true;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -93,6 +97,8 @@ type SessionMetadata = [
|
|||
|
||||
const activeSessions = new Map<number, SessionMetadata>();
|
||||
|
||||
export const isOrtEnvInitialized = (): boolean => ortEnvInitialized;
|
||||
|
||||
/**
|
||||
* allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
|
||||
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
|
||||
|
|
|
|||
162
js/web/lib/wasm/wasm-training-core-impl.ts
Normal file
162
js/web/lib/wasm/wasm-training-core-impl.ts
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {InferenceSession} from 'onnxruntime-common';
|
||||
|
||||
import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages';
|
||||
import {setSessionOptions} from './session-options';
|
||||
import {getInstance} from './wasm-factory';
|
||||
import {checkLastError} from './wasm-utils';
|
||||
|
||||
const NO_TRAIN_FUNCS_MSG =
|
||||
'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' +
|
||||
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
|
||||
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
|
||||
|
||||
export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
|
||||
const wasm = getInstance();
|
||||
|
||||
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
|
||||
let checkpointHandle = 0;
|
||||
|
||||
try {
|
||||
if (wasm._OrtTrainingLoadCheckpoint) {
|
||||
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
if (checkpointHandle === 0) {
|
||||
checkLastError('Error occurred when trying to create a CheckpointState.');
|
||||
}
|
||||
return checkpointHandle;
|
||||
} catch (e) {
|
||||
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
|
||||
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
|
||||
}
|
||||
throw e;
|
||||
} finally {
|
||||
// free buffer from wasm heap
|
||||
wasm._OrtFree(checkpointData[0]);
|
||||
}
|
||||
};
|
||||
|
||||
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
|
||||
const wasm = getInstance();
|
||||
const stack = wasm.stackSave();
|
||||
try {
|
||||
const dataOffset = wasm.stackAlloc(8);
|
||||
if (wasm._OrtTrainingGetModelInputOutputCount) {
|
||||
const errorCode =
|
||||
wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
|
||||
if (errorCode !== 0) {
|
||||
checkLastError('Can\'t get session input/output count.');
|
||||
}
|
||||
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
} finally {
|
||||
wasm.stackRestore(stack);
|
||||
}
|
||||
};
|
||||
|
||||
const getModelInputOutputNamesLoop =
|
||||
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
|
||||
const names = [];
|
||||
const wasm = getInstance();
|
||||
|
||||
const namesUTF8Encoded = [];
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
if (wasm._OrtTrainingGetModelInputOutputName) {
|
||||
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
|
||||
if (name === 0) {
|
||||
checkLastError('Can\'t get input or output name');
|
||||
}
|
||||
|
||||
namesUTF8Encoded.push(name);
|
||||
names.push(wasm.UTF8ToString(name));
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
}
|
||||
return [names, namesUTF8Encoded];
|
||||
};
|
||||
|
||||
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
|
||||
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
|
||||
|
||||
const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
|
||||
const [outputNames, outputNamesUTF8Encoded] =
|
||||
getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
|
||||
|
||||
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
|
||||
};
|
||||
|
||||
export const createTrainingSessionHandle =
|
||||
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
|
||||
optimizerModelData: SerializableModeldata,
|
||||
options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
|
||||
const wasm = getInstance();
|
||||
|
||||
let trainingSessionHandle = 0;
|
||||
let sessionOptionsHandle = 0;
|
||||
let allocs: number[] = [];
|
||||
let inputNamesUTF8Encoded: number[] = [];
|
||||
let outputNamesUTF8Encoded: number[] = [];
|
||||
|
||||
let inputNames: string[] = [];
|
||||
let outputNames: string[] = [];
|
||||
|
||||
try {
|
||||
[sessionOptionsHandle, allocs] = setSessionOptions(options);
|
||||
if (wasm._OrtTrainingCreateSession) {
|
||||
trainingSessionHandle = wasm._OrtTrainingCreateSession(
|
||||
sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0],
|
||||
evalModelData[1], optimizerModelData[0], optimizerModelData[1]);
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
if (trainingSessionHandle === 0) {
|
||||
checkLastError('Error occurred when trying to create a TrainingSession.');
|
||||
}
|
||||
|
||||
[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
|
||||
getTrainingModelInputOutputNames(trainingSessionHandle);
|
||||
return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
|
||||
|
||||
} catch (e) {
|
||||
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
|
||||
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
|
||||
}
|
||||
throw e;
|
||||
} finally {
|
||||
wasm._free(trainModelData[0]);
|
||||
wasm._free(evalModelData[0]);
|
||||
wasm._free(optimizerModelData[0]);
|
||||
|
||||
if (sessionOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
||||
}
|
||||
allocs.forEach(alloc => wasm._free(alloc));
|
||||
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
}
|
||||
};
|
||||
|
||||
export const releaseTrainingSessionAndCheckpoint =
|
||||
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
|
||||
void => {
|
||||
const wasm = getInstance();
|
||||
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
|
||||
if (wasm._OrtTrainingReleaseSession) {
|
||||
wasm._OrtTrainingReleaseSession(sessionId);
|
||||
}
|
||||
if (wasm._OrtTrainingReleaseCheckpoint) {
|
||||
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
|
||||
}
|
||||
};
|
||||
|
|
@ -493,6 +493,14 @@ char* OrtEndProfiling(ort_session_handle_t session) {
|
|||
#define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \
|
||||
CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__))
|
||||
|
||||
#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \
|
||||
do { \
|
||||
int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \
|
||||
if (error_code != ORT_OK) { \
|
||||
return error_code; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer,
|
||||
size_t checkpoint_size) {
|
||||
OrtCheckpointState* checkpoint_state = nullptr;
|
||||
|
|
@ -571,6 +579,57 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio
|
|||
return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only);
|
||||
}
|
||||
|
||||
int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle,
|
||||
size_t* input_count,
|
||||
size_t* output_count,
|
||||
bool isEvalModel) {
|
||||
if (isEvalModel) {
|
||||
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count);
|
||||
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count);
|
||||
return ORT_OK;
|
||||
} else {
|
||||
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count);
|
||||
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count);
|
||||
return ORT_OK;
|
||||
}
|
||||
}
|
||||
|
||||
char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle,
|
||||
size_t index,
|
||||
bool isInput,
|
||||
bool isEvalModel) {
|
||||
OrtAllocator* allocator = nullptr;
|
||||
RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator);
|
||||
|
||||
char* name = nullptr;
|
||||
|
||||
if (isEvalModel) {
|
||||
if (isInput) {
|
||||
return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index,
|
||||
allocator, &name) == ORT_OK)
|
||||
? name
|
||||
: nullptr;
|
||||
} else {
|
||||
return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index,
|
||||
allocator, &name) == ORT_OK)
|
||||
? name
|
||||
: nullptr;
|
||||
}
|
||||
} else {
|
||||
if (isInput) {
|
||||
return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index,
|
||||
allocator, &name) == ORT_OK)
|
||||
? name
|
||||
: nullptr;
|
||||
} else {
|
||||
return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index,
|
||||
allocator, &name) == ORT_OK)
|
||||
? name
|
||||
: nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) {
|
||||
Ort::GetTrainingApi().ReleaseTrainingSession(training_handle);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -432,6 +432,35 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio
|
|||
size_t parameter_count,
|
||||
bool trainable_only);
|
||||
|
||||
/**
|
||||
* Gets the input count and output count of the training or eval model associated with the given training handle.
|
||||
* @param traning_handle handle of the traning session
|
||||
* @param input_count [out] a pointer to a size_t variable to accept input_count
|
||||
* @param output_count [out] a pointer to a size_t variable to accept output_count
|
||||
* @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output
|
||||
* count of the eval model.
|
||||
* @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message.
|
||||
*/
|
||||
int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle,
|
||||
size_t* input_count,
|
||||
size_t* output_count,
|
||||
bool isEvalModel);
|
||||
|
||||
/**
|
||||
* Gets the input or output name at the specified index associated with the training or eval model from the
|
||||
* given training session.
|
||||
* @param traning_handle handle of the traning session
|
||||
* @param index the input or output index
|
||||
* @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name.
|
||||
* @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output
|
||||
* names of the eval model.
|
||||
* @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by
|
||||
*/
|
||||
char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle,
|
||||
size_t index,
|
||||
bool isInput,
|
||||
bool isEvalModel);
|
||||
|
||||
/**
|
||||
* @brief Release the specified ORT training session.
|
||||
*
|
||||
|
|
|
|||
Loading…
Reference in a new issue