mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
### Description
enable external data loading for ort-web.
### Why
The ORT external data design is highly depending on the file system,
especially synchronous file I/O APIs. Those are not available in web
platforms. We need to have extra code to make external data working on
web.
### How
Considering there is no file system in web, an implementation for web to
support external data is to use pre-loaded data. Assume model file
a.onnx includes initializers that linked to ./b.bin, we require users to
pass a full data file list when creating the session. The user code will
be look like:
```js
const mySess = await ort.InferenceSession.create('./path/model/a.onnx', {
// session options
externalData: [
{
// relative or absolute path/URL of the file,
// or a pre-loaded Uint8Array containing the data of the external data file
data: './path/data/b.bin',
// the relative path of the external data. Should match initializers' "location" value defined in the model file
path: './b.bin'
},
// { } if multiple external data file
]
});
```
Currently, this feature only works with JSEP build enabled.
126 lines
4.4 KiB
TypeScript
126 lines
4.4 KiB
TypeScript
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
|
|
|
|
import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
|
|
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
|
|
import {isGpuBufferSupportedType} from './wasm-common';
|
|
import {loadFile} from './wasm-utils-load-file';
|
|
|
|
export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
|
|
switch (tensor.location) {
|
|
case 'cpu':
|
|
return [tensor.type, tensor.dims, tensor.data, 'cpu'];
|
|
case 'gpu-buffer':
|
|
return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer'];
|
|
default:
|
|
throw new Error(`invalid data location: ${tensor.location} for ${getName()}`);
|
|
}
|
|
};
|
|
|
|
export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
|
|
switch (tensor[3]) {
|
|
case 'cpu':
|
|
return new Tensor(tensor[0], tensor[2], tensor[1]);
|
|
case 'gpu-buffer': {
|
|
const dataType = tensor[0];
|
|
if (!isGpuBufferSupportedType(dataType)) {
|
|
throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`);
|
|
}
|
|
const {gpuBuffer, download, dispose} = tensor[2];
|
|
return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose});
|
|
}
|
|
default:
|
|
throw new Error(`invalid data location: ${tensor[3]}`);
|
|
}
|
|
};
|
|
|
|
export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler {
|
|
private sessionId: number;
|
|
|
|
inputNames: string[];
|
|
outputNames: string[];
|
|
|
|
async fetchModelAndCopyToWasmMemory(path: string): Promise<SerializableInternalBuffer> {
|
|
// fetch model from url and move to wasm heap.
|
|
return copyFromExternalBuffer(await loadFile(path));
|
|
}
|
|
|
|
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise<void> {
|
|
TRACE_FUNC_BEGIN();
|
|
let model: Parameters<typeof createSession>[0];
|
|
|
|
if (typeof pathOrBuffer === 'string') {
|
|
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
|
|
// node
|
|
model = await loadFile(pathOrBuffer);
|
|
} else {
|
|
// browser
|
|
// fetch model and copy to wasm heap.
|
|
model = await this.fetchModelAndCopyToWasmMemory(pathOrBuffer);
|
|
}
|
|
} else {
|
|
model = pathOrBuffer;
|
|
}
|
|
|
|
[this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
|
|
TRACE_FUNC_END();
|
|
}
|
|
|
|
async dispose(): Promise<void> {
|
|
return releaseSession(this.sessionId);
|
|
}
|
|
|
|
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
|
|
Promise<SessionHandler.ReturnType> {
|
|
TRACE_FUNC_BEGIN();
|
|
const inputArray: Tensor[] = [];
|
|
const inputIndices: number[] = [];
|
|
Object.entries(feeds).forEach(kvp => {
|
|
const name = kvp[0];
|
|
const tensor = kvp[1];
|
|
const index = this.inputNames.indexOf(name);
|
|
if (index === -1) {
|
|
throw new Error(`invalid input '${name}'`);
|
|
}
|
|
inputArray.push(tensor);
|
|
inputIndices.push(index);
|
|
});
|
|
|
|
const outputArray: Array<Tensor|null> = [];
|
|
const outputIndices: number[] = [];
|
|
Object.entries(fetches).forEach(kvp => {
|
|
const name = kvp[0];
|
|
const tensor = kvp[1];
|
|
const index = this.outputNames.indexOf(name);
|
|
if (index === -1) {
|
|
throw new Error(`invalid output '${name}'`);
|
|
}
|
|
outputArray.push(tensor);
|
|
outputIndices.push(index);
|
|
});
|
|
|
|
const inputs =
|
|
inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`));
|
|
const outputs = outputArray.map(
|
|
(t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null);
|
|
|
|
const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
|
|
|
|
const resultMap: SessionHandler.ReturnType = {};
|
|
for (let i = 0; i < results.length; i++) {
|
|
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
|
|
}
|
|
TRACE_FUNC_END();
|
|
return resultMap;
|
|
}
|
|
|
|
startProfiling(): void {
|
|
// TODO: implement profiling
|
|
}
|
|
|
|
endProfiling(): void {
|
|
void endProfiling(this.sessionId);
|
|
}
|
|
}
|