mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Add "glue" between training WASM artifacts and training web (#17474)
### Description
* follows the packaging approach according to the design document
* adds `ENABLE_TRAINING` boolean flag to `BUILD_DEFS`
* modifies `package.json` to include training submodule
* modifies build script to handle, validate, and minimize training WASM
artifacts
* adds the binding for the new backend with training enabled & the new
training artifacts
* adds training backend
* edits `index.ts` to use training backend depending on `BUILD_DEFS`
* edits `wasm-factory.ts` to use the training artifacts if necessary
### Motivation and Context
* we are in the process of adding web bindings to enable training.
* Adding the "glue" to allow onnxruntime-web to use the training WASM
artifacts is required for this work.
* Since BUILD_DEFS is defined and used at build time, I thought that it
made sense to bundle the changes to building in the same PR.
#### Related work
* #16521 allowed for training artifacts to be built
* #17333 must be merged in before this one
---------
Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
parent
809c8905fe
commit
c373a808a2
9 changed files with 77 additions and 12 deletions
5
js/web/lib/backend-wasm-inference.ts
Normal file
5
js/web/lib/backend-wasm-inference.ts
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
|
||||
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
|
||||
17
js/web/lib/backend-wasm-training.ts
Normal file
17
js/web/lib/backend-wasm-training.ts
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';
|
||||
|
||||
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
|
||||
|
||||
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
|
||||
async createTrainingSessionHandler(
|
||||
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
|
||||
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
|
||||
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
|
||||
throw new Error('Method not implemented yet.');
|
||||
}
|
||||
}
|
||||
|
||||
export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();
|
||||
|
|
@ -32,7 +32,7 @@ export const initializeFlags = (): void => {
|
|||
}
|
||||
};
|
||||
|
||||
class OnnxruntimeWebAssemblyBackend implements Backend {
|
||||
export class OnnxruntimeWebAssemblyBackend implements Backend {
|
||||
async init(): Promise<void> {
|
||||
// populate wasm flags
|
||||
initializeFlags();
|
||||
|
|
@ -51,5 +51,3 @@ class OnnxruntimeWebAssemblyBackend implements Backend {
|
|||
return Promise.resolve(handler);
|
||||
}
|
||||
}
|
||||
|
||||
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
|
||||
|
|
|
|||
4
js/web/lib/build-def.d.ts
vendored
4
js/web/lib/build-def.d.ts
vendored
|
|
@ -30,6 +30,10 @@ interface BuildDefinitions {
|
|||
* defines whether to disable multi-threading feature in WebAssembly backend in the build.
|
||||
*/
|
||||
readonly DISABLE_WASM_THREAD: boolean;
|
||||
/**
|
||||
* defines whether to disable training APIs in WebAssembly backend.
|
||||
*/
|
||||
readonly DISABLE_TRAINING: boolean;
|
||||
}
|
||||
|
||||
declare const BUILD_DEFS: BuildDefinitions;
|
||||
|
|
|
|||
|
|
@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
|
|||
}
|
||||
|
||||
if (!BUILD_DEFS.DISABLE_WASM) {
|
||||
const wasmBackend = require('./backend-wasm').wasmBackend;
|
||||
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
|
||||
require('./backend-wasm-training').wasmBackend;
|
||||
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
|
||||
registerBackend('webgpu', wasmBackend, 5);
|
||||
}
|
||||
registerBackend('cpu', wasmBackend, 10);
|
||||
registerBackend('wasm', wasmBackend, 10);
|
||||
registerBackend('xnnpack', wasmBackend, 9);
|
||||
registerBackend('webnn', wasmBackend, 9);
|
||||
if (BUILD_DEFS.DISABLE_TRAINING) {
|
||||
registerBackend('xnnpack', wasmBackend, 9);
|
||||
registerBackend('webnn', wasmBackend, 9);
|
||||
}
|
||||
}
|
||||
|
||||
Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
|
||||
|
|
|
|||
|
|
@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm';
|
|||
import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded';
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-require-imports */
|
||||
const ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule> =
|
||||
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
|
||||
let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>;
|
||||
|
||||
if (!BUILD_DEFS.DISABLE_TRAINING) {
|
||||
ortWasmFactory = require('./binding/ort-training-wasm-simd.js');
|
||||
} else {
|
||||
ortWasmFactory =
|
||||
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
|
||||
}
|
||||
|
||||
const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ?
|
||||
(BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') :
|
||||
|
|
@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => {
|
|||
};
|
||||
|
||||
const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
|
||||
if (useThreads) {
|
||||
return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm';
|
||||
if (useSimd) {
|
||||
if (!BUILD_DEFS.DISABLE_TRAINING) {
|
||||
return 'ort-training-wasm-simd.wasm';
|
||||
}
|
||||
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
|
||||
} else {
|
||||
return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm';
|
||||
return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm';
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -151,6 +151,20 @@
|
|||
"development": "./dist/ort.webgpu.js",
|
||||
"default": "./dist/ort.webgpu.min.js"
|
||||
}
|
||||
},
|
||||
"./training": {
|
||||
"import": {
|
||||
"development": "./dist/esm/ort.training.wasm.js",
|
||||
"default": "./dist/esm/ort.training.wasm.min.js"
|
||||
},
|
||||
"require": {
|
||||
"development": "./dist/cjs/ort.training.wasm.js",
|
||||
"default": "./dist/cjs/ort.training.wasm.min.js"
|
||||
},
|
||||
"default": {
|
||||
"development": "./dist/ort.training.wasm.js",
|
||||
"default": "./dist/ort.training.wasm.min.js"
|
||||
}
|
||||
}
|
||||
},
|
||||
"types": "./types.d.ts",
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ const DEFAULT_DEFINE = {
|
|||
'BUILD_DEFS.DISABLE_WASM': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
|
||||
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
|
||||
'BUILD_DEFS.DISABLE_TRAINING': 'true',
|
||||
};
|
||||
|
||||
const COPYRIGHT_HEADER = `/*!
|
||||
|
|
@ -407,7 +408,7 @@ async function main() {
|
|||
});
|
||||
// ort.wasm-core[.min].js
|
||||
await addAllWebBuildTasks({
|
||||
outputBundleName: 'ort.wasm-core.min',
|
||||
outputBundleName: 'ort.wasm-core',
|
||||
define: {
|
||||
...DEFAULT_DEFINE,
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
|
||||
|
|
@ -416,6 +417,16 @@ async function main() {
|
|||
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
|
||||
},
|
||||
});
|
||||
// ort.training.wasm[.min].js
|
||||
await addAllWebBuildTasks({
|
||||
outputBundleName: 'ort.training.wasm',
|
||||
define: {
|
||||
...DEFAULT_DEFINE,
|
||||
'BUILD_DEFS.DISABLE_TRAINING': 'false',
|
||||
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
|
||||
'BUILD_DEFS.DISABLE_WEBGL': 'true',
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') {
|
||||
|
|
|
|||
4
js/web/types.d.ts
vendored
4
js/web/types.d.ts
vendored
|
|
@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' {
|
|||
declare module 'onnxruntime-web/webgpu' {
|
||||
export * from 'onnxruntime-web';
|
||||
}
|
||||
|
||||
declare module 'onnxruntime-web/training' {
|
||||
export * from 'onnxruntime-web';
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue