mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[js/web/training] Implemented runEvalStep & runOptimizerStep (#18259)
### Description * implemented runEvalStep and runOptimizerStep * added hasEvalModel and hasOptimizerModel boolean fields in TrainingSession representation * added evalInputNames and evalOutputNames fields to TrainingSessionHandler & TrainingSession * removed the inputNamesEncoded and outputNamesEncoded fields from TrainingSessionHandler -- since none of the training methods require the input names and output names as parameters, there's no need to store them. ### Motivation and Context * part of the work for implementing web bindings for training * previous PR: #18250 --------- Co-authored-by: Ashwini Khade <askhade@microsoft.com>
This commit is contained in:
parent
5353adcde3
commit
c02a386145
5 changed files with 242 additions and 61 deletions
|
|
@ -45,9 +45,16 @@ export interface InferenceSessionHandler extends SessionHandler {
|
|||
* @ignore
|
||||
*/
|
||||
export interface TrainingSessionHandler extends SessionHandler {
|
||||
readonly evalInputNames: readonly string[];
|
||||
readonly evalOutputNames: readonly string[];
|
||||
|
||||
runTrainStep(
|
||||
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
|
||||
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
|
||||
runEvalStep(
|
||||
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType>;
|
||||
|
||||
getParametersSize(trainableOnly: boolean): Promise<number>;
|
||||
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void>;
|
||||
|
|
|
|||
|
|
@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
|
|||
'Make sure you\'re using the correct configuration & WebAssembly files.';
|
||||
|
||||
export class TrainingSession implements TrainingSessionInterface {
|
||||
private constructor(handler: TrainingSessionHandler) {
|
||||
private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
|
||||
this.handler = handler;
|
||||
this.hasOptimizerModel = hasOptimizerModel;
|
||||
this.hasEvalModel = hasEvalModel;
|
||||
}
|
||||
private handler: TrainingSessionHandler;
|
||||
private hasOptimizerModel: boolean;
|
||||
private hasEvalModel: boolean;
|
||||
|
||||
get inputNames(): readonly string[] {
|
||||
get trainingInputNames(): readonly string[] {
|
||||
return this.handler.inputNames;
|
||||
}
|
||||
get outputNames(): readonly string[] {
|
||||
get trainingOutputNames(): readonly string[] {
|
||||
return this.handler.outputNames;
|
||||
}
|
||||
|
||||
get evalInputNames(): readonly string[] {
|
||||
if (this.hasEvalModel) {
|
||||
return this.handler.evalInputNames;
|
||||
} else {
|
||||
throw new Error('This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
get evalOutputNames(): readonly string[] {
|
||||
if (this.hasEvalModel) {
|
||||
return this.handler.evalOutputNames;
|
||||
} else {
|
||||
throw new Error('This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
|
||||
Promise<TrainingSession> {
|
||||
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
|
||||
|
|
@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
if (backend.createTrainingSessionHandler) {
|
||||
const handler = await backend.createTrainingSessionHandler(
|
||||
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
|
||||
return new TrainingSession(handler);
|
||||
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
|
||||
} else {
|
||||
throw new Error(noBackendErrMsg);
|
||||
}
|
||||
|
|
@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
|
||||
* the given parameters to SessionHandler.FetchesType and RunOptions.
|
||||
*
|
||||
* @param inputNames the feeds object is checked that they contain all input names in the provided list of input
|
||||
* names.
|
||||
* @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
|
||||
* names.
|
||||
* @param feeds the required input
|
||||
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
|
||||
* @param arg2 optional RunOptions object.
|
||||
* @returns
|
||||
*/
|
||||
typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
|
||||
[SessionHandler.FetchesType, RunOptions] {
|
||||
typeNarrowingForRunStep(
|
||||
inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions,
|
||||
arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] {
|
||||
const fetches: {[name: string]: OnnxValue|null} = {};
|
||||
let options: RunOptions = {};
|
||||
// check inputs
|
||||
|
|
@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
if (typeof name !== 'string') {
|
||||
throw new TypeError('\'fetches\' must be a string array or an object.');
|
||||
}
|
||||
if (this.outputNames.indexOf(name) === -1) {
|
||||
if (outputNames.indexOf(name) === -1) {
|
||||
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
|
||||
}
|
||||
fetches[name] = null;
|
||||
|
|
@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
// if any output name is present and its value is valid OnnxValue, we consider it fetches
|
||||
let isFetches = false;
|
||||
const arg1Keys = Object.getOwnPropertyNames(arg1);
|
||||
for (const name of this.outputNames) {
|
||||
for (const name of outputNames) {
|
||||
if (arg1Keys.indexOf(name) !== -1) {
|
||||
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
|
||||
if (v === null || v instanceof Tensor) {
|
||||
|
|
@ -130,7 +154,7 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
}
|
||||
|
||||
// check if all inputs are in feed
|
||||
for (const name of this.inputNames) {
|
||||
for (const name of inputNames) {
|
||||
if (typeof feeds[name] === 'undefined') {
|
||||
throw new Error(`input '${name}' is missing in 'feeds'.`);
|
||||
}
|
||||
|
|
@ -138,7 +162,7 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
|
||||
// if no fetches is specified, we use the full output names list
|
||||
if (isFetchesEmpty) {
|
||||
for (const name of this.outputNames) {
|
||||
for (const name of outputNames) {
|
||||
fetches[name] = null;
|
||||
}
|
||||
}
|
||||
|
|
@ -171,11 +195,33 @@ export class TrainingSession implements TrainingSessionInterface {
|
|||
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise<ReturnType>;
|
||||
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise<ReturnType>;
|
||||
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
|
||||
const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
|
||||
const [fetches, options] =
|
||||
this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2);
|
||||
const results = await this.handler.runTrainStep(feeds, fetches, options);
|
||||
return this.convertHandlerReturnTypeToMapOfTensors(results);
|
||||
}
|
||||
|
||||
async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise<void> {
|
||||
if (this.hasOptimizerModel) {
|
||||
await this.handler.runOptimizerStep(options || {});
|
||||
} else {
|
||||
throw new Error('This TrainingSession has no OptimizerModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise<ReturnType>;
|
||||
runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise<ReturnType>;
|
||||
async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise<ReturnType> {
|
||||
if (this.hasEvalModel) {
|
||||
const [fetches, options] =
|
||||
this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2);
|
||||
const results = await this.handler.runEvalStep(feeds, fetches, options);
|
||||
return this.convertHandlerReturnTypeToMapOfTensors(results);
|
||||
} else {
|
||||
throw new Error('This TrainingSession has no EvalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
async getParametersSize(trainableOnly = true): Promise<number> {
|
||||
return this.handler.getParametersSize(trainableOnly);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ export interface TrainingSession {
|
|||
* @param feeds - Representation of the model input.
|
||||
* @param fetches - Representation of the model output.
|
||||
* detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model inference.
|
||||
* @param options - Optional. A set of options that controls the behavior of model training.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
|
|
@ -47,6 +47,38 @@ export interface TrainingSession {
|
|||
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
|
||||
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
|
||||
*
|
||||
* @param options - Optional. A set of options that controls the behavior of model optimizing.
|
||||
*/
|
||||
runOptimizerStep(options?: InferenceSession.RunOptions): Promise<void>;
|
||||
|
||||
/**
|
||||
* Run a single eval step with the given inputs and options using the eval model.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
|
||||
Promise<InferenceSession.ReturnType>;
|
||||
|
||||
/**
|
||||
* Run a single eval step with the given inputs and options using the eval model.
|
||||
*
|
||||
* @param feeds - Representation of the model input.
|
||||
* @param fetches - Representation of the model output.
|
||||
* detail.
|
||||
* @param options - Optional. A set of options that controls the behavior of model eval step.
|
||||
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
|
||||
values.
|
||||
*/
|
||||
runEvalStep(
|
||||
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
|
||||
options?: InferenceSession.RunOptions): Promise<InferenceSession.ReturnType>;
|
||||
|
||||
// #endregion
|
||||
|
||||
// #region copy parameters
|
||||
|
|
@ -90,14 +122,25 @@ export interface TrainingSession {
|
|||
// #region metadata
|
||||
|
||||
/**
|
||||
* Get input names of the loaded model.
|
||||
* Get input names of the loaded training model.
|
||||
*/
|
||||
readonly inputNames: readonly string[];
|
||||
readonly trainingInputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get output names of the loaded model.
|
||||
* Get output names of the loaded training model.
|
||||
*/
|
||||
readonly outputNames: readonly string[];
|
||||
readonly trainingOutputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
|
||||
*/
|
||||
readonly evalInputNames: readonly string[];
|
||||
|
||||
/**
|
||||
* Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
|
||||
*/
|
||||
readonly evalOutputNames: readonly string[];
|
||||
|
||||
// #endregion
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio
|
|||
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
|
||||
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
|
||||
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
|
||||
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
|
||||
import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl';
|
||||
|
||||
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
|
||||
private sessionId: number;
|
||||
|
|
@ -15,8 +15,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
|
|||
inputNames: string[];
|
||||
outputNames: string[];
|
||||
|
||||
inputEncodedNames: number[];
|
||||
outputEncodedNames: number[];
|
||||
evalInputNames: string[] = [];
|
||||
evalOutputNames: string[] = [];
|
||||
|
||||
async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise<SerializableModeldata> {
|
||||
let buffer: Uint8Array;
|
||||
|
|
@ -51,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
|
|||
}
|
||||
|
||||
this.checkpointId = createCheckpointHandle(checkpointData);
|
||||
[[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
|
||||
this.sessionId =
|
||||
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
|
||||
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
|
||||
if (evalModelUriOrBuffer !== '') {
|
||||
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -118,6 +122,27 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
|
|||
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
|
||||
}
|
||||
|
||||
async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
|
||||
await runOptimizerStep(this.sessionId, options);
|
||||
}
|
||||
|
||||
async runEvalStep(
|
||||
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions): Promise<SessionHandler.ReturnType> {
|
||||
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
|
||||
feeds, this.evalInputNames,
|
||||
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`));
|
||||
|
||||
const [outputArray, outputIndices, outputs] =
|
||||
this.convertMapIntoValuesArrayAndIndicesArray<Tensor|null, TensorMetadata|null>(
|
||||
fetches, this.evalOutputNames,
|
||||
(t, i): TensorMetadata|null =>
|
||||
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null);
|
||||
|
||||
const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
|
||||
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
|
||||
}
|
||||
|
||||
async getParametersSize(trainableOnly: boolean): Promise<number> {
|
||||
return getParametersSize(this.sessionId, trainableOnly);
|
||||
}
|
||||
|
|
@ -131,7 +156,6 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
|
|||
}
|
||||
|
||||
async dispose(): Promise<void> {
|
||||
return releaseTrainingSessionAndCheckpoint(
|
||||
this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
|
||||
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import {InferenceSession, Tensor} from 'onnxruntime-common';
|
||||
|
||||
import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages';
|
||||
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
|
||||
import {setRunOptions} from './run-options';
|
||||
import {setSessionOptions} from './session-options';
|
||||
import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
|
||||
|
|
@ -77,50 +77,44 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea
|
|||
};
|
||||
|
||||
const getModelInputOutputNamesLoop =
|
||||
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
|
||||
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => {
|
||||
const names = [];
|
||||
const wasm = getInstance();
|
||||
|
||||
const namesUTF8Encoded = [];
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
if (wasm._OrtTrainingGetModelInputOutputName) {
|
||||
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
|
||||
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
|
||||
|
||||
namesUTF8Encoded.push(name);
|
||||
names.push(wasm.UTF8ToString(name));
|
||||
wasm._free(name);
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
}
|
||||
return [names, namesUTF8Encoded];
|
||||
return names;
|
||||
};
|
||||
|
||||
const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
|
||||
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
|
||||
export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
|
||||
let inputNames: string[] = [];
|
||||
let outputNames: string[] = [];
|
||||
|
||||
const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
|
||||
const [outputNames, outputNamesUTF8Encoded] =
|
||||
getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
|
||||
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
|
||||
|
||||
return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
|
||||
inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
|
||||
outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
|
||||
|
||||
return [inputNames, outputNames];
|
||||
};
|
||||
|
||||
export const createTrainingSessionHandle =
|
||||
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
|
||||
optimizerModelData: SerializableModeldata,
|
||||
options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
|
||||
optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => {
|
||||
const wasm = getInstance();
|
||||
|
||||
let trainingSessionHandle = 0;
|
||||
let sessionOptionsHandle = 0;
|
||||
let allocs: number[] = [];
|
||||
let inputNamesUTF8Encoded: number[] = [];
|
||||
let outputNamesUTF8Encoded: number[] = [];
|
||||
|
||||
let inputNames: string[] = [];
|
||||
let outputNames: string[] = [];
|
||||
|
||||
try {
|
||||
[sessionOptionsHandle, allocs] = setSessionOptions(options);
|
||||
|
|
@ -133,11 +127,7 @@ export const createTrainingSessionHandle =
|
|||
}
|
||||
|
||||
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
|
||||
|
||||
[inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
|
||||
getTrainingModelInputOutputNames(trainingSessionHandle);
|
||||
return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
|
||||
|
||||
return trainingSessionHandle;
|
||||
} catch (e) {
|
||||
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
|
||||
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
|
||||
|
|
@ -152,8 +142,6 @@ export const createTrainingSessionHandle =
|
|||
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
||||
}
|
||||
allocs.forEach(alloc => wasm._free(alloc));
|
||||
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -317,6 +305,83 @@ export const runTrainStep = async(
|
|||
}
|
||||
};
|
||||
|
||||
export const runOptimizerStep =
|
||||
async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise<void> => {
|
||||
const wasm = getInstance();
|
||||
|
||||
let runOptionsHandle = 0;
|
||||
let runOptionsAllocs: number[] = [];
|
||||
|
||||
try {
|
||||
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
||||
|
||||
if (wasm._OrtTrainingOptimizerStep) {
|
||||
const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
|
||||
ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
} finally {
|
||||
if (runOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
||||
}
|
||||
runOptionsAllocs.forEach(p => wasm._free(p));
|
||||
}
|
||||
};
|
||||
|
||||
export const runEvalStep = async(
|
||||
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
|
||||
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
|
||||
const wasm = getInstance();
|
||||
|
||||
const inputCount = inputIndices.length;
|
||||
const outputCount = outputIndices.length;
|
||||
|
||||
let runOptionsHandle = 0;
|
||||
let runOptionsAllocs: number[] = [];
|
||||
|
||||
const inputTensorHandles: number[] = [];
|
||||
const outputTensorHandles: number[] = [];
|
||||
const inputOutputAllocs: number[] = [];
|
||||
|
||||
const beforeRunStack = wasm.stackSave();
|
||||
|
||||
try {
|
||||
// prepare parameters by moving them to heap
|
||||
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
||||
|
||||
// handle inputs -- you don't want anything added to the index
|
||||
const inputValuesOffset = createAndAllocateTensors(
|
||||
trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0);
|
||||
// handle outputs
|
||||
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
|
||||
const outputValuesOffset = createAndAllocateTensors(
|
||||
trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount);
|
||||
|
||||
if (wasm._OrtTrainingEvalStep) {
|
||||
const errorCode = wasm._OrtTrainingEvalStep(
|
||||
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);
|
||||
|
||||
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
|
||||
} finally {
|
||||
wasm.stackRestore(beforeRunStack);
|
||||
|
||||
inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
||||
outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
||||
inputOutputAllocs.forEach(p => wasm._free(p));
|
||||
|
||||
if (runOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
||||
}
|
||||
runOptionsAllocs.forEach(p => wasm._free(p));
|
||||
}
|
||||
};
|
||||
|
||||
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
|
||||
const wasm = getInstance();
|
||||
const stack = wasm.stackSave();
|
||||
|
|
@ -439,17 +504,13 @@ export const loadParametersBuffer =
|
|||
}
|
||||
};
|
||||
|
||||
export const releaseTrainingSessionAndCheckpoint =
|
||||
(checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
|
||||
void => {
|
||||
const wasm = getInstance();
|
||||
inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
|
||||
export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
|
||||
const wasm = getInstance();
|
||||
|
||||
if (wasm._OrtTrainingReleaseSession) {
|
||||
wasm._OrtTrainingReleaseSession(sessionId);
|
||||
}
|
||||
if (wasm._OrtTrainingReleaseCheckpoint) {
|
||||
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
|
||||
}
|
||||
};
|
||||
if (wasm._OrtTrainingReleaseSession) {
|
||||
wasm._OrtTrainingReleaseSession(sessionId);
|
||||
}
|
||||
if (wasm._OrtTrainingReleaseCheckpoint) {
|
||||
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
|
||||
}
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in a new issue