onnxruntime/js/web/lib/wasm/session-handler-training.ts
Yulong Wang abdc31de40
[js] change default formatter for JavaScript/TypeScript from clang-format to Prettier (#21728)
### Description

See
454996d496
for manual changes (excluded auto-generated formatting changes)

### Why

Because the toolsets for old clang-format is out-of-date. This reduces
the development efficiency.

- The NPM package `clang-format` is already in maintenance mode. not
updated since 2 years ago.
- The VSCode extension for clang-format is not maintained for a while,
and a recent Node.js security update made it not working at all in
Windows.

No one in community seems interested in fixing those.

Choose Prettier as it is the most popular TS/JS formatter.

### How to merge

It's easy to break the build:
- Be careful of any new commits on main not included in this PR.
- Be careful that after this PR is merged, other PRs that already passed
CI can merge.

So, make sure there is no new commits before merging this one, and
invalidate js PRs that already passed CI, force them to merge to latest.
2024-08-14 16:51:22 -07:00

198 lines
7.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common';
import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages';
import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference';
import { copyFromExternalBuffer } from './wasm-core-impl';
import {
createCheckpointHandle,
createTrainingSessionHandle,
getContiguousParameters,
getModelInputOutputNames,
getParametersSize,
lazyResetGrad,
loadParametersBuffer,
releaseTrainingSessionAndCheckpoint,
runEvalStep,
runOptimizerStep,
runTrainStep,
} from './wasm-training-core-impl';
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
private sessionId: number;
private checkpointId: number;
inputNames: string[];
outputNames: string[];
evalInputNames: string[] = [];
evalOutputNames: string[] = [];
async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise<SerializableInternalBuffer> {
let buffer: Uint8Array;
if (typeof uriOrBuffer === 'string') {
const response = await fetch(uriOrBuffer);
const arrayBuffer = await response.arrayBuffer();
buffer = new Uint8Array(arrayBuffer);
} else {
buffer = uriOrBuffer;
}
return copyFromExternalBuffer(buffer);
}
async createTrainingSession(
checkpointStateUriOrBuffer: string | Uint8Array,
trainModelUriOrBuffer: string | Uint8Array,
evalModelUriOrBuffer: string | Uint8Array,
optimizerModelUriOrBuffer: string | Uint8Array,
options: InferenceSession.SessionOptions,
) {
const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
// 0 is supposed to be the nullptr
let evalModelData: SerializableInternalBuffer = [0, 0];
let optimizerModelData: SerializableInternalBuffer = [0, 0];
if (evalModelUriOrBuffer !== '') {
evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
}
if (optimizerModelUriOrBuffer !== '') {
optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
}
this.checkpointId = createCheckpointHandle(checkpointData);
this.sessionId = createTrainingSessionHandle(
this.checkpointId,
trainModelData,
evalModelData,
optimizerModelData,
options,
);
[this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
if (evalModelUriOrBuffer !== '') {
[this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
}
}
/**
* Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the
* corresponding name as a number referring to the index in the list of names provided.
*
* @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType
* @param names either inputNames or outputNames
* @returns a tuple of a list of values and a list of indices.
*/
convertMapIntoValuesArrayAndIndicesArray<T, U>(
feeds: { [name: string]: T },
names: string[],
mapFunc: (val: T, index: number) => U,
): [T[], number[], U[]] {
const values: T[] = [];
const indices: number[] = [];
Object.entries(feeds).forEach((kvp) => {
const name = kvp[0];
const tensor = kvp[1];
const index = names.indexOf(name);
if (index === -1) {
throw new Error(`invalid input '${name}`);
}
values.push(tensor);
indices.push(index);
});
const uList = values.map(mapFunc);
return [values, indices, uList];
}
/**
* Helper method that converts the TensorMetadata that the wasm-core functions return to the
* SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the
* corresponding result.
*
* @param results used to populate the resultMap if there is no value for that outputName already
* @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results
* @param outputIndices specifies which outputName the corresponding value for outputArray refers to.
* @returns a map of output names and OnnxValues.
*/
convertTensorMetadataToReturnType(
results: TensorMetadata[],
outputArray: Array<Tensor | null>,
outputIndices: number[],
): SessionHandler.ReturnType {
const resultMap: SessionHandler.ReturnType = {};
for (let i = 0; i < results.length; i++) {
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
return resultMap;
}
async lazyResetGrad(): Promise<void> {
await lazyResetGrad(this.sessionId);
}
async runTrainStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds,
this.inputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`),
);
const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
Tensor | null,
TensorMetadata | null
>(fetches, this.outputNames, (t, i): TensorMetadata | null =>
t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null,
);
const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
async runOptimizerStep(options: InferenceSession.RunOptions): Promise<void> {
await runOptimizerStep(this.sessionId, options);
}
async runEvalStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType> {
const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray<Tensor, TensorMetadata>(
feeds,
this.evalInputNames,
(t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`),
);
const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray<
Tensor | null,
TensorMetadata | null
>(fetches, this.evalOutputNames, (t, i): TensorMetadata | null =>
t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null,
);
const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
async getParametersSize(trainableOnly: boolean): Promise<number> {
return getParametersSize(this.sessionId, trainableOnly);
}
async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise<void> {
await loadParametersBuffer(this.sessionId, array, trainableOnly);
}
async getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue> {
const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly);
return decodeTensorMetadata(tensorResult);
}
async dispose(): Promise<void> {
return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
}
}