mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
### Description
See
454996d496
for manual changes (excluded auto-generated formatting changes)
### Why
Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.
- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.
No one in community seems interested in fixing those.
Choose Prettier as it is the most popular TS/JS formatter.
### How to merge
It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.
So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
248 lines
8.8 KiB
JavaScript
248 lines
8.8 KiB
JavaScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
'use strict';
|
|
|
|
const DATA_FOLDER = 'data/';
|
|
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
|
|
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
|
|
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
|
|
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';
|
|
|
|
const trainingSessionAllOptions = {
|
|
checkpointState: TRAININGDATA_CKPT,
|
|
trainModel: TRAININGDATA_TRAIN_MODEL,
|
|
evalModel: TRAININGDATA_EVAL_MODEL,
|
|
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL,
|
|
};
|
|
|
|
const trainingSessionMinOptions = {
|
|
checkpointState: TRAININGDATA_CKPT,
|
|
trainModel: TRAININGDATA_TRAIN_MODEL,
|
|
};
|
|
|
|
// ASSERT METHODS
|
|
|
|
function assert(cond) {
|
|
if (!cond) throw new Error();
|
|
}
|
|
|
|
function assertStrictEquals(actual, expected) {
|
|
if (actual !== expected) {
|
|
let strRep = actual;
|
|
if (typeof actual === 'object') {
|
|
strRep = JSON.stringify(actual);
|
|
}
|
|
throw new Error(`expected: ${expected}; got: ${strRep}`);
|
|
}
|
|
}
|
|
|
|
function assertTwoListsUnequal(list1, list2) {
|
|
if (list1.length !== list2.length) {
|
|
return;
|
|
}
|
|
for (let i = 0; i < list1.length; i++) {
|
|
if (list1[i] !== list2[i]) {
|
|
return;
|
|
}
|
|
}
|
|
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`);
|
|
}
|
|
|
|
// HELPER METHODS FOR TESTS
|
|
|
|
function generateGaussianRandom(mean = 0, scale = 1) {
|
|
const u = 1 - Math.random();
|
|
const v = Math.random();
|
|
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
|
|
return z * scale + mean;
|
|
}
|
|
|
|
function generateGaussianFloatArray(length) {
|
|
const array = new Float32Array(length);
|
|
|
|
for (let i = 0; i < length; i++) {
|
|
array[i] = generateGaussianRandom();
|
|
}
|
|
|
|
return array;
|
|
}
|
|
|
|
/**
|
|
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
|
|
* training session are correct.
|
|
* @param {} ort
|
|
* @param {*} createOptions
|
|
* @param {*} options
|
|
* @returns
|
|
*/
|
|
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
|
|
const trainingSession = await ort.TrainingSession.create(createOptions, options);
|
|
|
|
assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
|
|
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
|
|
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
|
|
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
|
|
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
|
|
return trainingSession;
|
|
}
|
|
|
|
/**
|
|
* verifies that the eval input and output names associated with the eval model loaded into the given training session
|
|
* are correct.
|
|
*/
|
|
function checkEvalModel(trainingSession) {
|
|
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
|
|
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
|
|
assertStrictEquals(trainingSession.evalInputNames.length, 2);
|
|
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
|
|
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
|
|
}
|
|
|
|
/**
|
|
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
|
|
* accessed
|
|
* @param {} trainingSession
|
|
*/
|
|
function checkNoEvalModel(trainingSession) {
|
|
try {
|
|
assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing');
|
|
} catch (error) {
|
|
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
|
|
}
|
|
try {
|
|
assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing');
|
|
} catch (error) {
|
|
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
|
|
}
|
|
}
|
|
|
|
/**
|
|
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
|
|
* of 1 for the loss.
|
|
* @param {} trainingSession
|
|
* @param {*} feeds
|
|
* @returns
|
|
*/
|
|
var runTrainStepAndCheck = async function (trainingSession, feeds) {
|
|
const results = await trainingSession.runTrainStep(feeds);
|
|
assertStrictEquals(Object.keys(results).length, 1);
|
|
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
|
|
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
|
|
return results;
|
|
};
|
|
|
|
var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) {
|
|
// make a float32 array that is filled with the constant
|
|
const newParams = new Float32Array(paramsLength);
|
|
for (let i = 0; i < paramsLength; i++) {
|
|
newParams[i] = constant;
|
|
}
|
|
|
|
const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);
|
|
|
|
await trainingSession.loadParametersBuffer(newParamsUint8);
|
|
const paramsAfterLoad = await trainingSession.getContiguousParameters();
|
|
|
|
// check that the parameters have changed
|
|
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data);
|
|
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);
|
|
|
|
// check that the parameters have changed to what they should be
|
|
for (let i = 0; i < paramsLength; i++) {
|
|
// round to the same number of digits (4 decimal places)
|
|
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
|
|
}
|
|
|
|
return paramsAfterLoad;
|
|
};
|
|
|
|
// TESTS
|
|
|
|
var testInferenceFunction = async function (ort, options) {
|
|
const session = await ort.InferenceSession.create('data/model.onnx', options || {});
|
|
|
|
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
|
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);
|
|
|
|
const fetches = await session.run({
|
|
a: new ort.Tensor('float32', dataA, [3, 4]),
|
|
b: new ort.Tensor('float32', dataB, [4, 3]),
|
|
});
|
|
|
|
const c = fetches.c;
|
|
|
|
assert(c instanceof ort.Tensor);
|
|
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
|
|
assert(c.data[0] === 700);
|
|
assert(c.data[1] === 800);
|
|
assert(c.data[2] === 900);
|
|
assert(c.data[3] === 1580);
|
|
assert(c.data[4] === 1840);
|
|
assert(c.data[5] === 2100);
|
|
assert(c.data[6] === 2460);
|
|
assert(c.data[7] === 2880);
|
|
assert(c.data[8] === 3300);
|
|
};
|
|
|
|
var testTrainingFunctionMin = async function (ort, options) {
|
|
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
|
|
checkNoEvalModel(trainingSession);
|
|
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
|
|
const labels = new ort.Tensor('int32', [2, 1], [2]);
|
|
const feeds = { 'input-0': input0, labels: labels };
|
|
|
|
// check getParametersSize
|
|
const paramsSize = await trainingSession.getParametersSize();
|
|
assertStrictEquals(paramsSize, 397510);
|
|
|
|
// check getContiguousParameters
|
|
const originalParams = await trainingSession.getContiguousParameters();
|
|
assertStrictEquals(originalParams.dims.length, 1);
|
|
assertStrictEquals(originalParams.dims[0], 397510);
|
|
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
|
|
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
|
|
|
|
await runTrainStepAndCheck(trainingSession, feeds);
|
|
|
|
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
|
|
};
|
|
|
|
var testTrainingFunctionAll = async function (ort, options) {
|
|
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
|
|
checkEvalModel(trainingSession);
|
|
|
|
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
|
|
const labels = new ort.Tensor('int32', [2, 1], [2]);
|
|
let feeds = { 'input-0': input0, labels: labels };
|
|
|
|
// check getParametersSize
|
|
const paramsSize = await trainingSession.getParametersSize();
|
|
assertStrictEquals(paramsSize, 397510);
|
|
|
|
// check getContiguousParameters
|
|
const originalParams = await trainingSession.getContiguousParameters();
|
|
assertStrictEquals(originalParams.dims.length, 1);
|
|
assertStrictEquals(originalParams.dims[0], 397510);
|
|
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
|
|
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
|
|
|
|
const results = await runTrainStepAndCheck(trainingSession, feeds);
|
|
|
|
await trainingSession.runOptimizerStep(feeds);
|
|
feeds = { 'input-0': input0, labels: labels };
|
|
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
|
|
const optimizedParams = await trainingSession.getContiguousParameters();
|
|
assertTwoListsUnequal(originalParams.data, optimizedParams.data);
|
|
|
|
const results2 = await runTrainStepAndCheck(trainingSession, feeds);
|
|
|
|
// check that loss decreased after optimizer step and training again
|
|
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);
|
|
|
|
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
|
|
};
|
|
|
|
if (typeof module === 'object') {
|
|
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
|
|
}
|