[js/web] remove training release (#22103)

### Description

Remove training from onnxruntime-web

Following up of #22082
This commit is contained in:
Yulong Wang 2024-09-16 10:56:22 -07:00 committed by GitHub
parent e93f14e00d
commit 291a5352b2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 15 additions and 1544 deletions

View file

@ -1,5 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { OnnxruntimeWebAssemblyBackend } from './backend-wasm';
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();

View file

@ -1,29 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common';
import { OnnxruntimeWebAssemblyBackend } from './backend-wasm';
import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training';
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
checkpointStateUriOrBuffer: string | Uint8Array,
trainModelUriOrBuffer: string | Uint8Array,
evalModelUriOrBuffer: string | Uint8Array,
optimizerModelUriOrBuffer: string | Uint8Array,
options: InferenceSession.SessionOptions,
): Promise<TrainingSessionHandler> {
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
await handler.createTrainingSession(
checkpointStateUriOrBuffer,
trainModelUriOrBuffer,
evalModelUriOrBuffer,
optimizerModelUriOrBuffer,
options,
);
return Promise.resolve(handler);
}
}
export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();

View file

@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend {
return Promise.resolve(handler);
}
}
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();

View file

@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING
? require('./backend-wasm-inference').wasmBackend
: require('./backend-wasm-training').wasmBackend;
const wasmBackend = require('./backend-wasm').wasmBackend;
if (!BUILD_DEFS.DISABLE_JSEP) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);

View file

@ -1,198 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common';
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference';
import { copyFromExternalBuffer } from './wasm-core-impl';
import {
createCheckpointHandle,
createTrainingSessionHandle,
getContiguousParameters,
getModelInputOutputNames,
getParametersSize,
lazyResetGrad,
loadParametersBuffer,
releaseTrainingSessionAndCheckpoint,
runEvalStep,
runOptimizerStep,
runTrainStep,
} from './wasm-training-core-impl';
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
private sessionId: number;
private checkpointId: number;
inputNames: string[];
outputNames: string[];
evalInputNames: string[] = [];
evalOutputNames: string[] = [];
async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise<SerializableInternalBuffer> {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
const arrayBuffer = await response.arrayBuffer();
buffer = new Uint8Array(arrayBuffer);
} else {
buffer = uriOrBuffer;
}
return copyFromExternalBuffer(buffer);
}
async createTrainingSession(
checkpointStateUriOrBuffer: string | Uint8Array,
trainModelUriOrBuffer: string | Uint8Array,
evalModelUriOrBuffer: string | Uint8Array,
optimizerModelUriOrBuffer: string | Uint8Array,
options: InferenceSession.SessionOptions,
) {
const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
let evalModelData: SerializableInternalBuffer = [0, 0];
let optimizerModelData: SerializableInternalBuffer = [0, 0];
if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
}
if (optimizerModelUriOrBuffer !== '') {
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
}
this.checkpointId = createCheckpointHandle(checkpointData);
this.sessionId = createTrainingSessionHandle(
this.checkpointId,
trainModelData,
evalModelData,
optimizerModelData,
options,
);
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
if (evalModelUriOrBuffer !== '') {
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
}
}
/**
* Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the
* corresponding name as a number referring to the index in the list of names provided.
*
* @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType
* @param names either inputNames or outputNames
* @returns a tuple of a list of values and a list of indices.
*/
convertMapIntoValuesArrayAndIndicesArray<T, U>(
feeds: { [name: string]: T },
names: string[],
mapFunc: (val: T, index: number) => U,
): [T[], number[], U[]] {
const values: T[] = [];
const indices: number[] = [];
Object.entries(feeds).forEach((kvp) => {
const name = kvp[0];
const tensor = kvp[1];
const index = names.indexOf(name);
if (index === -1) {
throw new Error(`invalid input '${name}`);
}
values.push(tensor);
indices.push(index);
});
const uList = values.map(mapFunc);
return [values, indices, uList];
}
/**
* Helper method that converts the TensorMetadata that the wasm-core functions return to the
* SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the
* corresponding result.
*
* @param results used to populate the resultMap if there is no value for that outputName already
* @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results
* @param outputIndices specifies which outputName the corresponding value for outputArray refers to.
* @returns a map of output names and OnnxValues.
*/
convertTensorMetadataToReturnType(
results: TensorMetadata[],
outputArray: Array<Tensor | null>,
outputIndices: number[],
): SessionHandler.ReturnType {
const resultMap: SessionHandler.ReturnType = {};
for (let i = 0; i < results.length; i++) {
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
return resultMap;
}
async lazyResetGrad(): Promise<void> {
await lazyResetGrad(this.sessionId);
}
async runTrainStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds,
this.inputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),
);
const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
Tensor | null,
TensorMetadata | null
>(fetches, this.outputNames, (t, i): TensorMetadata | null =>
t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,
);
const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
await runOptimizerStep(this.sessionId, options);
}
async runEvalStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds,
this.evalInputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`),
);
const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
Tensor | null,
TensorMetadata | null
>(fetches, this.evalOutputNames, (t, i): TensorMetadata | null =>
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null,
);
const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}
async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {
await loadParametersBuffer(this.sessionId, array, trainableOnly);
}
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
return decodeTensorMetadata(tensorResult);
}
async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
}
}

View file

@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file';
* Refer to web/lib/index.ts for the backend registration.
*
* 2. WebAssembly artifact initialization.
* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or
* `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings:
* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is
* called). In this step, onnxruntime-web does the followings:
* - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
* - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
* JavaScript code to initialize the WebAssembly runtime.
@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file';
* - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
*
* 4. Session initialization.
* This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3
* steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the
* followings:
* This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once),
* this step will be done for each session. In this step, onnxruntime-web does the followings:
* If the parameter is a URL:
* - download the model data from the URL.
* - copy the model data to the WASM heap. (proxy: 'copy-from')

View file

@ -1,631 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession, Tensor } from 'onnxruntime-common';
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
import { setRunOptions } from './run-options';
import { setSessionOptions } from './session-options';
import {
dataLocationStringToEnum,
tensorDataTypeEnumToString,
tensorDataTypeStringToEnum,
tensorTypeToTypedArrayConstructor,
} from './wasm-common';
import { prepareInputOutputTensor } from './wasm-core-impl';
import { getInstance } from './wasm-factory';
import { checkLastError } from './wasm-utils';
const NO_TRAIN_FUNCS_MSG =
"Built without training API's enabled. Use the onnxruntime-web/training import for training " +
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
/**
* Runs the checkLastError function which will throw an error, if the provided error code matches the specified
* pattern for an error code.
* @param errCode number to evaluated for if it's an error
* @param message message to pass into checkLastError
* @param checkNeqZero when true, treats not equal to zero as an error.
* When false, treats equal to zero as an error.
*/
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
if (checkNeqZero && errCode !== 0) {
checkLastError(message);
} else if (!checkNeqZero && errCode === 0) {
checkLastError(message);
}
};
export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
const wasm = getInstance();
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
let checkpointHandle = 0;
try {
if (wasm._OrtTrainingLoadCheckpoint) {
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
return checkpointHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
}
throw e;
} finally {
// free buffer from wasm heap
wasm._OrtFree(checkpointData[0]);
}
};
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
const wasm = getInstance();
const stack = wasm.stackSave();
try {
const dataOffset = wasm.stackAlloc(8);
if (wasm._OrtTrainingGetModelInputOutputCount) {
const errorCode = wasm._OrtTrainingGetModelInputOutputCount(
trainingSessionId,
dataOffset,
dataOffset + 4,
isEvalModel,
);
ifErrCodeCheckLastError(errorCode, "Can't get session input/output count.");
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};
const getModelInputOutputNamesLoop = (
trainingSessionId: number,
count: number,
isInput: boolean,
isEvalModel: boolean,
): string[] => {
const names = [];
const wasm = getInstance();
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetModelInputOutputName) {
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
names.push(wasm.UTF8ToString(name));
wasm._free(name);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
}
return names;
};
export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
let inputNames: string[] = [];
let outputNames: string[] = [];
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
return [inputNames, outputNames];
};
export const createTrainingSessionHandle = (
checkpointHandle: number,
trainModelData: SerializableInternalBuffer,
evalModelData: SerializableInternalBuffer,
optimizerModelData: SerializableInternalBuffer,
options: InferenceSession.SessionOptions,
): number => {
const wasm = getInstance();
let trainingSessionHandle = 0;
let sessionOptionsHandle = 0;
let allocs: number[] = [];
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
if (wasm._OrtTrainingCreateSession) {
trainingSessionHandle = wasm._OrtTrainingCreateSession(
sessionOptionsHandle,
checkpointHandle,
trainModelData[0],
trainModelData[1],
evalModelData[0],
evalModelData[1],
optimizerModelData[0],
optimizerModelData[1],
);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
return trainingSessionHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
}
throw e;
} finally {
wasm._free(trainModelData[0]);
wasm._free(evalModelData[0]);
wasm._free(optimizerModelData[0]);
if (sessionOptionsHandle !== 0) {
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach((alloc) => wasm._free(alloc));
}
};
/**
* Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the
* WASM tensors.
*
* @param trainingSessionId
* @param indices for each tensor, the index of the input or output name that the tensor corresponds with
* @param tensors list of TensorMetaData
* @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
* handles of the allocated tensors on the heap
* @param inputOutputAllocs modified in-place by this method
* @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
*/
const createAndAllocateTensors = (
trainingSessionId: number,
indices: number[],
tensors: Array<TensorMetadata | null>,
tensorHandles: number[],
inputOutputAllocs: number[],
indexAdd: number,
) => {
const count = indices.length;
// creates the tensors
for (let i = 0; i < count; i++) {
prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
}
// moves to heap
const wasm = getInstance();
const valuesOffset = wasm.stackAlloc(count * 4);
let valuesIndex = valuesOffset / 4;
for (let i = 0; i < count; i++) {
wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
}
return valuesOffset;
};
/**
* Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information
* associated with the tensor handle.
*
* @param outputValuesOffset
* @param outputCount
* @returns list of TensorMetadata retrieved from the output handles.
*/
const moveOutputToTensorMetadataArr = (
outputValuesOffset: number,
outputCount: number,
outputTensorHandles: number[],
outputTensors: Array<TensorMetadata | null>,
) => {
const wasm = getInstance();
const output: TensorMetadata[] = [];
for (let i = 0; i < outputCount; i++) {
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
if (tensor === outputTensorHandles[i]) {
// output tensor is pre-allocated. no need to copy data.
output.push(outputTensors[i]!);
continue;
}
const beforeGetTensorDataStack = wasm.stackSave();
// stack allocate 4 pointer value
const tensorDataOffset = wasm.stackAlloc(4 * 4);
let type: Tensor.Type | undefined,
dataOffset = 0;
try {
const errorCode = wasm._OrtGetTensorData(
tensor,
tensorDataOffset,
tensorDataOffset + 4,
tensorDataOffset + 8,
tensorDataOffset + 12,
);
ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);
let tensorDataIndex = tensorDataOffset / 4;
const dataType = wasm.HEAPU32[tensorDataIndex++];
dataOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
const dims = [];
for (let i = 0; i < dimsLength; i++) {
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
}
wasm._OrtFree(dimsOffset);
const size = dims.reduce((a, b) => a * b, 1);
type = tensorDataTypeEnumToString(dataType);
if (type === 'string') {
const stringData: string[] = [];
let dataIndex = dataOffset / 4;
for (let i = 0; i < size; i++) {
const offset = wasm.HEAPU32[dataIndex++];
const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
}
output.push([type, dims, stringData, 'cpu']);
} else {
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
const data = new typedArrayConstructor(size);
new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
);
output.push([type, dims, data, 'cpu']);
}
} finally {
wasm.stackRestore(beforeGetTensorDataStack);
if (type === 'string' && dataOffset) {
wasm._free(dataOffset);
}
wasm._OrtReleaseTensor(tensor);
}
}
return output;
};
export const lazyResetGrad = async (trainingSessionId: number): Promise<void> => {
const wasm = getInstance();
if (wasm._OrtTrainingLazyResetGrad) {
const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad.");
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
};
export const runTrainStep = async (
trainingSessionId: number,
inputIndices: number[],
inputTensors: TensorMetadata[],
outputIndices: number[],
outputTensors: Array<TensorMetadata | null>,
options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
const wasm = getInstance();
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
const inputTensorHandles: number[] = [];
const outputTensorHandles: number[] = [];
const inputOutputAllocs: number[] = [];
const beforeRunStack = wasm.stackSave();
try {
// prepare parameters by moving them to heap
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
// handle inputs -- you don't want anything added to the index
const inputValuesOffset = createAndAllocateTensors(
trainingSessionId,
inputIndices,
inputTensors,
inputTensorHandles,
inputOutputAllocs,
0,
);
// handle outputs
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
const outputValuesOffset = createAndAllocateTensors(
trainingSessionId,
outputIndices,
outputTensors,
outputTensorHandles,
inputOutputAllocs,
inputCount,
);
if (wasm._OrtTrainingRunTrainStep) {
const errorCode = wasm._OrtTrainingRunTrainStep(
trainingSessionId,
inputValuesOffset,
inputCount,
outputValuesOffset,
outputCount,
runOptionsHandle,
);
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
} finally {
wasm.stackRestore(beforeRunStack);
inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
inputOutputAllocs.forEach((p) => wasm._free(p));
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
}
runOptionsAllocs.forEach((p) => wasm._free(p));
}
};
export const runOptimizerStep = async (
trainingSessionId: number,
options: InferenceSession.RunOptions,
): Promise<void> => {
const wasm = getInstance();
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
try {
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
if (wasm._OrtTrainingOptimizerStep) {
const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
}
runOptionsAllocs.forEach((p) => wasm._free(p));
}
};
export const runEvalStep = async (
trainingSessionId: number,
inputIndices: number[],
inputTensors: TensorMetadata[],
outputIndices: number[],
outputTensors: Array<TensorMetadata | null>,
options: InferenceSession.RunOptions,
): Promise<TensorMetadata[]> => {
const wasm = getInstance();
const inputCount = inputIndices.length;
const outputCount = outputIndices.length;
let runOptionsHandle = 0;
let runOptionsAllocs: number[] = [];
const inputTensorHandles: number[] = [];
const outputTensorHandles: number[] = [];
const inputOutputAllocs: number[] = [];
const beforeRunStack = wasm.stackSave();
try {
// prepare parameters by moving them to heap
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
// handle inputs -- you don't want anything added to the index
const inputValuesOffset = createAndAllocateTensors(
trainingSessionId,
inputIndices,
inputTensors,
inputTensorHandles,
inputOutputAllocs,
0,
);
// handle outputs
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
const outputValuesOffset = createAndAllocateTensors(
trainingSessionId,
outputIndices,
outputTensors,
outputTensorHandles,
inputOutputAllocs,
inputCount,
);
if (wasm._OrtTrainingEvalStep) {
const errorCode = wasm._OrtTrainingEvalStep(
trainingSessionId,
inputValuesOffset,
inputCount,
outputValuesOffset,
outputCount,
runOptionsHandle,
);
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
} finally {
wasm.stackRestore(beforeRunStack);
inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
inputOutputAllocs.forEach((p) => wasm._free(p));
if (runOptionsHandle !== 0) {
wasm._OrtReleaseRunOptions(runOptionsHandle);
}
runOptionsAllocs.forEach((p) => wasm._free(p));
}
};
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
const wasm = getInstance();
const stack = wasm.stackSave();
try {
const sizeOffset = wasm.stackAlloc(4);
if (wasm._OrtTrainingGetParametersSize) {
const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
ifErrCodeCheckLastError(errorCode, "Can't get parameters size");
return wasm.HEAP32[sizeOffset / 4];
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
wasm.stackRestore(stack);
}
};
export const getContiguousParameters = async (
trainingSessionId: number,
trainableOnly: boolean,
): Promise<TensorMetadata> => {
const wasm = getInstance();
const stack = wasm.stackSave();
const tensorTypeAsString = 'float32';
const locationAsString = 'cpu';
const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
let tensor = 0;
// allocates a buffer of the correct size on the WASM heap
const paramsByteLength = 4 * parametersSize;
const paramsOffset = wasm._malloc(paramsByteLength);
// handles the dimensions-related createTensor parameters
const dims = [parametersSize];
const dimsOffset = wasm.stackAlloc(4);
const dimsIndex = dimsOffset / 4;
wasm.HEAP32[dimsIndex] = parametersSize;
try {
// wraps allocated array in a tensor
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(tensorTypeAsString),
paramsOffset,
paramsByteLength,
dimsOffset,
dims.length,
dataLocationStringToEnum(locationAsString),
);
ifErrCodeCheckLastError(
tensor,
`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`,
false,
);
if (wasm._OrtTrainingCopyParametersToBuffer) {
const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters.");
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
// copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
const data = new typedArrayConstructor(parametersSize);
const output: TensorMetadata[] = [];
new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength),
);
output.push([tensorTypeAsString, dims, data, locationAsString]);
if (output.length !== 1) {
throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
one, got ${output.length}`);
} else {
return output[0];
}
} finally {
if (tensor !== 0) {
wasm._OrtReleaseTensor(tensor);
}
wasm._free(paramsOffset);
wasm._free(dimsOffset);
wasm.stackRestore(stack);
}
};
export const loadParametersBuffer = async (
trainingSessionId: number,
buffer: Uint8Array,
trainableOnly: boolean,
): Promise<void> => {
const wasm = getInstance();
const stack = wasm.stackSave();
const tensorTypeAsString = 'float32';
const locationAsString = 'cpu';
// allocates & copies JavaScript buffer to WASM heap
const bufferByteLength = buffer.length;
const bufferCount = bufferByteLength / 4;
const bufferOffset = wasm._malloc(bufferByteLength);
wasm.HEAPU8.set(buffer, bufferOffset);
// allocates and handles moving dimensions information to WASM memory
const dimsOffset = wasm.stackAlloc(4);
wasm.HEAP32[dimsOffset / 4] = bufferCount;
const dimsLength = 1;
let tensor = 0;
try {
tensor = wasm._OrtCreateTensor(
tensorDataTypeStringToEnum(tensorTypeAsString),
bufferOffset,
bufferByteLength,
dimsOffset,
dimsLength,
dataLocationStringToEnum(locationAsString),
);
ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);
if (wasm._OrtTrainingCopyParametersFromBuffer) {
const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters.");
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
} finally {
if (tensor !== 0) {
wasm._OrtReleaseTensor(tensor);
}
wasm.stackRestore(stack);
wasm._free(bufferOffset);
wasm._free(dimsOffset);
}
};
export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
const wasm = getInstance();
if (wasm._OrtTrainingReleaseSession) {
wasm._OrtTrainingReleaseSession(sessionId);
}
if (wasm._OrtTrainingReleaseCheckpoint) {
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
}
};

View file

@ -213,84 +213,10 @@ export interface OrtInferenceAPIs {
_OrtEndProfiling(sessionHandle: number): number;
}
export interface OrtTrainingAPIs {
_OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number;
_OrtTrainingReleaseCheckpoint(checkpointHandle: number): void;
_OrtTrainingCreateSession(
sessionOptionsHandle: number,
checkpointHandle: number,
trainOffset: number,
trainLength: number,
evalOffset: number,
evalLength: number,
optimizerOffset: number,
optimizerLength: number,
): number;
_OrtTrainingLazyResetGrad(trainingHandle: number): number;
_OrtTrainingRunTrainStep(
trainingHandle: number,
inputsOffset: number,
inputCount: number,
outputsOffset: number,
outputCount: number,
runOptionsHandle: number,
): number;
_OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number;
_OrtTrainingEvalStep(
trainingHandle: number,
inputsOffset: number,
inputCount: number,
outputsOffset: number,
outputCount: number,
runOptionsHandle: number,
): number;
_OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
_OrtTrainingCopyParametersToBuffer(
trainingHandle: number,
parametersBuffer: number,
parameterCount: number,
trainableOnly: boolean,
): number;
_OrtTrainingCopyParametersFromBuffer(
trainingHandle: number,
parametersBuffer: number,
parameterCount: number,
trainableOnly: boolean,
): number;
_OrtTrainingGetModelInputOutputCount(
trainingHandle: number,
inputCount: number,
outputCount: number,
isEvalModel: boolean,
): number;
_OrtTrainingGetModelInputOutputName(
trainingHandle: number,
index: number,
isInput: boolean,
isEvalModel: boolean,
): number;
_OrtTrainingReleaseSession(trainingHandle: number): void;
}
/**
* The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten.
*/
export interface OrtWasmModule
extends EmscriptenModule,
OrtInferenceAPIs,
Partial<OrtTrainingAPIs>,
Partial<JSEP.Module> {
export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial<JSEP.Module> {
// #region emscripten functions
stackSave(): number;
stackRestore(stack: number): void;

View file

@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory<OrtWasmModule> | undefined =
BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT
? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
require(
!BUILD_DEFS.DISABLE_TRAINING
? '../../dist/ort-training-wasm-simd-threaded.mjs'
: !BUILD_DEFS.DISABLE_JSEP
? '../../dist/ort-wasm-simd-threaded.jsep.mjs'
: '../../dist/ort-wasm-simd-threaded.mjs',
!BUILD_DEFS.DISABLE_JSEP
? '../../dist/ort-wasm-simd-threaded.jsep.mjs'
: '../../dist/ort-wasm-simd-threaded.mjs',
).default
: undefined;
@ -163,11 +161,9 @@ export const importWasmModule = async (
if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) {
return [undefined, embeddedWasmModule!];
} else {
const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING
? 'ort-training-wasm-simd-threaded.mjs'
: !BUILD_DEFS.DISABLE_JSEP
? 'ort-wasm-simd-threaded.jsep.mjs'
: 'ort-wasm-simd-threaded.mjs';
const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP
? 'ort-wasm-simd-threaded.jsep.mjs'
: 'ort-wasm-simd-threaded.mjs';
const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride);
// need to preload if all of the following conditions are met:
// 1. not in Node.js.

View file

@ -23,7 +23,6 @@
"build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md",
"pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts",
"test:e2e": "node ./test/e2e/run",
"test:training:e2e": "node ./test/training/e2e/run",
"prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit",
"build": "node ./script/build",
"test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli",
@ -101,12 +100,6 @@
"import": "./dist/ort.webgpu.bundle.min.mjs",
"require": "./dist/ort.webgpu.min.js",
"types": "./types.d.ts"
},
"./training": {
"node": null,
"import": "./dist/ort.training.wasm.min.mjs",
"require": "./dist/ort.training.wasm.min.js",
"types": "./types.d.ts"
}
},
"types": "./types.d.ts",

View file

@ -56,7 +56,6 @@ const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_JSEP': 'false',
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_TRAINING': 'true',
'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false',
'BUILD_DEFS.IS_ESM': 'false',
@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) {
*
* The distribution code is split into multiple files:
* - [output-name][.min].[m]js
* - ort[-training]-wasm-simd-threaded[.jsep].mjs
* - ort-wasm-simd-threaded[.jsep].mjs
*/
async function buildOrt({
isProduction = false,
@ -630,16 +629,6 @@ async function main() {
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
},
});
// ort.training.wasm[.min].[m]js
await addAllWebBuildTasks({
outputName: 'ort.training.wasm',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_JSEP': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
},
});
}
if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') {

View file

@ -149,11 +149,9 @@ downloadJson(
void jszip.loadAsync(buffer).then((zip) => {
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName);
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName);
extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName);
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName);
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName);
extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName);
});
});
},

View file

@ -1,21 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
'use strict';
describe('Browser E2E testing for training package', function () {
it('Check that training package encompasses inference', async function () {
ort.env.wasm.numThreads = 1;
await testInferenceFunction(ort, { executionProviders: ['wasm'] });
});
it('Check training functionality, all options', async function () {
ort.env.wasm.numThreads = 1;
await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] });
});
it('Check training functionality, minimum options', async function () {
ort.env.wasm.numThreads = 1;
await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] });
});
});

View file

@ -1,248 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
'use strict';
const DATA_FOLDER = 'data/';
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';
const trainingSessionAllOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
evalModel: TRAININGDATA_EVAL_MODEL,
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL,
};
const trainingSessionMinOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
};
// ASSERT METHODS
function assert(cond) {
if (!cond) throw new Error();
}
function assertStrictEquals(actual, expected) {
if (actual !== expected) {
let strRep = actual;
if (typeof actual === 'object') {
strRep = JSON.stringify(actual);
}
throw new Error(`expected: ${expected}; got: ${strRep}`);
}
}
function assertTwoListsUnequal(list1, list2) {
if (list1.length !== list2.length) {
return;
}
for (let i = 0; i < list1.length; i++) {
if (list1[i] !== list2[i]) {
return;
}
}
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`);
}
// HELPER METHODS FOR TESTS
function generateGaussianRandom(mean = 0, scale = 1) {
const u = 1 - Math.random();
const v = Math.random();
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
return z * scale + mean;
}
function generateGaussianFloatArray(length) {
const array = new Float32Array(length);
for (let i = 0; i < length; i++) {
array[i] = generateGaussianRandom();
}
return array;
}
/**
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
* training session are correct.
* @param {} ort
* @param {*} createOptions
* @param {*} options
* @returns
*/
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
const trainingSession = await ort.TrainingSession.create(createOptions, options);
assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
return trainingSession;
}
/**
* verifies that the eval input and output names associated with the eval model loaded into the given training session
* are correct.
*/
function checkEvalModel(trainingSession) {
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
assertStrictEquals(trainingSession.evalInputNames.length, 2);
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
}
/**
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
* accessed
* @param {} trainingSession
*/
function checkNoEvalModel(trainingSession) {
try {
assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing');
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
try {
assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing');
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
}
/**
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
* of 1 for the loss.
* @param {} trainingSession
* @param {*} feeds
* @returns
*/
var runTrainStepAndCheck = async function (trainingSession, feeds) {
const results = await trainingSession.runTrainStep(feeds);
assertStrictEquals(Object.keys(results).length, 1);
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
return results;
};
var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) {
// make a float32 array that is filled with the constant
const newParams = new Float32Array(paramsLength);
for (let i = 0; i < paramsLength; i++) {
newParams[i] = constant;
}
const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);
await trainingSession.loadParametersBuffer(newParamsUint8);
const paramsAfterLoad = await trainingSession.getContiguousParameters();
// check that the parameters have changed
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data);
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);
// check that the parameters have changed to what they should be
for (let i = 0; i < paramsLength; i++) {
// round to the same number of digits (4 decimal places)
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
}
return paramsAfterLoad;
};
// TESTS
var testInferenceFunction = async function (ort, options) {
const session = await ort.InferenceSession.create('data/model.onnx', options || {});
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);
const fetches = await session.run({
a: new ort.Tensor('float32', dataA, [3, 4]),
b: new ort.Tensor('float32', dataB, [4, 3]),
});
const c = fetches.c;
assert(c instanceof ort.Tensor);
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
assert(c.data[0] === 700);
assert(c.data[1] === 800);
assert(c.data[2] === 900);
assert(c.data[3] === 1580);
assert(c.data[4] === 1840);
assert(c.data[5] === 2100);
assert(c.data[6] === 2460);
assert(c.data[7] === 2880);
assert(c.data[8] === 3300);
};
var testTrainingFunctionMin = async function (ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
checkNoEvalModel(trainingSession);
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
const feeds = { 'input-0': input0, labels: labels };
// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);
// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
await runTrainStepAndCheck(trainingSession, feeds);
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
};
var testTrainingFunctionAll = async function (ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
checkEvalModel(trainingSession);
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
let feeds = { 'input-0': input0, labels: labels };
// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);
// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
const results = await runTrainStepAndCheck(trainingSession, feeds);
await trainingSession.runOptimizerStep(feeds);
feeds = { 'input-0': input0, labels: labels };
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
const optimizedParams = await trainingSession.getContiguousParameters();
assertTwoListsUnequal(originalParams.data, optimizedParams.data);
const results2 = await runTrainStepAndCheck(trainingSession, feeds);
// check that loss decreased after optimizer step and training again
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
};
if (typeof module === 'object') {
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
}

View file

@ -1,16 +0,0 @@
 backend-test:b

a
bc"MatMultest_matmul_2dZ
a


Z
b


b
c


B

View file

@ -1,54 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
'use strict';
const args = require('minimist')(process.argv.slice(2));
const SELF_HOST = !!args['self-host'];
const ORT_MAIN = args['ort-main'];
const TEST_MAIN = args['test-main'];
if (typeof TEST_MAIN !== 'string') {
throw new Error('flag --test-main=<TEST_MAIN_JS_FILE> is required');
}
const USER_DATA = args['user-data'];
if (typeof USER_DATA !== 'string') {
throw new Error('flag --user-data=<CHROME_USER_DATA_FOLDER> is required');
}
module.exports = function (config) {
const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/';
config.set({
frameworks: ['mocha'],
files: [
{ pattern: distPrefix + ORT_MAIN },
{ pattern: './common.js' },
{ pattern: TEST_MAIN },
{ pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true },
{ pattern: './data/*', included: false },
],
plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins],
proxies: {
'/model.onnx': '/base/model.onnx',
'/data/': '/base/data/',
},
client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } },
reporters: ['mocha'],
captureTimeout: 120000,
reportSlowerThan: 100,
browserDisconnectTimeout: 600000,
browserNoActivityTimeout: 300000,
browserDisconnectTolerance: 0,
browserSocketTimeout: 60000,
hostname: 'localhost',
browsers: [],
customLaunchers: {
Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA },
Chrome_no_threads: {
base: 'ChromeHeadless',
chromeDataDir: USER_DATA,
// TODO: no-thread flags
},
Edge_default: { base: 'Edge', edgeDataDir: USER_DATA },
},
});
};

View file

@ -1,14 +0,0 @@
{
"devDependencies": {
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2",
"fs-extra": "^11.1.0",
"globby": "^13.1.3",
"karma": "^6.4.1",
"karma-chrome-launcher": "^3.1.1",
"karma-mocha": "^2.0.1",
"karma-mocha-reporter": "^2.2.5",
"light-server": "^2.9.1",
"minimist": "^1.2.7",
"mocha": "^10.2.0"
}
}

View file

@ -1,143 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
'use strict';
const path = require('path');
const fs = require('fs-extra');
const { spawn } = require('child_process');
const startServer = require('./simple-http-server');
const minimist = require('minimist');
// copy whole folder to out-side of <ORT_ROOT>/js/ because we need to test in a folder that no `package.json` file
// exists in its parent folder.
// here we use <ORT_ROOT>/build/js/e2e-training/ for the test
const TEST_E2E_SRC_FOLDER = __dirname;
const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..');
const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training');
const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache');
const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data');
fs.emptyDirSync(TEST_E2E_RUN_FOLDER);
fs.emptyDirSync(NPM_CACHE_FOLDER);
fs.emptyDirSync(CHROME_USER_DATA_FOLDER);
fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER);
// training data to copy
const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..');
const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api');
const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data');
// always use a new folder as user-data-dir
let nextUserDataDirId = 0;
function getNextUserDataDir() {
const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString());
nextUserDataDirId++;
fs.emptyDirSync(dir);
return dir;
}
// commandline arguments
const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default';
async function main() {
// find packed package
const { globbySync } = await import('globby');
const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common');
const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER });
const PACKAGES_TO_INSTALL = [];
if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) {
PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0]));
} else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) {
throw new Error('multiple packages found for onnxruntime-common.');
}
const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web');
const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER });
if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) {
throw new Error('cannot find exactly single package for onnxruntime-web.');
}
PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0]));
// we start here:
// install dev dependencies
await runInShell(`npm install`);
// npm install with "--cache" to install packed packages with an empty cache folder
await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`);
// prepare training data
prepareTrainingDataByCopying();
console.log('===============================================================');
console.log('Running self-hosted tests');
console.log('===============================================================');
// test cases with self-host (ort hosted in same origin)
await testAllBrowserCases({ hostInKarma: true });
console.log('===============================================================');
console.log('Running not self-hosted tests');
console.log('===============================================================');
// test cases without self-host (ort hosted in cross origin)
const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081);
try {
await testAllBrowserCases({ hostInKarma: false });
} finally {
// close the server after all tests
await server.close();
}
}
async function testAllBrowserCases({ hostInKarma }) {
await runKarma({ hostInKarma, main: './browser-test-wasm.js' });
}
async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) {
console.log('===============================================================');
console.log(`Running karma with the following binary: ${ortMain}`);
console.log('===============================================================');
const selfHostFlag = hostInKarma ? '--self-host' : '';
await runInShell(
`npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${
ortMain
} --test-main=${main} --user-data=${getNextUserDataDir()}`,
);
}
async function runInShell(cmd) {
console.log('===============================================================');
console.log(' Running command in shell:');
console.log(' > ' + cmd);
console.log('===============================================================');
let complete = false;
const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER });
childProcess.on('close', function (code) {
if (code !== 0) {
process.exit(code);
} else {
complete = true;
}
});
while (!complete) {
await delay(100);
}
}
async function delay(ms) {
return new Promise(function (resolve) {
setTimeout(function () {
resolve();
}, ms);
});
}
function prepareTrainingDataByCopying() {
fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST);
console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`);
}
main();

View file

@ -1,67 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
'use strict';
// this is a simple HTTP server that enables CORS.
// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework
const http = require('http');
const fs = require('fs');
const path = require('path');
const getRequestData = (url, dir) => {
const pathname = new URL(url, 'http://localhost').pathname;
let filepath;
let mimeType;
if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) {
filepath = path.resolve(dir, pathname.substring(1));
} else {
return null;
}
if (filepath.endsWith('.wasm')) {
mimeType = 'application/wasm';
} else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) {
mimeType = 'text/javascript';
} else {
return null;
}
return [filepath, mimeType];
};
module.exports = function (dir, port) {
const server = http
.createServer(function (request, response) {
const url = request.url.replace(/\n|\r/g, '');
console.log(`request ${url}`);
const requestData = getRequestData(url, dir);
if (!request || !requestData) {
response.writeHead(404);
response.end('404');
} else {
const [filePath, contentType] = requestData;
fs.readFile(path.resolve(dir, filePath), function (error, content) {
if (error) {
if (error.code == 'ENOENT') {
response.writeHead(404);
response.end('404');
} else {
response.writeHead(500);
response.end('500');
}
} else {
response.setHeader('access-control-allow-origin', '*');
response.writeHead(200, { 'Content-Type': contentType });
response.end(content, 'utf-8');
}
});
}
})
.listen(port);
console.log(`Server running at http://localhost:${port}/`);
return server;
};

4
js/web/types.d.ts vendored
View file

@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' {
declare module 'onnxruntime-web/webgpu' {
export * from 'onnxruntime-web';
}
declare module 'onnxruntime-web/training' {
export * from 'onnxruntime-web';
}