onnxruntime/js/web/test/training/e2e/common.js
Yulong Wang abdc31de40
[js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728)
### Description

See
454996d496
for manual changes (excluded auto-generated formatting changes)

### Why

Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.

- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.

No one in community seems interested in fixing those.

Choose Prettier as it is the most popular TS/JS formatter.

### How to merge

It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.

So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
2024-08-14 16:51:22 -07:00

248 lines
8.8 KiB
JavaScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
'use strict';
const DATA_FOLDER = 'data/';
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';
const trainingSessionAllOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
evalModel: TRAININGDATA_EVAL_MODEL,
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL,
};
const trainingSessionMinOptions = {
checkpointState: TRAININGDATA_CKPT,
trainModel: TRAININGDATA_TRAIN_MODEL,
};
// ASSERT METHODS
function assert(cond) {
if (!cond) throw new Error();
}
function assertStrictEquals(actual, expected) {
if (actual !== expected) {
let strRep = actual;
if (typeof actual === 'object') {
strRep = JSON.stringify(actual);
}
throw new Error(`expected: ${expected}; got: ${strRep}`);
}
}
function assertTwoListsUnequal(list1, list2) {
if (list1.length !== list2.length) {
return;
}
for (let i = 0; i < list1.length; i++) {
if (list1[i] !== list2[i]) {
return;
}
}
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`);
}
// HELPER METHODS FOR TESTS
function generateGaussianRandom(mean = 0, scale = 1) {
const u = 1 - Math.random();
const v = Math.random();
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
return z * scale + mean;
}
function generateGaussianFloatArray(length) {
const array = new Float32Array(length);
for (let i = 0; i < length; i++) {
array[i] = generateGaussianRandom();
}
return array;
}
/**
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
* training session are correct.
* @param {} ort
* @param {*} createOptions
* @param {*} options
* @returns
*/
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
const trainingSession = await ort.TrainingSession.create(createOptions, options);
assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
return trainingSession;
}
/**
* verifies that the eval input and output names associated with the eval model loaded into the given training session
* are correct.
*/
function checkEvalModel(trainingSession) {
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
assertStrictEquals(trainingSession.evalInputNames.length, 2);
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
}
/**
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
* accessed
* @param {} trainingSession
*/
function checkNoEvalModel(trainingSession) {
try {
assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing');
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
try {
assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing');
} catch (error) {
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
}
}
/**
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
* of 1 for the loss.
* @param {} trainingSession
* @param {*} feeds
* @returns
*/
var runTrainStepAndCheck = async function (trainingSession, feeds) {
const results = await trainingSession.runTrainStep(feeds);
assertStrictEquals(Object.keys(results).length, 1);
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
return results;
};
var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) {
// make a float32 array that is filled with the constant
const newParams = new Float32Array(paramsLength);
for (let i = 0; i < paramsLength; i++) {
newParams[i] = constant;
}
const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);
await trainingSession.loadParametersBuffer(newParamsUint8);
const paramsAfterLoad = await trainingSession.getContiguousParameters();
// check that the parameters have changed
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data);
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);
// check that the parameters have changed to what they should be
for (let i = 0; i < paramsLength; i++) {
// round to the same number of digits (4 decimal places)
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
}
return paramsAfterLoad;
};
// TESTS
var testInferenceFunction = async function (ort, options) {
const session = await ort.InferenceSession.create('data/model.onnx', options || {});
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);
const fetches = await session.run({
a: new ort.Tensor('float32', dataA, [3, 4]),
b: new ort.Tensor('float32', dataB, [4, 3]),
});
const c = fetches.c;
assert(c instanceof ort.Tensor);
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
assert(c.data[0] === 700);
assert(c.data[1] === 800);
assert(c.data[2] === 900);
assert(c.data[3] === 1580);
assert(c.data[4] === 1840);
assert(c.data[5] === 2100);
assert(c.data[6] === 2460);
assert(c.data[7] === 2880);
assert(c.data[8] === 3300);
};
var testTrainingFunctionMin = async function (ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
checkNoEvalModel(trainingSession);
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
const feeds = { 'input-0': input0, labels: labels };
// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);
// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
await runTrainStepAndCheck(trainingSession, feeds);
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
};
var testTrainingFunctionAll = async function (ort, options) {
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
checkEvalModel(trainingSession);
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
const labels = new ort.Tensor('int32', [2, 1], [2]);
let feeds = { 'input-0': input0, labels: labels };
// check getParametersSize
const paramsSize = await trainingSession.getParametersSize();
assertStrictEquals(paramsSize, 397510);
// check getContiguousParameters
const originalParams = await trainingSession.getContiguousParameters();
assertStrictEquals(originalParams.dims.length, 1);
assertStrictEquals(originalParams.dims[0], 397510);
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
const results = await runTrainStepAndCheck(trainingSession, feeds);
await trainingSession.runOptimizerStep(feeds);
feeds = { 'input-0': input0, labels: labels };
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
const optimizedParams = await trainingSession.getContiguousParameters();
assertTwoListsUnequal(originalParams.data, optimizedParams.data);
const results2 = await runTrainStepAndCheck(trainingSession, feeds);
// check that loss decreased after optimizer step and training again
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
};
if (typeof module === 'object') {
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
}