mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
247 lines
8.8 KiB
JavaScript
247 lines
8.8 KiB
JavaScript
|
|
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||
|
|
// Licensed under the MIT License.
|
||
|
|
|
||
|
|
'use strict';
|
||
|
|
|
||
|
|
const DATA_FOLDER = 'data/';
|
||
|
|
const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx';
|
||
|
|
const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx';
|
||
|
|
const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx';
|
||
|
|
const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt';
|
||
|
|
|
||
|
|
const trainingSessionAllOptions = {
|
||
|
|
checkpointState: TRAININGDATA_CKPT,
|
||
|
|
trainModel: TRAININGDATA_TRAIN_MODEL,
|
||
|
|
evalModel: TRAININGDATA_EVAL_MODEL,
|
||
|
|
optimizerModel: TRAININGDATA_OPTIMIZER_MODEL
|
||
|
|
}
|
||
|
|
|
||
|
|
const trainingSessionMinOptions = {
|
||
|
|
checkpointState: TRAININGDATA_CKPT,
|
||
|
|
trainModel: TRAININGDATA_TRAIN_MODEL,
|
||
|
|
}
|
||
|
|
|
||
|
|
// ASSERT METHODS
|
||
|
|
|
||
|
|
function assert(cond) {
|
||
|
|
if (!cond) throw new Error();
|
||
|
|
}
|
||
|
|
|
||
|
|
function assertStrictEquals(actual, expected) {
|
||
|
|
if (actual !== expected) {
|
||
|
|
let strRep = actual;
|
||
|
|
if (typeof actual === 'object') {
|
||
|
|
strRep = JSON.stringify(actual);
|
||
|
|
}
|
||
|
|
throw new Error(`expected: ${expected}; got: ${strRep}`);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
function assertTwoListsUnequal(list1, list2) {
|
||
|
|
if (list1.length !== list2.length) {
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
for (let i = 0; i < list1.length; i++) {
|
||
|
|
if (list1[i] !== list2[i]) {
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`);
|
||
|
|
}
|
||
|
|
|
||
|
|
// HELPER METHODS FOR TESTS
|
||
|
|
|
||
|
|
function generateGaussianRandom(mean=0, scale=1) {
|
||
|
|
const u = 1 - Math.random();
|
||
|
|
const v = Math.random();
|
||
|
|
const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v);
|
||
|
|
return z * scale + mean;
|
||
|
|
}
|
||
|
|
|
||
|
|
function generateGaussianFloatArray(length) {
|
||
|
|
const array = new Float32Array(length);
|
||
|
|
|
||
|
|
for (let i = 0; i < length; i++) {
|
||
|
|
array[i] = generateGaussianRandom();
|
||
|
|
}
|
||
|
|
|
||
|
|
return array;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* creates the TrainingSession and verifies that the input and output names of the training model loaded into the
|
||
|
|
* training session are correct.
|
||
|
|
* @param {} ort
|
||
|
|
* @param {*} createOptions
|
||
|
|
* @param {*} options
|
||
|
|
* @returns
|
||
|
|
*/
|
||
|
|
async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) {
|
||
|
|
const trainingSession = await ort.TrainingSession.create(createOptions, options);
|
||
|
|
|
||
|
|
assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0');
|
||
|
|
assertStrictEquals(trainingSession.trainingInputNames[1], 'labels');
|
||
|
|
assertStrictEquals(trainingSession.trainingInputNames.length, 2);
|
||
|
|
assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273');
|
||
|
|
assertStrictEquals(trainingSession.trainingOutputNames.length, 1);
|
||
|
|
return trainingSession;
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* verifies that the eval input and output names associated with the eval model loaded into the given training session
|
||
|
|
* are correct.
|
||
|
|
*/
|
||
|
|
function checkEvalModel(trainingSession) {
|
||
|
|
assertStrictEquals(trainingSession.evalInputNames[0], 'input-0');
|
||
|
|
assertStrictEquals(trainingSession.evalInputNames[1], 'labels');
|
||
|
|
assertStrictEquals(trainingSession.evalInputNames.length, 2);
|
||
|
|
assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273');
|
||
|
|
assertStrictEquals(trainingSession.evalOutputNames.length, 1);
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if
|
||
|
|
* accessed
|
||
|
|
* @param {} trainingSession
|
||
|
|
*/
|
||
|
|
function checkNoEvalModel(trainingSession) {
|
||
|
|
try {
|
||
|
|
assertStrictEquals(trainingSession.evalInputNames, "should have thrown an error upon accessing");
|
||
|
|
} catch (error) {
|
||
|
|
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
|
||
|
|
}
|
||
|
|
try {
|
||
|
|
assertStrictEquals(trainingSession.evalOutputNames, "should have thrown an error upon accessing");
|
||
|
|
} catch (error) {
|
||
|
|
assertStrictEquals(error.message, 'This training session has no evalModel loaded.');
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
/**
|
||
|
|
* runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length
|
||
|
|
* of 1 for the loss.
|
||
|
|
* @param {} trainingSession
|
||
|
|
* @param {*} feeds
|
||
|
|
* @returns
|
||
|
|
*/
|
||
|
|
var runTrainStepAndCheck = async function(trainingSession, feeds) {
|
||
|
|
const results = await trainingSession.runTrainStep(feeds);
|
||
|
|
assertStrictEquals(Object.keys(results).length, 1);
|
||
|
|
assertStrictEquals(results['onnx::loss::21273'].data.length, 1);
|
||
|
|
assertStrictEquals(results['onnx::loss::21273'].type, 'float32');
|
||
|
|
return results;
|
||
|
|
};
|
||
|
|
|
||
|
|
var loadParametersBufferAndCheck = async function(trainingSession, paramsLength, constant, paramsBefore) {
|
||
|
|
// make a float32 array that is filled with the constant
|
||
|
|
const newParams = new Float32Array(paramsLength);
|
||
|
|
for (let i = 0; i < paramsLength; i++) {
|
||
|
|
newParams[i] = constant;
|
||
|
|
}
|
||
|
|
|
||
|
|
const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength);
|
||
|
|
|
||
|
|
await trainingSession.loadParametersBuffer(newParamsUint8);
|
||
|
|
const paramsAfterLoad = await trainingSession.getContiguousParameters();
|
||
|
|
|
||
|
|
// check that the parameters have changed
|
||
|
|
assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data);
|
||
|
|
assertStrictEquals(paramsAfterLoad.dims[0], paramsLength);
|
||
|
|
|
||
|
|
// check that the parameters have changed to what they should be
|
||
|
|
for (let i = 0; i < paramsLength; i++) {
|
||
|
|
// round to the same number of digits (4 decimal places)
|
||
|
|
assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4));
|
||
|
|
}
|
||
|
|
|
||
|
|
return paramsAfterLoad;
|
||
|
|
}
|
||
|
|
|
||
|
|
// TESTS
|
||
|
|
|
||
|
|
var testInferenceFunction = async function(ort, options) {
|
||
|
|
const session = await ort.InferenceSession.create('data/model.onnx', options || {});
|
||
|
|
|
||
|
|
const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
|
||
|
|
const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]);
|
||
|
|
|
||
|
|
const fetches =
|
||
|
|
await session.run({a: new ort.Tensor('float32', dataA, [3, 4]), b: new ort.Tensor('float32', dataB, [4, 3])});
|
||
|
|
|
||
|
|
const c = fetches.c;
|
||
|
|
|
||
|
|
assert(c instanceof ort.Tensor);
|
||
|
|
assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3);
|
||
|
|
assert(c.data[0] === 700);
|
||
|
|
assert(c.data[1] === 800);
|
||
|
|
assert(c.data[2] === 900);
|
||
|
|
assert(c.data[3] === 1580);
|
||
|
|
assert(c.data[4] === 1840);
|
||
|
|
assert(c.data[5] === 2100);
|
||
|
|
assert(c.data[6] === 2460);
|
||
|
|
assert(c.data[7] === 2880);
|
||
|
|
assert(c.data[8] === 3300);
|
||
|
|
};
|
||
|
|
|
||
|
|
var testTrainingFunctionMin = async function(ort, options) {
|
||
|
|
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options);
|
||
|
|
checkNoEvalModel(trainingSession);
|
||
|
|
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
|
||
|
|
const labels = new ort.Tensor('int32', [2, 1], [2]);
|
||
|
|
const feeds = {"input-0": input0, "labels": labels};
|
||
|
|
|
||
|
|
// check getParametersSize
|
||
|
|
const paramsSize = await trainingSession.getParametersSize();
|
||
|
|
assertStrictEquals(paramsSize, 397510);
|
||
|
|
|
||
|
|
// check getContiguousParameters
|
||
|
|
const originalParams = await trainingSession.getContiguousParameters();
|
||
|
|
assertStrictEquals(originalParams.dims.length, 1);
|
||
|
|
assertStrictEquals(originalParams.dims[0], 397510);
|
||
|
|
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
|
||
|
|
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
|
||
|
|
|
||
|
|
await runTrainStepAndCheck(trainingSession, feeds);
|
||
|
|
|
||
|
|
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams);
|
||
|
|
}
|
||
|
|
|
||
|
|
var testTrainingFunctionAll = async function(ort, options) {
|
||
|
|
const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options);
|
||
|
|
checkEvalModel(trainingSession);
|
||
|
|
|
||
|
|
const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]);
|
||
|
|
const labels = new ort.Tensor('int32', [2, 1], [2]);
|
||
|
|
let feeds = {"input-0": input0, "labels": labels};
|
||
|
|
|
||
|
|
// check getParametersSize
|
||
|
|
const paramsSize = await trainingSession.getParametersSize();
|
||
|
|
assertStrictEquals(paramsSize, 397510);
|
||
|
|
|
||
|
|
// check getContiguousParameters
|
||
|
|
const originalParams = await trainingSession.getContiguousParameters();
|
||
|
|
assertStrictEquals(originalParams.dims.length, 1);
|
||
|
|
assertStrictEquals(originalParams.dims[0], 397510);
|
||
|
|
assertStrictEquals(originalParams.data[0], -0.025190064683556557);
|
||
|
|
assertStrictEquals(originalParams.data[2000], -0.034044936299324036);
|
||
|
|
|
||
|
|
const results = await runTrainStepAndCheck(trainingSession, feeds);
|
||
|
|
|
||
|
|
await trainingSession.runOptimizerStep(feeds);
|
||
|
|
feeds = {"input-0": input0, "labels": labels};
|
||
|
|
// check getContiguousParameters after optimizerStep -- that the parameters have been updated
|
||
|
|
const optimizedParams = await trainingSession.getContiguousParameters();
|
||
|
|
assertTwoListsUnequal(originalParams.data, optimizedParams.data);
|
||
|
|
|
||
|
|
const results2 = await runTrainStepAndCheck(trainingSession, feeds);
|
||
|
|
|
||
|
|
// check that loss decreased after optimizer step and training again
|
||
|
|
assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data);
|
||
|
|
|
||
|
|
await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams);
|
||
|
|
}
|
||
|
|
|
||
|
|
if (typeof module === 'object') {
|
||
|
|
module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest];
|
||
|
|
}
|