[js/web/training] Add CreateTrainingSession (#17891)

### Description
* Adds TrainingSession.create() functionality following the web bindings
for training design doc
* Added 2 new training APIs to wasm/api.h:
   * OrtTrainingGetInputOutputName
   * OrtTrainingGetInputOutputCount
* Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a
method that references it

### Motivation and Context
* Adding web bindings for training

#### Related work
* #16521 allowed for training artifacts to be built
* #17333 added interfaces for training
* #17474 allows for training package to be built + adds training backend
to web package **[MUST BE MERGED IN BEFORE THIS ONE]**

---------

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Co-authored-by: Ashwini Khade <askhade@microsoft.com>
This commit is contained in:
Caroline Zhu 2023-10-26 09:22:10 -07:00 committed by GitHub
parent 0f72739b6d
commit 64de71c5e2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 399 additions and 12 deletions

View file

@ -1,11 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
type SessionOptions = InferenceSession.SessionOptions;
const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';
export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}
static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise<TrainingSession> {
throw new Error('Method not implemented');
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};
// get backend hints
const eps = options.executionProviders || [];
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
return new TrainingSession(handler);
} else {
throw new Error(noBackendErrMsg);
}
}
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {

View file

@ -4,13 +4,17 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
return Promise.resolve(handler);
}
}

View file

@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
_OrtTrainingGetModelInputOutputCount?
(trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
_OrtTrainingGetModelInputOutputName?
(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;
_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion

View file

@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
in ?: number;
}
interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}
export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;

View file

@ -4,7 +4,7 @@
/// <reference lib="webworker" />
import {OrtWasmMessage} from '../proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';
self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent<OrtWasmMessage>): void => {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
case 'is-ort-env-initialized':
try {
const ortEnvInitialized = isOrtEnvInitialized();
postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
} catch (err) {
postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
}
break;
default:
}
};

View file

@ -24,6 +24,7 @@ const createSessionCallbacks: Array<PromiseCallbacks<SerializableSessionMetadata
const releaseSessionCallbacks: Array<PromiseCallbacks<void>> = [];
const runCallbacks: Array<PromiseCallbacks<SerializableTensorMetadata[]>> = [];
const endProfilingCallbacks: Array<PromiseCallbacks<void>> = [];
const isOrtEnvInitializedCallbacks: Array<PromiseCallbacks<boolean>> = [];
const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent<OrtWasmMessage>): void => {
endProfilingCallbacks.shift()![0]();
}
break;
case 'is-ort-env-initialized':
if (ev.data.err) {
isOrtEnvInitializedCallbacks.shift()![1](ev.data.err);
} else {
isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!);
}
break;
default:
}
};
@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise<void> => {
core.endProfiling(sessionId);
}
};
export const isOrtEnvInitialized = async(): Promise<boolean> => {
if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
ensureWorker();
return new Promise<boolean>((resolve, reject) => {
isOrtEnvInitializedCallbacks.push([resolve, reject]);
const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
proxyWorker!.postMessage(message);
});
} else {
return core.isOrtEnvInitialized();
}
};

View file

@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';
import {SerializableModeldata} from './proxy-messages';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise<void> {
throw new Error('Method not implemented.');
}
async getContiguousParameters(_trainableOnly: boolean): Promise<Uint8Array> {
throw new Error('Method not implemented.');
}
private sessionId: number;
private checkpointId: number;
inputNames: string[];
outputNames: string[];
inputEncodedNames: number[];
outputEncodedNames: number[];
async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
const arrayBuffer = await response.arrayBuffer();
buffer = new Uint8Array(arrayBuffer);
} else {
buffer = uriOrBuffer;
}
return createSessionAllocate(buffer);
}
async createTrainingSession(
checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
options: InferenceSession.SessionOptions) {
if (!isOrtEnvInitialized()) {
await initRuntime(env);
}
const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
let evalModelData: SerializableModeldata = [0, 0];
let optimizerModelData: SerializableModeldata = [0, 0];
if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
}
if (optimizerModelUriOrBuffer !== '') {
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
}
this.checkpointId = createCheckpointHandle(checkpointData);
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
}
async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
}
async runTrainStep(
_feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
_options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
throw new Error('Method not implemented yet.');
}
}

View file

@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises';
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';
let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise<void>|undefined;
const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
if (!runtimeInitialized) {
if (!(await isOrtEnvInitialized())) {
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
await runtimeInitializationPromise;
runtimeInitializationPromise = undefined;
runtimeInitialized = true;
}
if (typeof pathOrBuffer === 'string') {

View file

@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
let ortEnvInitialized = false;
/**
* get the input/output count of the session.
* @param sessionHandle the handle representing the session. should be non-zero.
@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise<void> => {
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env);
}
ortEnvInitialized = true;
};
/**
@ -93,6 +97,8 @@ type SessionMetadata = [
const activeSessions = new Map<number, SessionMetadata>();
export const isOrtEnvInitialized = (): boolean => ortEnvInitialized;
/**
* allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
* @returns a 2-elements tuple - the pointer and size of the allocated buffer

View file

@ -0,0 +1,162 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession} from 'onnxruntime-common';
import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages';
import {setSessionOptions} from './session-options';
import {getInstance} from './wasm-factory';
import {checkLastError} from './wasm-utils';
const NO_TRAIN_FUNCS_MSG =
'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' +
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
const wasm = getInstance();
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
let checkpointHandle = 0;
try {
if (wasm._OrtTrainingLoadCheckpoint) {
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
if (checkpointHandle === 0) {
checkLastError('Error occurred when trying to create a CheckpointState.');
}
return checkpointHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
}
throw e;
} finally {
// free buffer from wasm heap
wasm._OrtFree(checkpointData[0]);
}
};
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
const wasm = getInstance();
const stack = wasm.stackSave();
try {
const dataOffset = wasm.stackAlloc(8);
if (wasm._OrtTrainingGetModelInputOutputCount) {
const errorCode =
wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
if (errorCode !== 0) {
checkLastError('Can\'t get session input/output count.');
}
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};
const getModelInputOutputNamesLoop =
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
const names = [];
const wasm = getInstance();
const namesUTF8Encoded = [];
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetModelInputOutputName) {
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
if (name === 0) {
checkLastError('Can\'t get input or output name');
}
namesUTF8Encoded.push(name);
names.push(wasm.UTF8ToString(name));
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
}
return [names, namesUTF8Encoded];
};
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
const [outputNames, outputNamesUTF8Encoded] =
getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
};
export const createTrainingSessionHandle =
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
optimizerModelData: SerializableModeldata,
options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
const wasm = getInstance();
let trainingSessionHandle = 0;
let sessionOptionsHandle = 0;
let allocs: number[] = [];
let inputNamesUTF8Encoded: number[] = [];
let outputNamesUTF8Encoded: number[] = [];
let inputNames: string[] = [];
let outputNames: string[] = [];
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
if (wasm._OrtTrainingCreateSession) {
trainingSessionHandle = wasm._OrtTrainingCreateSession(
sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0],
evalModelData[1], optimizerModelData[0], optimizerModelData[1]);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
if (trainingSessionHandle === 0) {
checkLastError('Error occurred when trying to create a TrainingSession.');
}
[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
getTrainingModelInputOutputNames(trainingSessionHandle);
return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
} catch (e) {
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
}
throw e;
} finally {
wasm._free(trainModelData[0]);
wasm._free(evalModelData[0]);
wasm._free(optimizerModelData[0]);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
}
};
export const releaseTrainingSessionAndCheckpoint =
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
void => {
const wasm = getInstance();
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
};

View file

@ -493,6 +493,14 @@ char* OrtEndProfiling(ort_session_handle_t session) {
#define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \
CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__))
#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \
do { \
int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \
if (error_code != ORT_OK) { \
return error_code; \
} \
} while (false)
ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer,
size_t checkpoint_size) {
OrtCheckpointState* checkpoint_state = nullptr;
@ -571,6 +579,57 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio
return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only);
}
int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle,
size_t* input_count,
size_t* output_count,
bool isEvalModel) {
if (isEvalModel) {
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count);
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count);
return ORT_OK;
} else {
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count);
RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count);
return ORT_OK;
}
}
char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle,
size_t index,
bool isInput,
bool isEvalModel) {
OrtAllocator* allocator = nullptr;
RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator);
char* name = nullptr;
if (isEvalModel) {
if (isInput) {
return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index,
allocator, &name) == ORT_OK)
? name
: nullptr;
} else {
return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index,
allocator, &name) == ORT_OK)
? name
: nullptr;
}
} else {
if (isInput) {
return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index,
allocator, &name) == ORT_OK)
? name
: nullptr;
} else {
return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index,
allocator, &name) == ORT_OK)
? name
: nullptr;
}
}
}
void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) {
Ort::GetTrainingApi().ReleaseTrainingSession(training_handle);
}

View file

@ -432,6 +432,35 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio
size_t parameter_count,
bool trainable_only);
/**
* Gets the input count and output count of the training or eval model associated with the given training handle.
* @param traning_handle handle of the traning session
* @param input_count [out] a pointer to a size_t variable to accept input_count
* @param output_count [out] a pointer to a size_t variable to accept output_count
* @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output
* count of the eval model.
* @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message.
*/
int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle,
size_t* input_count,
size_t* output_count,
bool isEvalModel);
/**
* Gets the input or output name at the specified index associated with the training or eval model from the
* given training session.
* @param traning_handle handle of the traning session
* @param index the input or output index
* @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name.
* @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output
* names of the eval model.
* @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by
*/
char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle,
size_t index,
bool isInput,
bool isEvalModel);
/**
* @brief Release the specified ORT training session.
*