mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Description
This PR revises the backend registration.
The following describes the expected behavior after this change:
(**bolded are changed behavior**)
- (ort.min.js - built without webgpu support)
- loading: do not register 'webgpu' backend
- creating session without EP list: use default EP list ['webnn', 'cpu',
'wasm']
- creating session with ['webgpu'] as EP list: should fail with backend
not available
- (ort.webgpu.min.js - built with webgpu support)
- loading: **always register 'webgpu' backend**
( previous behavior: only register 'webgpu' backend when `navigator.gpu`
is available)
- creating session without EP list: use default EP list ['webgpu',
'webnn', 'cpu', 'wasm']
- when WebGPU is available (win): use WebGPU backend
- when WebGPU is unavailable (android): **should fail backend init,**
and try to use next backend in the list, 'webnn'
(previous behavior: does not fail backend init, but fail in JSEP init,
which was too late to switch to next backend)
- creating session with ['webgpu'] as EP list
- when WebGPU is available (win): use WebGPU backend
- when WebGPU is unavailable (android): **should fail backend init, and
because no more EP listed, fail.
related PRs: #18190 #18144
527 lines
20 KiB
TypeScript
527 lines
20 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {InferenceSession, Tensor} from 'onnxruntime-common';
|
|
|
|
import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
|
|
import {setRunOptions} from './run-options';
|
|
import {setSessionOptions} from './session-options';
|
|
import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
|
|
import {prepareInputOutputTensor} from './wasm-core-impl';
|
|
import {getInstance} from './wasm-factory';
|
|
import {checkLastError} from './wasm-utils';
|
|
|
|
const NO_TRAIN_FUNCS_MSG =
|
|
'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' +
|
|
'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
|
|
'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
|
|
|
|
/**
|
|
* Runs the checkLastError function which will throw an error, if the provided error code matches the specified
|
|
* pattern for an error code.
|
|
* @param errCode number to evaluated for if it's an error
|
|
* @param message message to pass into checkLastError
|
|
* @param checkNeqZero when true, treats not equal to zero as an error.
|
|
* When false, treats equal to zero as an error.
|
|
*/
|
|
const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => {
|
|
if (checkNeqZero && errCode !== 0) {
|
|
checkLastError(message);
|
|
} else if (!checkNeqZero && errCode === 0) {
|
|
checkLastError(message);
|
|
}
|
|
};
|
|
|
|
export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => {
|
|
const wasm = getInstance();
|
|
|
|
const [checkpointDataOffset, checkpointDataLength] = checkpointData;
|
|
let checkpointHandle = 0;
|
|
|
|
try {
|
|
if (wasm._OrtTrainingLoadCheckpoint) {
|
|
checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false);
|
|
return checkpointHandle;
|
|
} catch (e) {
|
|
if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
|
|
wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
|
|
}
|
|
throw e;
|
|
} finally {
|
|
// free buffer from wasm heap
|
|
wasm._OrtFree(checkpointData[0]);
|
|
}
|
|
};
|
|
|
|
const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
|
|
const wasm = getInstance();
|
|
const stack = wasm.stackSave();
|
|
try {
|
|
const dataOffset = wasm.stackAlloc(8);
|
|
if (wasm._OrtTrainingGetModelInputOutputCount) {
|
|
const errorCode =
|
|
wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
|
|
ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.');
|
|
return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
} finally {
|
|
wasm.stackRestore(stack);
|
|
}
|
|
};
|
|
|
|
const getModelInputOutputNamesLoop =
|
|
(trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => {
|
|
const names = [];
|
|
const wasm = getInstance();
|
|
|
|
for (let i = 0; i < count; i++) {
|
|
if (wasm._OrtTrainingGetModelInputOutputName) {
|
|
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
|
|
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
|
|
|
|
names.push(wasm.UTF8ToString(name));
|
|
wasm._free(name);
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
}
|
|
return names;
|
|
};
|
|
|
|
export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
|
|
let inputNames: string[] = [];
|
|
let outputNames: string[] = [];
|
|
|
|
const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
|
|
|
|
inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
|
|
outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
|
|
|
|
return [inputNames, outputNames];
|
|
};
|
|
|
|
export const createTrainingSessionHandle =
|
|
(checkpointHandle: number, trainModelData: SerializableInternalBuffer, evalModelData: SerializableInternalBuffer,
|
|
optimizerModelData: SerializableInternalBuffer, options: InferenceSession.SessionOptions): number => {
|
|
const wasm = getInstance();
|
|
|
|
let trainingSessionHandle = 0;
|
|
let sessionOptionsHandle = 0;
|
|
let allocs: number[] = [];
|
|
|
|
try {
|
|
[sessionOptionsHandle, allocs] = setSessionOptions(options);
|
|
if (wasm._OrtTrainingCreateSession) {
|
|
trainingSessionHandle = wasm._OrtTrainingCreateSession(
|
|
sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0],
|
|
evalModelData[1], optimizerModelData[0], optimizerModelData[1]);
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
|
|
return trainingSessionHandle;
|
|
} catch (e) {
|
|
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
|
|
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
|
|
}
|
|
throw e;
|
|
} finally {
|
|
wasm._free(trainModelData[0]);
|
|
wasm._free(evalModelData[0]);
|
|
wasm._free(optimizerModelData[0]);
|
|
|
|
if (sessionOptionsHandle !== 0) {
|
|
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
|
|
}
|
|
allocs.forEach(alloc => wasm._free(alloc));
|
|
}
|
|
};
|
|
|
|
/**
|
|
* Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the
|
|
* WASM tensors.
|
|
*
|
|
* @param trainingSessionId
|
|
* @param indices for each tensor, the index of the input or output name that the tensor corresponds with
|
|
* @param tensors list of TensorMetaData
|
|
* @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting
|
|
* handles of the allocated tensors on the heap
|
|
* @param inputOutputAllocs modified in-place by this method
|
|
* @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor
|
|
*/
|
|
const createAndAllocateTensors =
|
|
(trainingSessionId: number, indices: number[], tensors: Array<TensorMetadata|null>, tensorHandles: number[],
|
|
inputOutputAllocs: number[], indexAdd: number) => {
|
|
const count = indices.length;
|
|
|
|
// creates the tensors
|
|
for (let i = 0; i < count; i++) {
|
|
prepareInputOutputTensor(
|
|
tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]);
|
|
}
|
|
|
|
// moves to heap
|
|
const wasm = getInstance();
|
|
const valuesOffset = wasm.stackAlloc(count * 4);
|
|
let valuesIndex = valuesOffset / 4;
|
|
for (let i = 0; i < count; i++) {
|
|
wasm.HEAPU32[valuesIndex++] = tensorHandles[i];
|
|
}
|
|
|
|
return valuesOffset;
|
|
};
|
|
|
|
/**
|
|
* Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information
|
|
* associated with the tensor handle.
|
|
*
|
|
* @param outputValuesOffset
|
|
* @param outputCount
|
|
* @returns list of TensorMetadata retrieved from the output handles.
|
|
*/
|
|
const moveOutputToTensorMetadataArr =
|
|
(outputValuesOffset: number, outputCount: number, outputTensorHandles: number[],
|
|
outputTensors: Array<TensorMetadata|null>) => {
|
|
const wasm = getInstance();
|
|
const output: TensorMetadata[] = [];
|
|
|
|
for (let i = 0; i < outputCount; i++) {
|
|
const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i];
|
|
if (tensor === outputTensorHandles[i]) {
|
|
// output tensor is pre-allocated. no need to copy data.
|
|
output.push(outputTensors[i]!);
|
|
continue;
|
|
}
|
|
|
|
const beforeGetTensorDataStack = wasm.stackSave();
|
|
// stack allocate 4 pointer value
|
|
const tensorDataOffset = wasm.stackAlloc(4 * 4);
|
|
|
|
let type: Tensor.Type|undefined, dataOffset = 0;
|
|
try {
|
|
const errorCode = wasm._OrtGetTensorData(
|
|
tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12);
|
|
ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`);
|
|
|
|
let tensorDataIndex = tensorDataOffset / 4;
|
|
const dataType = wasm.HEAPU32[tensorDataIndex++];
|
|
dataOffset = wasm.HEAPU32[tensorDataIndex++];
|
|
const dimsOffset = wasm.HEAPU32[tensorDataIndex++];
|
|
const dimsLength = wasm.HEAPU32[tensorDataIndex++];
|
|
const dims = [];
|
|
for (let i = 0; i < dimsLength; i++) {
|
|
dims.push(wasm.HEAPU32[dimsOffset / 4 + i]);
|
|
}
|
|
wasm._OrtFree(dimsOffset);
|
|
|
|
const size = dims.reduce((a, b) => a * b, 1);
|
|
type = tensorDataTypeEnumToString(dataType);
|
|
|
|
if (type === 'string') {
|
|
const stringData: string[] = [];
|
|
let dataIndex = dataOffset / 4;
|
|
for (let i = 0; i < size; i++) {
|
|
const offset = wasm.HEAPU32[dataIndex++];
|
|
const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset;
|
|
stringData.push(wasm.UTF8ToString(offset, maxBytesToRead));
|
|
}
|
|
output.push([type, dims, stringData, 'cpu']);
|
|
} else {
|
|
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
|
|
const data = new typedArrayConstructor(size);
|
|
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
|
|
.set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength));
|
|
output.push([type, dims, data, 'cpu']);
|
|
}
|
|
} finally {
|
|
wasm.stackRestore(beforeGetTensorDataStack);
|
|
if (type === 'string' && dataOffset) {
|
|
wasm._free(dataOffset);
|
|
}
|
|
wasm._OrtReleaseTensor(tensor);
|
|
}
|
|
}
|
|
|
|
return output;
|
|
};
|
|
|
|
export const lazyResetGrad = async(trainingSessionId: number): Promise<void> => {
|
|
const wasm = getInstance();
|
|
|
|
if (wasm._OrtTrainingLazyResetGrad) {
|
|
const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
|
|
ifErrCodeCheckLastError(errorCode, 'Can\'t call lazyResetGrad.');
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
};
|
|
|
|
export const runTrainStep = async(
|
|
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
|
|
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
|
|
const wasm = getInstance();
|
|
|
|
const inputCount = inputIndices.length;
|
|
const outputCount = outputIndices.length;
|
|
|
|
let runOptionsHandle = 0;
|
|
let runOptionsAllocs: number[] = [];
|
|
|
|
const inputTensorHandles: number[] = [];
|
|
const outputTensorHandles: number[] = [];
|
|
const inputOutputAllocs: number[] = [];
|
|
|
|
const beforeRunStack = wasm.stackSave();
|
|
|
|
try {
|
|
// prepare parameters by moving them to heap
|
|
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
|
|
|
// handle inputs -- you don't want anything added to the index
|
|
const inputValuesOffset = createAndAllocateTensors(
|
|
trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0);
|
|
// handle outputs
|
|
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
|
|
const outputValuesOffset = createAndAllocateTensors(
|
|
trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount);
|
|
|
|
if (wasm._OrtTrainingRunTrainStep) {
|
|
const errorCode = wasm._OrtTrainingRunTrainStep(
|
|
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);
|
|
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer');
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
|
|
} finally {
|
|
wasm.stackRestore(beforeRunStack);
|
|
|
|
inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
|
outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
|
inputOutputAllocs.forEach(p => wasm._free(p));
|
|
|
|
if (runOptionsHandle !== 0) {
|
|
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
|
}
|
|
runOptionsAllocs.forEach(p => wasm._free(p));
|
|
}
|
|
};
|
|
|
|
export const runOptimizerStep =
|
|
async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise<void> => {
|
|
const wasm = getInstance();
|
|
|
|
let runOptionsHandle = 0;
|
|
let runOptionsAllocs: number[] = [];
|
|
|
|
try {
|
|
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
|
|
|
if (wasm._OrtTrainingOptimizerStep) {
|
|
const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
|
|
ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
} finally {
|
|
if (runOptionsHandle !== 0) {
|
|
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
|
}
|
|
runOptionsAllocs.forEach(p => wasm._free(p));
|
|
}
|
|
};
|
|
|
|
export const runEvalStep = async(
|
|
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
|
|
outputTensors: Array<TensorMetadata|null>, options: InferenceSession.RunOptions): Promise<TensorMetadata[]> => {
|
|
const wasm = getInstance();
|
|
|
|
const inputCount = inputIndices.length;
|
|
const outputCount = outputIndices.length;
|
|
|
|
let runOptionsHandle = 0;
|
|
let runOptionsAllocs: number[] = [];
|
|
|
|
const inputTensorHandles: number[] = [];
|
|
const outputTensorHandles: number[] = [];
|
|
const inputOutputAllocs: number[] = [];
|
|
|
|
const beforeRunStack = wasm.stackSave();
|
|
|
|
try {
|
|
// prepare parameters by moving them to heap
|
|
[runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
|
|
|
|
// handle inputs -- you don't want anything added to the index
|
|
const inputValuesOffset = createAndAllocateTensors(
|
|
trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0);
|
|
// handle outputs
|
|
// you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
|
|
const outputValuesOffset = createAndAllocateTensors(
|
|
trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount);
|
|
|
|
if (wasm._OrtTrainingEvalStep) {
|
|
const errorCode = wasm._OrtTrainingEvalStep(
|
|
trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);
|
|
|
|
ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
|
|
} finally {
|
|
wasm.stackRestore(beforeRunStack);
|
|
|
|
inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
|
outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
|
|
inputOutputAllocs.forEach(p => wasm._free(p));
|
|
|
|
if (runOptionsHandle !== 0) {
|
|
wasm._OrtReleaseRunOptions(runOptionsHandle);
|
|
}
|
|
runOptionsAllocs.forEach(p => wasm._free(p));
|
|
}
|
|
};
|
|
|
|
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
|
|
const wasm = getInstance();
|
|
const stack = wasm.stackSave();
|
|
|
|
try {
|
|
const sizeOffset = wasm.stackAlloc(4);
|
|
if (wasm._OrtTrainingGetParametersSize) {
|
|
const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly);
|
|
ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size');
|
|
|
|
return wasm.HEAP32[sizeOffset / 4];
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
} finally {
|
|
wasm.stackRestore(stack);
|
|
}
|
|
};
|
|
|
|
export const getContiguousParameters =
|
|
async(trainingSessionId: number, trainableOnly: boolean): Promise<TensorMetadata> => {
|
|
const wasm = getInstance();
|
|
const stack = wasm.stackSave();
|
|
|
|
const tensorTypeAsString = 'float32';
|
|
const locationAsString = 'cpu';
|
|
|
|
const parametersSize = getParametersSize(trainingSessionId, trainableOnly);
|
|
let tensor = 0;
|
|
|
|
// allocates a buffer of the correct size on the WASM heap
|
|
const paramsByteLength = 4 * parametersSize;
|
|
const paramsOffset = wasm._malloc(paramsByteLength);
|
|
|
|
// handles the dimensions-related createTensor parameters
|
|
const dims = [parametersSize];
|
|
|
|
const dimsOffset = wasm.stackAlloc(4);
|
|
const dimsIndex = dimsOffset / 4;
|
|
wasm.HEAP32[dimsIndex] = parametersSize;
|
|
|
|
try {
|
|
// wraps allocated array in a tensor
|
|
tensor = wasm._OrtCreateTensor(
|
|
tensorDataTypeStringToEnum(tensorTypeAsString), paramsOffset, paramsByteLength, dimsOffset, dims.length,
|
|
dataLocationStringToEnum(locationAsString));
|
|
ifErrCodeCheckLastError(
|
|
tensor, `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, false);
|
|
|
|
if (wasm._OrtTrainingCopyParametersToBuffer) {
|
|
const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly);
|
|
ifErrCodeCheckLastError(errCode, 'Can\'t get contiguous parameters.');
|
|
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
|
|
// copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object
|
|
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString);
|
|
const data = new typedArrayConstructor(parametersSize);
|
|
const output: TensorMetadata[] = [];
|
|
new Uint8Array(data.buffer, data.byteOffset, data.byteLength)
|
|
.set(wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength));
|
|
output.push([tensorTypeAsString, dims, data, locationAsString]);
|
|
if (output.length !== 1) {
|
|
throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of
|
|
one, got ${output.length}`);
|
|
} else {
|
|
return output[0];
|
|
}
|
|
} finally {
|
|
if (tensor !== 0) {
|
|
wasm._OrtReleaseTensor(tensor);
|
|
}
|
|
wasm._free(paramsOffset);
|
|
wasm._free(dimsOffset);
|
|
wasm.stackRestore(stack);
|
|
}
|
|
};
|
|
|
|
export const loadParametersBuffer =
|
|
async(trainingSessionId: number, buffer: Uint8Array, trainableOnly: boolean): Promise<void> => {
|
|
const wasm = getInstance();
|
|
const stack = wasm.stackSave();
|
|
|
|
const tensorTypeAsString = 'float32';
|
|
const locationAsString = 'cpu';
|
|
|
|
// allocates & copies JavaScript buffer to WASM heap
|
|
const bufferByteLength = buffer.length;
|
|
const bufferCount = bufferByteLength / 4;
|
|
const bufferOffset = wasm._malloc(bufferByteLength);
|
|
wasm.HEAPU8.set(buffer, bufferOffset);
|
|
|
|
// allocates and handles moving dimensions information to WASM memory
|
|
const dimsOffset = wasm.stackAlloc(4);
|
|
wasm.HEAP32[dimsOffset / 4] = bufferCount;
|
|
const dimsLength = 1;
|
|
let tensor = 0;
|
|
|
|
try {
|
|
tensor = wasm._OrtCreateTensor(
|
|
tensorDataTypeStringToEnum(tensorTypeAsString), bufferOffset, bufferByteLength, dimsOffset, dimsLength,
|
|
dataLocationStringToEnum(locationAsString));
|
|
ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false);
|
|
|
|
if (wasm._OrtTrainingCopyParametersFromBuffer) {
|
|
const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly);
|
|
ifErrCodeCheckLastError(errCode, 'Can\'t copy buffer to parameters.');
|
|
} else {
|
|
throw new Error(NO_TRAIN_FUNCS_MSG);
|
|
}
|
|
} finally {
|
|
if (tensor !== 0) {
|
|
wasm._OrtReleaseTensor(tensor);
|
|
}
|
|
wasm.stackRestore(stack);
|
|
wasm._free(bufferOffset);
|
|
wasm._free(dimsOffset);
|
|
}
|
|
};
|
|
|
|
export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
|
|
const wasm = getInstance();
|
|
|
|
if (wasm._OrtTrainingReleaseSession) {
|
|
wasm._OrtTrainingReleaseSession(sessionId);
|
|
}
|
|
if (wasm._OrtTrainingReleaseCheckpoint) {
|
|
wasm._OrtTrainingReleaseCheckpoint(checkpointId);
|
|
}
|
|
};
|