mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[js/web] remove training release (#22103)
### Description Remove training from onnxruntime-web Following up of #22082
This commit is contained in:
parent
e93f14e00d
commit
291a5352b2
20 changed files with 15 additions and 1544 deletions
|
|
@ -1,5 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { OnnxruntimeWebAssemblyBackend } from './backend-wasm';
|
||||
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common';
|
||||
|
||||
import { OnnxruntimeWebAssemblyBackend } from './backend-wasm';
|
||||
import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training';
|
||||
|
||||
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
|
||||
async createTrainingSessionHandler(
|
||||
checkpointStateUriOrBuffer: string | Uint8Array,
|
||||
trainModelUriOrBuffer: string | Uint8Array,
|
||||
evalModelUriOrBuffer: string | Uint8Array,
|
||||
optimizerModelUriOrBuffer: string | Uint8Array,
|
||||
options: InferenceSession.SessionOptions,
|
||||
): Promise<TrainingSessionHandler> {
|
||||
const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
|
||||
await handler.createTrainingSession(
|
||||
checkpointStateUriOrBuffer,
|
||||
trainModelUriOrBuffer,
|
||||
evalModelUriOrBuffer,
|
||||
optimizerModelUriOrBuffer,
|
||||
options,
|
||||
);
|
||||
return Promise.resolve(handler);
|
||||
}
|
||||
}
|
||||
|
||||
export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
|
||||
|
|
@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend {
|
|||
return Promise.resolve(handler);
|
||||
}
|
||||
}
|
||||
|
||||
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
|
||||
|
|
|
|||
|
|
@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
|
|||
}
|
||||
|
||||
if (!BUILD_DEFS.DISABLE_WASM) {
|
||||
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING
|
||||
? require('./backend-wasm-inference').wasmBackend
|
||||
: require('./backend-wasm-training').wasmBackend;
|
||||
const wasmBackend = require('./backend-wasm').wasmBackend;
|
||||
if (!BUILD_DEFS.DISABLE_JSEP) {
|
||||
registerBackend('webgpu', wasmBackend, 5);
|
||||
registerBackend('webnn', wasmBackend, 5);
|
||||
|
|
|
|||
|
|
@ -1,198 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common';
|
||||
|
||||
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
|
||||
import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference';
|
||||
import { copyFromExternalBuffer } from './wasm-core-impl';
|
||||
import {
|
||||
createCheckpointHandle,
|
||||
createTrainingSessionHandle,
|
||||
getContiguousParameters,
|
||||
getModelInputOutputNames,
|
||||
getParametersSize,
|
||||
lazyResetGrad,
|
||||
loadParametersBuffer,
|
||||
releaseTrainingSessionAndCheckpoint,
|
||||
runEvalStep,
|
||||
runOptimizerStep,
|
||||
runTrainStep,
|
||||
} from './wasm-training-core-impl';
|
||||
|
||||
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
|
||||
private sessionId: number;
|
||||
private checkpointId: number;
|
||||
|
||||
inputNames: string[];
|
||||
outputNames: string[];
|
||||
|
||||
evalInputNames: string[] = [];
|
||||
evalOutputNames: string[] = [];
|
||||
|
||||
async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise<SerializableInternalBuffer> {
|
||||
let buffer: Uint8Array;
|
||||
if (typeof uriOrBuffer === 'string') {
|
||||
const response = await fetch(uriOrBuffer);
|
||||
const arrayBuffer = await response.arrayBuffer();
|
||||
buffer = new Uint8Array(arrayBuffer);
|
||||
} else {
|
||||
buffer = uriOrBuffer;
|
||||
}
|
||||
return copyFromExternalBuffer(buffer);
|
||||
}
|
||||
|
||||
async createTrainingSession(
|
||||
checkpointStateUriOrBuffer: string | Uint8Array,
|
||||
trainModelUriOrBuffer: string | Uint8Array,
|
||||
evalModelUriOrBuffer: string | Uint8Array,
|
||||
optimizerModelUriOrBuffer: string | Uint8Array,
|
||||
options: InferenceSession.SessionOptions,
|
||||
) {
|
||||
const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
|
||||
const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
|
||||
// 0 is supposed to be the nullptr
|
||||
let evalModelData: SerializableInternalBuffer = [0, 0];
|
||||
let optimizerModelData: SerializableInternalBuffer = [0, 0];
|
||||
|
||||
if (evalModelUriOrBuffer !== '') {
|
||||
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
|
||||
}
|
||||
if (optimizerModelUriOrBuffer !== '') {
|
||||
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
|
||||
}
|
||||
|
||||
this.checkpointId = createCheckpointHandle(checkpointData);
|
||||
this.sessionId = createTrainingSessionHandle(
|
||||
this.checkpointId,
|
||||
trainModelData,
|
||||
evalModelData,
|
||||
optimizerModelData,
|
||||
options,
|
||||
);
|
||||
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
|
||||
if (evalModelUriOrBuffer !== '') {
|
||||
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the
|
||||
* corresponding name as a number referring to the index in the list of names provided.
|
||||
*
|
||||
* @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType
|
||||
* @param names either inputNames or outputNames
|
||||
* @returns a tuple of a list of values and a list of indices.
|
||||
*/
|
||||
convertMapIntoValuesArrayAndIndicesArray<T, U>(
|
||||
feeds: { [name: string]: T },
|
||||
names: string[],
|
||||
mapFunc: (val: T, index: number) => U,
|
||||
): [T[], number[], U[]] {
|
||||
const values: T[] = [];
|
||||
const indices: number[] = [];
|
||||
Object.entries(feeds).forEach((kvp) => {
|
||||
const name = kvp[0];
|
||||
const tensor = kvp[1];
|
||||
const index = names.indexOf(name);
|
||||
if (index === -1) {
|
||||
throw new Error(`invalid input '${name}`);
|
||||
}
|
||||
values.push(tensor);
|
||||
indices.push(index);
|
||||
});
|
||||
|
||||
const uList = values.map(mapFunc);
|
||||
return [values, indices, uList];
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method that converts the TensorMetadata that the wasm-core functions return to the
|
||||
* SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the
|
||||
* corresponding result.
|
||||
*
|
||||
* @param results used to populate the resultMap if there is no value for that outputName already
|
||||
* @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results
|
||||
* @param outputIndices specifies which outputName the corresponding value for outputArray refers to.
|
||||
* @returns a map of output names and OnnxValues.
|
||||
*/
|
||||
convertTensorMetadataToReturnType(
|
||||
results: TensorMetadata[],
|
||||
outputArray: Array<Tensor | null>,
|
||||
outputIndices: number[],
|
||||
): SessionHandler.ReturnType {
|
||||
const resultMap: SessionHandler.ReturnType = {};
|
||||
for (let i = 0; i < results.length; i++) {
|
||||
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
|
||||
}
|
||||
return resultMap;
|
||||
}
|
||||
|
||||
async lazyResetGrad(): Promise<void> {
|
||||
await lazyResetGrad(this.sessionId);
|
||||
}
|
||||
|
||||
async runTrainStep(
|
||||
feeds: SessionHandler.FeedsType,
|
||||
fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<SessionHandler.ReturnType> {
|
||||
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
|
||||
feeds,
|
||||
this.inputNames,
|
||||
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),
|
||||
);
|
||||
|
||||
const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
|
||||
Tensor | null,
|
||||
TensorMetadata | null
|
||||
>(fetches, this.outputNames, (t, i): TensorMetadata | null =>
|
||||
t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,
|
||||
);
|
||||
|
||||
const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
|
||||
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
|
||||
}
|
||||
|
||||
async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
|
||||
await runOptimizerStep(this.sessionId, options);
|
||||
}
|
||||
|
||||
async runEvalStep(
|
||||
feeds: SessionHandler.FeedsType,
|
||||
fetches: SessionHandler.FetchesType,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<SessionHandler.ReturnType> {
|
||||
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
|
||||
feeds,
|
||||
this.evalInputNames,
|
||||
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`),
|
||||
);
|
||||
|
||||
const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
|
||||
Tensor | null,
|
||||
TensorMetadata | null
|
||||
>(fetches, this.evalOutputNames, (t, i): TensorMetadata | null =>
|
||||
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null,
|
||||
);
|
||||
|
||||
const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
|
||||
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
|
||||
}
|
||||
|
||||
async getParametersSize(trainableOnly: boolean): Promise<number> {
|
||||
return getParametersSize(this.sessionId, trainableOnly);
|
||||
}
|
||||
|
||||
async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {
|
||||
await loadParametersBuffer(this.sessionId, array, trainableOnly);
|
||||
}
|
||||
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
|
||||
const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
|
||||
return decodeTensorMetadata(tensorResult);
|
||||
}
|
||||
|
||||
async dispose(): Promise<void> {
|
||||
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
|
||||
}
|
||||
}
|
||||
|
|
@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file';
|
|||
* Refer to web/lib/index.ts for the backend registration.
|
||||
*
|
||||
* 2. WebAssembly artifact initialization.
|
||||
* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or
|
||||
* `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings:
|
||||
* This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is
|
||||
* called). In this step, onnxruntime-web does the followings:
|
||||
* - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled.
|
||||
* - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated
|
||||
* JavaScript code to initialize the WebAssembly runtime.
|
||||
|
|
@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file';
|
|||
* - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step.
|
||||
*
|
||||
* 4. Session initialization.
|
||||
* This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3
|
||||
* steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the
|
||||
* followings:
|
||||
* This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once),
|
||||
* this step will be done for each session. In this step, onnxruntime-web does the followings:
|
||||
* If the parameter is a URL:
|
||||
* - download the model data from the URL.
|
||||
* - copy the model data to the WASM heap. (proxy: 'copy-from')
|
||||
|
|
|
|||
|
|
@ -1,631 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import { InferenceSession, Tensor } from 'onnxruntime-common';
|
||||
|
||||
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
|
||||
import { setRunOptions } from './run-options';
|
||||
import { setSessionOptions } from './session-options';
|
||||
import {
|
||||
dataLocationStringToEnum,
|
||||
tensorDataTypeEnumToString,
|
||||
tensorDataTypeStringToEnum,
|
||||
tensorTypeToTypedArrayConstructor,
|
||||
} from './wasm-common';
|
||||
import { prepareInputOutputTensor } from './wasm-core-impl';
|
||||
import { getInstance } from './wasm-factory';
|
||||
import { checkLastError } from './wasm-utils';
|
||||
|
||||
const NO_TRAIN_FUNCS_MSG =
|
||||
"Built without training API's enabled. Use the onnxruntime-web/training import for training " +
|
||||
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
|
||||
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
|
||||
|
||||
/**
|
||||
* Runs the checkLastError function which will throw an error, if the provided error code matches the specified
|
||||
* pattern for an error code.
|
||||
* @param errCode number to evaluated for if it's an error
|
||||
* @param message message to pass into checkLastError
|
||||
* @param checkNeqZero when true, treats not equal to zero as an error.
|
||||
* When false, treats equal to zero as an error.
|
||||
*/
|
||||
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
|
||||
if (checkNeqZero && errCode !== 0) {
|
||||
checkLastError(message);
|
||||
} else if (!checkNeqZero && errCode === 0) {
|
||||
checkLastError(message);
|
||||
}
|
||||
};
|
||||
|
||||
export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
|
||||
const wasm = getInstance();
|
||||
|
||||
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
|
||||
let checkpointHandle = 0;
|
||||
|
||||
try {
|
||||
if (wasm._OrtTrainingLoadCheckpoint) {
|
||||
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
|
||||
return checkpointHandle;
|
||||
} catch (e) {
|
||||
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
|
||||
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
|
||||
}
|
||||
throw e;
|
||||
} finally {
|
||||
// free buffer from wasm heap
|
||||
wasm._OrtFree(checkpointData[0]);
|
||||
}
|
||||
};
|
||||
|
||||
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
|
||||
const wasm = getInstance();
|
||||
const stack = wasm.stackSave();
|
||||
try {
|
||||
const dataOffset = wasm.stackAlloc(8);
|
||||
if (wasm._OrtTrainingGetModelInputOutputCount) {
|
||||
const errorCode = wasm._OrtTrainingGetModelInputOutputCount(
|
||||
trainingSessionId,
|
||||
dataOffset,
|
||||
dataOffset + 4,
|
||||
isEvalModel,
|
||||
);
|
||||
ifErrCodeCheckLastError(errorCode, "Can't get session input/output count.");
|
||||
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
} finally {
|
||||
wasm.stackRestore(stack);
|
||||
}
|
||||
};
|
||||
|
||||
const getModelInputOutputNamesLoop = (
|
||||
trainingSessionId: number,
|
||||
count: number,
|
||||
isInput: boolean,
|
||||
isEvalModel: boolean,
|
||||
): string[] => {
|
||||
const names = [];
|
||||
const wasm = getInstance();
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
if (wasm._OrtTrainingGetModelInputOutputName) {
|
||||
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
|
||||
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
|
||||
|
||||
names.push(wasm.UTF8ToString(name));
|
||||
wasm._free(name);
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
}
|
||||
return names;
|
||||
};
|
||||
|
||||
export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
|
||||
let inputNames: string[] = [];
|
||||
let outputNames: string[] = [];
|
||||
|
||||
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
|
||||
|
||||
inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
|
||||
outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
|
||||
|
||||
return [inputNames, outputNames];
|
||||
};
|
||||
|
||||
export const createTrainingSessionHandle = (
|
||||
checkpointHandle: number,
|
||||
trainModelData: SerializableInternalBuffer,
|
||||
evalModelData: SerializableInternalBuffer,
|
||||
optimizerModelData: SerializableInternalBuffer,
|
||||
options: InferenceSession.SessionOptions,
|
||||
): number => {
|
||||
const wasm = getInstance();
|
||||
|
||||
let trainingSessionHandle = 0;
|
||||
let sessionOptionsHandle = 0;
|
||||
let allocs: number[] = [];
|
||||
|
||||
try {
|
||||
[sessionOptionsHandle, allocs] = setSessionOptions(options);
|
||||
if (wasm._OrtTrainingCreateSession) {
|
||||
trainingSessionHandle = wasm._OrtTrainingCreateSession(
|
||||
sessionOptionsHandle,
|
||||
checkpointHandle,
|
||||
trainModelData[0],
|
||||
trainModelData[1],
|
||||
evalModelData[0],
|
||||
evalModelData[1],
|
||||
optimizerModelData[0],
|
||||
optimizerModelData[1],
|
||||
);
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
|
||||
return trainingSessionHandle;
|
||||
} catch (e) {
|
||||
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
|
||||
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
|
||||
}
|
||||
throw e;
|
||||
} finally {
|
||||
wasm._free(trainModelData[0]);
|
||||
wasm._free(evalModelData[0]);
|
||||
wasm._free(optimizerModelData[0]);
|
||||
|
||||
if (sessionOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
||||
}
|
||||
allocs.forEach((alloc) => wasm._free(alloc));
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the
|
||||
* WASM tensors.
|
||||
*
|
||||
* @param trainingSessionId
|
||||
* @param indices for each tensor, the index of the input or output name that the tensor corresponds with
|
||||
* @param tensors list of TensorMetaData
|
||||
* @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
|
||||
* handles of the allocated tensors on the heap
|
||||
* @param inputOutputAllocs modified in-place by this method
|
||||
* @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
|
||||
*/
|
||||
const createAndAllocateTensors = (
|
||||
trainingSessionId: number,
|
||||
indices: number[],
|
||||
tensors: Array<TensorMetadata | null>,
|
||||
tensorHandles: number[],
|
||||
inputOutputAllocs: number[],
|
||||
indexAdd: number,
|
||||
) => {
|
||||
const count = indices.length;
|
||||
|
||||
// creates the tensors
|
||||
for (let i = 0; i < count; i++) {
|
||||
prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
|
||||
}
|
||||
|
||||
// moves to heap
|
||||
const wasm = getInstance();
|
||||
const valuesOffset = wasm.stackAlloc(count * 4);
|
||||
let valuesIndex = valuesOffset / 4;
|
||||
for (let i = 0; i < count; i++) {
|
||||
wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
|
||||
}
|
||||
|
||||
return valuesOffset;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information
|
||||
* associated with the tensor handle.
|
||||
*
|
||||
* @param outputValuesOffset
|
||||
* @param outputCount
|
||||
* @returns list of TensorMetadata retrieved from the output handles.
|
||||
*/
|
||||
const moveOutputToTensorMetadataArr = (
|
||||
outputValuesOffset: number,
|
||||
outputCount: number,
|
||||
outputTensorHandles: number[],
|
||||
outputTensors: Array<TensorMetadata | null>,
|
||||
) => {
|
||||
const wasm = getInstance();
|
||||
const output: TensorMetadata[] = [];
|
||||
|
||||
for (let i = 0; i < outputCount; i++) {
|
||||
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
|
||||
if (tensor === outputTensorHandles[i]) {
|
||||
// output tensor is pre-allocated. no need to copy data.
|
||||
output.push(outputTensors[i]!);
|
||||
continue;
|
||||
}
|
||||
|
||||
const beforeGetTensorDataStack = wasm.stackSave();
|
||||
// stack allocate 4 pointer value
|
||||
const tensorDataOffset = wasm.stackAlloc(4 * 4);
|
||||
|
||||
let type: Tensor.Type | undefined,
|
||||
dataOffset = 0;
|
||||
try {
|
||||
const errorCode = wasm._OrtGetTensorData(
|
||||
tensor,
|
||||
tensorDataOffset,
|
||||
tensorDataOffset + 4,
|
||||
tensorDataOffset + 8,
|
||||
tensorDataOffset + 12,
|
||||
);
|
||||
ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);
|
||||
|
||||
let tensorDataIndex = tensorDataOffset / 4;
|
||||
const dataType = wasm.HEAPU32[tensorDataIndex++];
|
||||
dataOffset = wasm.HEAPU32[tensorDataIndex++];
|
||||
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
|
||||
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
|
||||
const dims = [];
|
||||
for (let i = 0; i < dimsLength; i++) {
|
||||
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
|
||||
}
|
||||
wasm._OrtFree(dimsOffset);
|
||||
|
||||
const size = dims.reduce((a, b) => a * b, 1);
|
||||
type = tensorDataTypeEnumToString(dataType);
|
||||
|
||||
if (type === 'string') {
|
||||
const stringData: string[] = [];
|
||||
let dataIndex = dataOffset / 4;
|
||||
for (let i = 0; i < size; i++) {
|
||||
const offset = wasm.HEAPU32[dataIndex++];
|
||||
const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
|
||||
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
|
||||
}
|
||||
output.push([type, dims, stringData, 'cpu']);
|
||||
} else {
|
||||
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
|
||||
const data = new typedArrayConstructor(size);
|
||||
new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
|
||||
wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength),
|
||||
);
|
||||
output.push([type, dims, data, 'cpu']);
|
||||
}
|
||||
} finally {
|
||||
wasm.stackRestore(beforeGetTensorDataStack);
|
||||
if (type === 'string' && dataOffset) {
|
||||
wasm._free(dataOffset);
|
||||
}
|
||||
wasm._OrtReleaseTensor(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
};
|
||||
|
||||
export const lazyResetGrad = async (trainingSessionId: number): Promise<void> => {
|
||||
const wasm = getInstance();
|
||||
|
||||
if (wasm._OrtTrainingLazyResetGrad) {
|
||||
const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
|
||||
ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad.");
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
};
|
||||
|
||||
export const runTrainStep = async (
|
||||
trainingSessionId: number,
|
||||
inputIndices: number[],
|
||||
inputTensors: TensorMetadata[],
|
||||
outputIndices: number[],
|
||||
outputTensors: Array<TensorMetadata | null>,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<TensorMetadata[]> => {
|
||||
const wasm = getInstance();
|
||||
|
||||
const inputCount = inputIndices.length;
|
||||
const outputCount = outputIndices.length;
|
||||
|
||||
let runOptionsHandle = 0;
|
||||
let runOptionsAllocs: number[] = [];
|
||||
|
||||
const inputTensorHandles: number[] = [];
|
||||
const outputTensorHandles: number[] = [];
|
||||
const inputOutputAllocs: number[] = [];
|
||||
|
||||
const beforeRunStack = wasm.stackSave();
|
||||
|
||||
try {
|
||||
// prepare parameters by moving them to heap
|
||||
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
||||
|
||||
// handle inputs -- you don't want anything added to the index
|
||||
const inputValuesOffset = createAndAllocateTensors(
|
||||
trainingSessionId,
|
||||
inputIndices,
|
||||
inputTensors,
|
||||
inputTensorHandles,
|
||||
inputOutputAllocs,
|
||||
0,
|
||||
);
|
||||
// handle outputs
|
||||
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
|
||||
const outputValuesOffset = createAndAllocateTensors(
|
||||
trainingSessionId,
|
||||
outputIndices,
|
||||
outputTensors,
|
||||
outputTensorHandles,
|
||||
inputOutputAllocs,
|
||||
inputCount,
|
||||
);
|
||||
|
||||
if (wasm._OrtTrainingRunTrainStep) {
|
||||
const errorCode = wasm._OrtTrainingRunTrainStep(
|
||||
trainingSessionId,
|
||||
inputValuesOffset,
|
||||
inputCount,
|
||||
outputValuesOffset,
|
||||
outputCount,
|
||||
runOptionsHandle,
|
||||
);
|
||||
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
|
||||
} finally {
|
||||
wasm.stackRestore(beforeRunStack);
|
||||
|
||||
inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
|
||||
outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
|
||||
inputOutputAllocs.forEach((p) => wasm._free(p));
|
||||
|
||||
if (runOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
||||
}
|
||||
runOptionsAllocs.forEach((p) => wasm._free(p));
|
||||
}
|
||||
};
|
||||
|
||||
export const runOptimizerStep = async (
|
||||
trainingSessionId: number,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<void> => {
|
||||
const wasm = getInstance();
|
||||
|
||||
let runOptionsHandle = 0;
|
||||
let runOptionsAllocs: number[] = [];
|
||||
|
||||
try {
|
||||
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
||||
|
||||
if (wasm._OrtTrainingOptimizerStep) {
|
||||
const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
|
||||
ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
} finally {
|
||||
if (runOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
||||
}
|
||||
runOptionsAllocs.forEach((p) => wasm._free(p));
|
||||
}
|
||||
};
|
||||
|
||||
export const runEvalStep = async (
|
||||
trainingSessionId: number,
|
||||
inputIndices: number[],
|
||||
inputTensors: TensorMetadata[],
|
||||
outputIndices: number[],
|
||||
outputTensors: Array<TensorMetadata | null>,
|
||||
options: InferenceSession.RunOptions,
|
||||
): Promise<TensorMetadata[]> => {
|
||||
const wasm = getInstance();
|
||||
|
||||
const inputCount = inputIndices.length;
|
||||
const outputCount = outputIndices.length;
|
||||
|
||||
let runOptionsHandle = 0;
|
||||
let runOptionsAllocs: number[] = [];
|
||||
|
||||
const inputTensorHandles: number[] = [];
|
||||
const outputTensorHandles: number[] = [];
|
||||
const inputOutputAllocs: number[] = [];
|
||||
|
||||
const beforeRunStack = wasm.stackSave();
|
||||
|
||||
try {
|
||||
// prepare parameters by moving them to heap
|
||||
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
||||
|
||||
// handle inputs -- you don't want anything added to the index
|
||||
const inputValuesOffset = createAndAllocateTensors(
|
||||
trainingSessionId,
|
||||
inputIndices,
|
||||
inputTensors,
|
||||
inputTensorHandles,
|
||||
inputOutputAllocs,
|
||||
0,
|
||||
);
|
||||
// handle outputs
|
||||
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
|
||||
const outputValuesOffset = createAndAllocateTensors(
|
||||
trainingSessionId,
|
||||
outputIndices,
|
||||
outputTensors,
|
||||
outputTensorHandles,
|
||||
inputOutputAllocs,
|
||||
inputCount,
|
||||
);
|
||||
|
||||
if (wasm._OrtTrainingEvalStep) {
|
||||
const errorCode = wasm._OrtTrainingEvalStep(
|
||||
trainingSessionId,
|
||||
inputValuesOffset,
|
||||
inputCount,
|
||||
outputValuesOffset,
|
||||
outputCount,
|
||||
runOptionsHandle,
|
||||
);
|
||||
|
||||
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
|
||||
} finally {
|
||||
wasm.stackRestore(beforeRunStack);
|
||||
|
||||
inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
|
||||
outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v));
|
||||
inputOutputAllocs.forEach((p) => wasm._free(p));
|
||||
|
||||
if (runOptionsHandle !== 0) {
|
||||
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
||||
}
|
||||
runOptionsAllocs.forEach((p) => wasm._free(p));
|
||||
}
|
||||
};
|
||||
|
||||
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
|
||||
const wasm = getInstance();
|
||||
const stack = wasm.stackSave();
|
||||
|
||||
try {
|
||||
const sizeOffset = wasm.stackAlloc(4);
|
||||
if (wasm._OrtTrainingGetParametersSize) {
|
||||
const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
|
||||
ifErrCodeCheckLastError(errorCode, "Can't get parameters size");
|
||||
|
||||
return wasm.HEAP32[sizeOffset / 4];
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
} finally {
|
||||
wasm.stackRestore(stack);
|
||||
}
|
||||
};
|
||||
|
||||
export const getContiguousParameters = async (
|
||||
trainingSessionId: number,
|
||||
trainableOnly: boolean,
|
||||
): Promise<TensorMetadata> => {
|
||||
const wasm = getInstance();
|
||||
const stack = wasm.stackSave();
|
||||
|
||||
const tensorTypeAsString = 'float32';
|
||||
const locationAsString = 'cpu';
|
||||
|
||||
const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
|
||||
let tensor = 0;
|
||||
|
||||
// allocates a buffer of the correct size on the WASM heap
|
||||
const paramsByteLength = 4 * parametersSize;
|
||||
const paramsOffset = wasm._malloc(paramsByteLength);
|
||||
|
||||
// handles the dimensions-related createTensor parameters
|
||||
const dims = [parametersSize];
|
||||
|
||||
const dimsOffset = wasm.stackAlloc(4);
|
||||
const dimsIndex = dimsOffset / 4;
|
||||
wasm.HEAP32[dimsIndex] = parametersSize;
|
||||
|
||||
try {
|
||||
// wraps allocated array in a tensor
|
||||
tensor = wasm._OrtCreateTensor(
|
||||
tensorDataTypeStringToEnum(tensorTypeAsString),
|
||||
paramsOffset,
|
||||
paramsByteLength,
|
||||
dimsOffset,
|
||||
dims.length,
|
||||
dataLocationStringToEnum(locationAsString),
|
||||
);
|
||||
ifErrCodeCheckLastError(
|
||||
tensor,
|
||||
`Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`,
|
||||
false,
|
||||
);
|
||||
|
||||
if (wasm._OrtTrainingCopyParametersToBuffer) {
|
||||
const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
|
||||
ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters.");
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
|
||||
// copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
|
||||
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
|
||||
const data = new typedArrayConstructor(parametersSize);
|
||||
const output: TensorMetadata[] = [];
|
||||
new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set(
|
||||
wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength),
|
||||
);
|
||||
output.push([tensorTypeAsString, dims, data, locationAsString]);
|
||||
if (output.length !== 1) {
|
||||
throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
|
||||
one, got ${output.length}`);
|
||||
} else {
|
||||
return output[0];
|
||||
}
|
||||
} finally {
|
||||
if (tensor !== 0) {
|
||||
wasm._OrtReleaseTensor(tensor);
|
||||
}
|
||||
wasm._free(paramsOffset);
|
||||
wasm._free(dimsOffset);
|
||||
wasm.stackRestore(stack);
|
||||
}
|
||||
};
|
||||
|
||||
export const loadParametersBuffer = async (
|
||||
trainingSessionId: number,
|
||||
buffer: Uint8Array,
|
||||
trainableOnly: boolean,
|
||||
): Promise<void> => {
|
||||
const wasm = getInstance();
|
||||
const stack = wasm.stackSave();
|
||||
|
||||
const tensorTypeAsString = 'float32';
|
||||
const locationAsString = 'cpu';
|
||||
|
||||
// allocates & copies JavaScript buffer to WASM heap
|
||||
const bufferByteLength = buffer.length;
|
||||
const bufferCount = bufferByteLength / 4;
|
||||
const bufferOffset = wasm._malloc(bufferByteLength);
|
||||
wasm.HEAPU8.set(buffer, bufferOffset);
|
||||
|
||||
// allocates and handles moving dimensions information to WASM memory
|
||||
const dimsOffset = wasm.stackAlloc(4);
|
||||
wasm.HEAP32[dimsOffset / 4] = bufferCount;
|
||||
const dimsLength = 1;
|
||||
let tensor = 0;
|
||||
|
||||
try {
|
||||
tensor = wasm._OrtCreateTensor(
|
||||
tensorDataTypeStringToEnum(tensorTypeAsString),
|
||||
bufferOffset,
|
||||
bufferByteLength,
|
||||
dimsOffset,
|
||||
dimsLength,
|
||||
dataLocationStringToEnum(locationAsString),
|
||||
);
|
||||
ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);
|
||||
|
||||
if (wasm._OrtTrainingCopyParametersFromBuffer) {
|
||||
const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
|
||||
ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters.");
|
||||
} else {
|
||||
throw new Error(NO_TRAIN_FUNCS_MSG);
|
||||
}
|
||||
} finally {
|
||||
if (tensor !== 0) {
|
||||
wasm._OrtReleaseTensor(tensor);
|
||||
}
|
||||
wasm.stackRestore(stack);
|
||||
wasm._free(bufferOffset);
|
||||
wasm._free(dimsOffset);
|
||||
}
|
||||
};
|
||||
|
||||
export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
|
||||
const wasm = getInstance();
|
||||
|
||||
if (wasm._OrtTrainingReleaseSession) {
|
||||
wasm._OrtTrainingReleaseSession(sessionId);
|
||||
}
|
||||
if (wasm._OrtTrainingReleaseCheckpoint) {
|
||||
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
|
||||
}
|
||||
};
|
||||
|
|
@ -213,84 +213,10 @@ export interface OrtInferenceAPIs {
|
|||
_OrtEndProfiling(sessionHandle: number): number;
|
||||
}
|
||||
|
||||
export interface OrtTrainingAPIs {
|
||||
_OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number;
|
||||
|
||||
_OrtTrainingReleaseCheckpoint(checkpointHandle: number): void;
|
||||
|
||||
_OrtTrainingCreateSession(
|
||||
sessionOptionsHandle: number,
|
||||
checkpointHandle: number,
|
||||
trainOffset: number,
|
||||
trainLength: number,
|
||||
evalOffset: number,
|
||||
evalLength: number,
|
||||
optimizerOffset: number,
|
||||
optimizerLength: number,
|
||||
): number;
|
||||
|
||||
_OrtTrainingLazyResetGrad(trainingHandle: number): number;
|
||||
|
||||
_OrtTrainingRunTrainStep(
|
||||
trainingHandle: number,
|
||||
inputsOffset: number,
|
||||
inputCount: number,
|
||||
outputsOffset: number,
|
||||
outputCount: number,
|
||||
runOptionsHandle: number,
|
||||
): number;
|
||||
|
||||
_OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number;
|
||||
|
||||
_OrtTrainingEvalStep(
|
||||
trainingHandle: number,
|
||||
inputsOffset: number,
|
||||
inputCount: number,
|
||||
outputsOffset: number,
|
||||
outputCount: number,
|
||||
runOptionsHandle: number,
|
||||
): number;
|
||||
|
||||
_OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
|
||||
|
||||
_OrtTrainingCopyParametersToBuffer(
|
||||
trainingHandle: number,
|
||||
parametersBuffer: number,
|
||||
parameterCount: number,
|
||||
trainableOnly: boolean,
|
||||
): number;
|
||||
|
||||
_OrtTrainingCopyParametersFromBuffer(
|
||||
trainingHandle: number,
|
||||
parametersBuffer: number,
|
||||
parameterCount: number,
|
||||
trainableOnly: boolean,
|
||||
): number;
|
||||
|
||||
_OrtTrainingGetModelInputOutputCount(
|
||||
trainingHandle: number,
|
||||
inputCount: number,
|
||||
outputCount: number,
|
||||
isEvalModel: boolean,
|
||||
): number;
|
||||
_OrtTrainingGetModelInputOutputName(
|
||||
trainingHandle: number,
|
||||
index: number,
|
||||
isInput: boolean,
|
||||
isEvalModel: boolean,
|
||||
): number;
|
||||
|
||||
_OrtTrainingReleaseSession(trainingHandle: number): void;
|
||||
}
|
||||
|
||||
/**
|
||||
* The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten.
|
||||
*/
|
||||
export interface OrtWasmModule
|
||||
extends EmscriptenModule,
|
||||
OrtInferenceAPIs,
|
||||
Partial<OrtTrainingAPIs>,
|
||||
Partial<JSEP.Module> {
|
||||
export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial<JSEP.Module> {
|
||||
// #region emscripten functions
|
||||
stackSave(): number;
|
||||
stackRestore(stack: number): void;
|
||||
|
|
|
|||
|
|
@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory<OrtWasmModule> | undefined =
|
|||
BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT
|
||||
? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
|
||||
require(
|
||||
!BUILD_DEFS.DISABLE_TRAINING
|
||||
? '../../dist/ort-training-wasm-simd-threaded.mjs'
|
||||
: !BUILD_DEFS.DISABLE_JSEP
|
||||
? '../../dist/ort-wasm-simd-threaded.jsep.mjs'
|
||||
: '../../dist/ort-wasm-simd-threaded.mjs',
|
||||
!BUILD_DEFS.DISABLE_JSEP
|
||||
? '../../dist/ort-wasm-simd-threaded.jsep.mjs'
|
||||
: '../../dist/ort-wasm-simd-threaded.mjs',
|
||||
).default
|
||||
: undefined;
|
||||
|
||||
|
|
@ -163,11 +161,9 @@ export const importWasmModule = async (
|
|||
if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) {
|
||||
return [undefined, embeddedWasmModule!];
|
||||
} else {
|
||||
const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING
|
||||
? 'ort-training-wasm-simd-threaded.mjs'
|
||||
: !BUILD_DEFS.DISABLE_JSEP
|
||||
? 'ort-wasm-simd-threaded.jsep.mjs'
|
||||
: 'ort-wasm-simd-threaded.mjs';
|
||||
const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP
|
||||
? 'ort-wasm-simd-threaded.jsep.mjs'
|
||||
: 'ort-wasm-simd-threaded.mjs';
|
||||
const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride);
|
||||
// need to preload if all of the following conditions are met:
|
||||
// 1. not in Node.js.
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@
|
|||
"build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md",
|
||||
"pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts",
|
||||
"test:e2e": "node ./test/e2e/run",
|
||||
"test:training:e2e": "node ./test/training/e2e/run",
|
||||
"prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit",
|
||||
"build": "node ./script/build",
|
||||
"test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli",
|
||||
|
|
@ -101,12 +100,6 @@
|
|||
"import": "./dist/ort.webgpu.bundle.min.mjs",
|
||||
"require": "./dist/ort.webgpu.min.js",
|
||||
"types": "./types.d.ts"
|
||||
},
|
||||
"./training": {
|
||||
"node": null,
|
||||
"import": "./dist/ort.training.wasm.min.mjs",
|
||||
"require": "./dist/ort.training.wasm.min.js",
|
||||
"types": "./types.d.ts"
|
||||
}
|
||||
},
|
||||
"types": "./types.d.ts",
|
||||
|
|
|
|||
|
|
@ -56,7 +56,6 @@ const DEFAULT_DEFINE = {
|
|||
'BUILD_DEFS.DISABLE_JSEP': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
|
||||
'BUILD_DEFS.DISABLE_TRAINING': 'true',
|
||||
'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false',
|
||||
|
||||
'BUILD_DEFS.IS_ESM': 'false',
|
||||
|
|
@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) {
|
|||
*
|
||||
* The distribution code is split into multiple files:
|
||||
* - [output-name][.min].[m]js
|
||||
* - ort[-training]-wasm-simd-threaded[.jsep].mjs
|
||||
* - ort-wasm-simd-threaded[.jsep].mjs
|
||||
*/
|
||||
async function buildOrt({
|
||||
isProduction = false,
|
||||
|
|
@ -630,16 +629,6 @@ async function main() {
|
|||
'BUILD_DEFS.DISABLE_WASM_PROXY': 'true',
|
||||
},
|
||||
});
|
||||
// ort.training.wasm[.min].[m]js
|
||||
await addAllWebBuildTasks({
|
||||
outputName: 'ort.training.wasm',
|
||||
define: {
|
||||
...DEFAULT_DEFINE,
|
||||
'BUILD_DEFS.DISABLE_TRAINING': 'false',
|
||||
'BUILD_DEFS.DISABLE_JSEP': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBGL': 'true',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') {
|
||||
|
|
|
|||
|
|
@ -149,11 +149,9 @@ downloadJson(
|
|||
void jszip.loadAsync(buffer).then((zip) => {
|
||||
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName);
|
||||
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName);
|
||||
extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName);
|
||||
|
||||
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName);
|
||||
extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName);
|
||||
extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName);
|
||||
});
|
||||
});
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
'use strict';
|
||||
|
||||
describe('Browser E2E testing for training package', function () {
|
||||
it('Check that training package encompasses inference', async function () {
|
||||
ort.env.wasm.numThreads = 1;
|
||||
await testInferenceFunction(ort, { executionProviders: ['wasm'] });
|
||||
});
|
||||
|
||||
it('Check training functionality, all options', async function () {
|
||||
ort.env.wasm.numThreads = 1;
|
||||
await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] });
|
||||
});
|
||||
|
||||
it('Check training functionality, minimum options', async function () {
|
||||
ort.env.wasm.numThreads = 1;
|
||||
await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] });
|
||||
});
|
||||
});
|
||||
|
|
@ -1,248 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
'use strict';
|
||||
|
||||
const DATA_FOLDER = 'data/';
|
||||
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
|
||||
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
|
||||
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
|
||||
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';
|
||||
|
||||
const trainingSessionAllOptions = {
|
||||
checkpointState: TRAININGDATA_CKPT,
|
||||
trainModel: TRAININGDATA_TRAIN_MODEL,
|
||||
evalModel: TRAININGDATA_EVAL_MODEL,
|
||||
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL,
|
||||
};
|
||||
|
||||
const trainingSessionMinOptions = {
|
||||
checkpointState: TRAININGDATA_CKPT,
|
||||
trainModel: TRAININGDATA_TRAIN_MODEL,
|
||||
};
|
||||
|
||||
// ASSERT METHODS
|
||||
|
||||
function assert(cond) {
|
||||
if (!cond) throw new Error();
|
||||
}
|
||||
|
||||
function assertStrictEquals(actual, expected) {
|
||||
if (actual !== expected) {
|
||||
let strRep = actual;
|
||||
if (typeof actual === 'object') {
|
||||
strRep = JSON.stringify(actual);
|
||||
}
|
||||
throw new Error(`expected: ${expected}; got: ${strRep}`);
|
||||
}
|
||||
}
|
||||
|
||||
function assertTwoListsUnequal(list1, list2) {
|
||||
if (list1.length !== list2.length) {
|
||||
return;
|
||||
}
|
||||
for (let i = 0; i < list1.length; i++) {
|
||||
if (list1[i] !== list2[i]) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`);
|
||||
}
|
||||
|
||||
// HELPER METHODS FOR TESTS
|
||||
|
||||
function generateGaussianRandom(mean = 0, scale = 1) {
|
||||
const u = 1 - Math.random();
|
||||
const v = Math.random();
|
||||
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
|
||||
return z * scale + mean;
|
||||
}
|
||||
|
||||
function generateGaussianFloatArray(length) {
|
||||
const array = new Float32Array(length);
|
||||
|
||||
for (let i = 0; i < length; i++) {
|
||||
array[i] = generateGaussianRandom();
|
||||
}
|
||||
|
||||
return array;
|
||||
}
|
||||
|
||||
/**
|
||||
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
|
||||
* training session are correct.
|
||||
* @param {} ort
|
||||
* @param {*} createOptions
|
||||
* @param {*} options
|
||||
* @returns
|
||||
*/
|
||||
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
|
||||
const trainingSession = await ort.TrainingSession.create(createOptions, options);
|
||||
|
||||
assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
|
||||
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
|
||||
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
|
||||
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
|
||||
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
|
||||
return trainingSession;
|
||||
}
|
||||
|
||||
/**
|
||||
* verifies that the eval input and output names associated with the eval model loaded into the given training session
|
||||
* are correct.
|
||||
*/
|
||||
function checkEvalModel(trainingSession) {
|
||||
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
|
||||
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
|
||||
assertStrictEquals(trainingSession.evalInputNames.length, 2);
|
||||
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
|
||||
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
|
||||
* accessed
|
||||
* @param {} trainingSession
|
||||
*/
|
||||
function checkNoEvalModel(trainingSession) {
|
||||
try {
|
||||
assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing');
|
||||
} catch (error) {
|
||||
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
|
||||
}
|
||||
try {
|
||||
assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing');
|
||||
} catch (error) {
|
||||
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
|
||||
* of 1 for the loss.
|
||||
* @param {} trainingSession
|
||||
* @param {*} feeds
|
||||
* @returns
|
||||
*/
|
||||
var runTrainStepAndCheck = async function (trainingSession, feeds) {
|
||||
const results = await trainingSession.runTrainStep(feeds);
|
||||
assertStrictEquals(Object.keys(results).length, 1);
|
||||
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
|
||||
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
|
||||
return results;
|
||||
};
|
||||
|
||||
var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) {
|
||||
// make a float32 array that is filled with the constant
|
||||
const newParams = new Float32Array(paramsLength);
|
||||
for (let i = 0; i < paramsLength; i++) {
|
||||
newParams[i] = constant;
|
||||
}
|
||||
|
||||
const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);
|
||||
|
||||
await trainingSession.loadParametersBuffer(newParamsUint8);
|
||||
const paramsAfterLoad = await trainingSession.getContiguousParameters();
|
||||
|
||||
// check that the parameters have changed
|
||||
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data);
|
||||
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);
|
||||
|
||||
// check that the parameters have changed to what they should be
|
||||
for (let i = 0; i < paramsLength; i++) {
|
||||
// round to the same number of digits (4 decimal places)
|
||||
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
|
||||
}
|
||||
|
||||
return paramsAfterLoad;
|
||||
};
|
||||
|
||||
// TESTS
|
||||
|
||||
var testInferenceFunction = async function (ort, options) {
|
||||
const session = await ort.InferenceSession.create('data/model.onnx', options || {});
|
||||
|
||||
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
||||
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);
|
||||
|
||||
const fetches = await session.run({
|
||||
a: new ort.Tensor('float32', dataA, [3, 4]),
|
||||
b: new ort.Tensor('float32', dataB, [4, 3]),
|
||||
});
|
||||
|
||||
const c = fetches.c;
|
||||
|
||||
assert(c instanceof ort.Tensor);
|
||||
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
|
||||
assert(c.data[0] === 700);
|
||||
assert(c.data[1] === 800);
|
||||
assert(c.data[2] === 900);
|
||||
assert(c.data[3] === 1580);
|
||||
assert(c.data[4] === 1840);
|
||||
assert(c.data[5] === 2100);
|
||||
assert(c.data[6] === 2460);
|
||||
assert(c.data[7] === 2880);
|
||||
assert(c.data[8] === 3300);
|
||||
};
|
||||
|
||||
var testTrainingFunctionMin = async function (ort, options) {
|
||||
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
|
||||
checkNoEvalModel(trainingSession);
|
||||
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
|
||||
const labels = new ort.Tensor('int32', [2, 1], [2]);
|
||||
const feeds = { 'input-0': input0, labels: labels };
|
||||
|
||||
// check getParametersSize
|
||||
const paramsSize = await trainingSession.getParametersSize();
|
||||
assertStrictEquals(paramsSize, 397510);
|
||||
|
||||
// check getContiguousParameters
|
||||
const originalParams = await trainingSession.getContiguousParameters();
|
||||
assertStrictEquals(originalParams.dims.length, 1);
|
||||
assertStrictEquals(originalParams.dims[0], 397510);
|
||||
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
|
||||
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
|
||||
|
||||
await runTrainStepAndCheck(trainingSession, feeds);
|
||||
|
||||
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
|
||||
};
|
||||
|
||||
var testTrainingFunctionAll = async function (ort, options) {
|
||||
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
|
||||
checkEvalModel(trainingSession);
|
||||
|
||||
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
|
||||
const labels = new ort.Tensor('int32', [2, 1], [2]);
|
||||
let feeds = { 'input-0': input0, labels: labels };
|
||||
|
||||
// check getParametersSize
|
||||
const paramsSize = await trainingSession.getParametersSize();
|
||||
assertStrictEquals(paramsSize, 397510);
|
||||
|
||||
// check getContiguousParameters
|
||||
const originalParams = await trainingSession.getContiguousParameters();
|
||||
assertStrictEquals(originalParams.dims.length, 1);
|
||||
assertStrictEquals(originalParams.dims[0], 397510);
|
||||
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
|
||||
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
|
||||
|
||||
const results = await runTrainStepAndCheck(trainingSession, feeds);
|
||||
|
||||
await trainingSession.runOptimizerStep(feeds);
|
||||
feeds = { 'input-0': input0, labels: labels };
|
||||
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
|
||||
const optimizedParams = await trainingSession.getContiguousParameters();
|
||||
assertTwoListsUnequal(originalParams.data, optimizedParams.data);
|
||||
|
||||
const results2 = await runTrainStepAndCheck(trainingSession, feeds);
|
||||
|
||||
// check that loss decreased after optimizer step and training again
|
||||
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);
|
||||
|
||||
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
|
||||
};
|
||||
|
||||
if (typeof module === 'object') {
|
||||
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
|
||||
}
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
backend-test:b
|
||||
|
||||
a
|
||||
bc"MatMultest_matmul_2dZ
|
||||
a
|
||||
|
||||
|
||||
Z
|
||||
b
|
||||
|
||||
|
||||
b
|
||||
c
|
||||
|
||||
|
||||
B
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
'use strict';
|
||||
|
||||
const args = require('minimist')(process.argv.slice(2));
|
||||
const SELF_HOST = !!args['self-host'];
|
||||
const ORT_MAIN = args['ort-main'];
|
||||
const TEST_MAIN = args['test-main'];
|
||||
if (typeof TEST_MAIN !== 'string') {
|
||||
throw new Error('flag --test-main=<TEST_MAIN_JS_FILE> is required');
|
||||
}
|
||||
const USER_DATA = args['user-data'];
|
||||
if (typeof USER_DATA !== 'string') {
|
||||
throw new Error('flag --user-data=<CHROME_USER_DATA_FOLDER> is required');
|
||||
}
|
||||
|
||||
module.exports = function (config) {
|
||||
const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/';
|
||||
config.set({
|
||||
frameworks: ['mocha'],
|
||||
files: [
|
||||
{ pattern: distPrefix + ORT_MAIN },
|
||||
{ pattern: './common.js' },
|
||||
{ pattern: TEST_MAIN },
|
||||
{ pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true },
|
||||
{ pattern: './data/*', included: false },
|
||||
],
|
||||
plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins],
|
||||
proxies: {
|
||||
'/model.onnx': '/base/model.onnx',
|
||||
'/data/': '/base/data/',
|
||||
},
|
||||
client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } },
|
||||
reporters: ['mocha'],
|
||||
captureTimeout: 120000,
|
||||
reportSlowerThan: 100,
|
||||
browserDisconnectTimeout: 600000,
|
||||
browserNoActivityTimeout: 300000,
|
||||
browserDisconnectTolerance: 0,
|
||||
browserSocketTimeout: 60000,
|
||||
hostname: 'localhost',
|
||||
browsers: [],
|
||||
customLaunchers: {
|
||||
Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA },
|
||||
Chrome_no_threads: {
|
||||
base: 'ChromeHeadless',
|
||||
chromeDataDir: USER_DATA,
|
||||
// TODO: no-thread flags
|
||||
},
|
||||
Edge_default: { base: 'Edge', edgeDataDir: USER_DATA },
|
||||
},
|
||||
});
|
||||
};
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
{
|
||||
"devDependencies": {
|
||||
"@chiragrupani/karma-chromium-edge-launcher": "^2.2.2",
|
||||
"fs-extra": "^11.1.0",
|
||||
"globby": "^13.1.3",
|
||||
"karma": "^6.4.1",
|
||||
"karma-chrome-launcher": "^3.1.1",
|
||||
"karma-mocha": "^2.0.1",
|
||||
"karma-mocha-reporter": "^2.2.5",
|
||||
"light-server": "^2.9.1",
|
||||
"minimist": "^1.2.7",
|
||||
"mocha": "^10.2.0"
|
||||
}
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
'use strict';
|
||||
|
||||
const path = require('path');
|
||||
const fs = require('fs-extra');
|
||||
const { spawn } = require('child_process');
|
||||
const startServer = require('./simple-http-server');
|
||||
const minimist = require('minimist');
|
||||
|
||||
// copy whole folder to out-side of <ORT_ROOT>/js/ because we need to test in a folder that no `package.json` file
|
||||
// exists in its parent folder.
|
||||
// here we use <ORT_ROOT>/build/js/e2e-training/ for the test
|
||||
|
||||
const TEST_E2E_SRC_FOLDER = __dirname;
|
||||
const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..');
|
||||
const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training');
|
||||
const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache');
|
||||
const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data');
|
||||
fs.emptyDirSync(TEST_E2E_RUN_FOLDER);
|
||||
fs.emptyDirSync(NPM_CACHE_FOLDER);
|
||||
fs.emptyDirSync(CHROME_USER_DATA_FOLDER);
|
||||
fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER);
|
||||
|
||||
// training data to copy
|
||||
const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..');
|
||||
const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api');
|
||||
const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data');
|
||||
|
||||
// always use a new folder as user-data-dir
|
||||
let nextUserDataDirId = 0;
|
||||
function getNextUserDataDir() {
|
||||
const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString());
|
||||
nextUserDataDirId++;
|
||||
fs.emptyDirSync(dir);
|
||||
return dir;
|
||||
}
|
||||
|
||||
// commandline arguments
|
||||
const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default';
|
||||
|
||||
async function main() {
|
||||
// find packed package
|
||||
const { globbySync } = await import('globby');
|
||||
|
||||
const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common');
|
||||
const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER });
|
||||
|
||||
const PACKAGES_TO_INSTALL = [];
|
||||
|
||||
if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) {
|
||||
PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0]));
|
||||
} else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) {
|
||||
throw new Error('multiple packages found for onnxruntime-common.');
|
||||
}
|
||||
|
||||
const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web');
|
||||
const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER });
|
||||
if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) {
|
||||
throw new Error('cannot find exactly single package for onnxruntime-web.');
|
||||
}
|
||||
PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0]));
|
||||
|
||||
// we start here:
|
||||
|
||||
// install dev dependencies
|
||||
await runInShell(`npm install`);
|
||||
|
||||
// npm install with "--cache" to install packed packages with an empty cache folder
|
||||
await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`);
|
||||
|
||||
// prepare training data
|
||||
prepareTrainingDataByCopying();
|
||||
|
||||
console.log('===============================================================');
|
||||
console.log('Running self-hosted tests');
|
||||
console.log('===============================================================');
|
||||
// test cases with self-host (ort hosted in same origin)
|
||||
await testAllBrowserCases({ hostInKarma: true });
|
||||
|
||||
console.log('===============================================================');
|
||||
console.log('Running not self-hosted tests');
|
||||
console.log('===============================================================');
|
||||
// test cases without self-host (ort hosted in cross origin)
|
||||
const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081);
|
||||
try {
|
||||
await testAllBrowserCases({ hostInKarma: false });
|
||||
} finally {
|
||||
// close the server after all tests
|
||||
await server.close();
|
||||
}
|
||||
}
|
||||
|
||||
async function testAllBrowserCases({ hostInKarma }) {
|
||||
await runKarma({ hostInKarma, main: './browser-test-wasm.js' });
|
||||
}
|
||||
|
||||
async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) {
|
||||
console.log('===============================================================');
|
||||
console.log(`Running karma with the following binary: ${ortMain}`);
|
||||
console.log('===============================================================');
|
||||
const selfHostFlag = hostInKarma ? '--self-host' : '';
|
||||
await runInShell(
|
||||
`npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${
|
||||
ortMain
|
||||
} --test-main=${main} --user-data=${getNextUserDataDir()}`,
|
||||
);
|
||||
}
|
||||
|
||||
async function runInShell(cmd) {
|
||||
console.log('===============================================================');
|
||||
console.log(' Running command in shell:');
|
||||
console.log(' > ' + cmd);
|
||||
console.log('===============================================================');
|
||||
let complete = false;
|
||||
const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER });
|
||||
childProcess.on('close', function (code) {
|
||||
if (code !== 0) {
|
||||
process.exit(code);
|
||||
} else {
|
||||
complete = true;
|
||||
}
|
||||
});
|
||||
while (!complete) {
|
||||
await delay(100);
|
||||
}
|
||||
}
|
||||
|
||||
async function delay(ms) {
|
||||
return new Promise(function (resolve) {
|
||||
setTimeout(function () {
|
||||
resolve();
|
||||
}, ms);
|
||||
});
|
||||
}
|
||||
|
||||
function prepareTrainingDataByCopying() {
|
||||
fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST);
|
||||
console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`);
|
||||
}
|
||||
|
||||
main();
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
'use strict';
|
||||
|
||||
// this is a simple HTTP server that enables CORS.
|
||||
// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework
|
||||
|
||||
const http = require('http');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const getRequestData = (url, dir) => {
|
||||
const pathname = new URL(url, 'http://localhost').pathname;
|
||||
|
||||
let filepath;
|
||||
let mimeType;
|
||||
if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) {
|
||||
filepath = path.resolve(dir, pathname.substring(1));
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (filepath.endsWith('.wasm')) {
|
||||
mimeType = 'application/wasm';
|
||||
} else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) {
|
||||
mimeType = 'text/javascript';
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
return [filepath, mimeType];
|
||||
};
|
||||
|
||||
module.exports = function (dir, port) {
|
||||
const server = http
|
||||
.createServer(function (request, response) {
|
||||
const url = request.url.replace(/\n|\r/g, '');
|
||||
console.log(`request ${url}`);
|
||||
|
||||
const requestData = getRequestData(url, dir);
|
||||
if (!request || !requestData) {
|
||||
response.writeHead(404);
|
||||
response.end('404');
|
||||
} else {
|
||||
const [filePath, contentType] = requestData;
|
||||
fs.readFile(path.resolve(dir, filePath), function (error, content) {
|
||||
if (error) {
|
||||
if (error.code == 'ENOENT') {
|
||||
response.writeHead(404);
|
||||
response.end('404');
|
||||
} else {
|
||||
response.writeHead(500);
|
||||
response.end('500');
|
||||
}
|
||||
} else {
|
||||
response.setHeader('access-control-allow-origin', '*');
|
||||
response.writeHead(200, { 'Content-Type': contentType });
|
||||
response.end(content, 'utf-8');
|
||||
}
|
||||
});
|
||||
}
|
||||
})
|
||||
.listen(port);
|
||||
console.log(`Server running at http://localhost:${port}/`);
|
||||
return server;
|
||||
};
|
||||
4
js/web/types.d.ts
vendored
4
js/web/types.d.ts
vendored
|
|
@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' {
|
|||
declare module 'onnxruntime-web/webgpu' {
|
||||
export * from 'onnxruntime-web';
|
||||
}
|
||||
|
||||
declare module 'onnxruntime-web/training' {
|
||||
export * from 'onnxruntime-web';
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue