Add "glue" between training WASM artifacts and training web (#17474)

### Description
* follows the packaging approach according to the design document
    * adds `ENABLE_TRAINING` boolean flag to `BUILD_DEFS`
    * modifies `package.json` to include training submodule
* modifies build script to handle, validate, and minimize training WASM
artifacts
* adds the binding for the new backend with training enabled & the new
training artifacts
    * adds training backend
    * edits `index.ts` to use training backend depending on `BUILD_DEFS`
    * edits `wasm-factory.ts` to use the training artifacts if necessary

### Motivation and Context
* we are in the process of adding web bindings to enable training. 
* Adding the "glue" to allow onnxruntime-web to use the training WASM
artifacts is required for this work.
* Since BUILD_DEFS is defined and used at build time, I thought that it
made sense to bundle the changes to building in the same PR.
#### Related work
* #16521 allowed for training artifacts to be built
* #17333 must be merged in before this one

---------

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
Caroline Zhu 2023-10-12 11:16:56 -07:00 committed by GitHub
parent 809c8905fe
commit c373a808a2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 77 additions and 12 deletions

View file

@ -0,0 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();

View file

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
_checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
_evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
_options: InferenceSession.SessionOptions): Promise<TrainingSessionHandler> {
throw new Error('Method not implemented yet.');
}
}
export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend();

View file

@ -32,7 +32,7 @@ export const initializeFlags = (): void => {
}
};
class OnnxruntimeWebAssemblyBackend implements Backend {
export class OnnxruntimeWebAssemblyBackend implements Backend {
async init(): Promise<void> {
// populate wasm flags
initializeFlags();
@ -51,5 +51,3 @@ class OnnxruntimeWebAssemblyBackend implements Backend {
return Promise.resolve(handler);
}
}
export const wasmBackend = new OnnxruntimeWebAssemblyBackend();

View file

@ -30,6 +30,10 @@ interface BuildDefinitions {
* defines whether to disable multi-threading feature in WebAssembly backend in the build.
*/
readonly DISABLE_WASM_THREAD: boolean;
/**
* defines whether to disable training APIs in WebAssembly backend.
*/
readonly DISABLE_TRAINING: boolean;
}
declare const BUILD_DEFS: BuildDefinitions;

View file

@ -16,14 +16,17 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}
if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = require('./backend-wasm').wasmBackend;
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend :
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU && typeof navigator !== 'undefined' && navigator.gpu) {
registerBackend('webgpu', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
if (BUILD_DEFS.DISABLE_TRAINING) {
registerBackend('xnnpack', wasmBackend, 9);
registerBackend('webnn', wasmBackend, 9);
}
}
Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});

View file

@ -8,8 +8,14 @@ import {OrtWasmModule} from './binding/ort-wasm';
import {OrtWasmThreadedModule} from './binding/ort-wasm-threaded';
/* eslint-disable @typescript-eslint/no-require-imports */
const ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule> =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
let ortWasmFactory: EmscriptenModuleFactory<OrtWasmModule>;
if (!BUILD_DEFS.DISABLE_TRAINING) {
ortWasmFactory = require('./binding/ort-training-wasm-simd.js');
} else {
ortWasmFactory =
BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm.js') : require('./binding/ort-wasm-simd.jsep.js');
}
const ortWasmFactoryThreaded: EmscriptenModuleFactory<OrtWasmModule> = !BUILD_DEFS.DISABLE_WASM_THREAD ?
(BUILD_DEFS.DISABLE_WEBGPU ? require('./binding/ort-wasm-threaded.js') :
@ -72,10 +78,13 @@ const isSimdSupported = (): boolean => {
};
const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
if (useThreads) {
return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm';
if (useSimd) {
if (!BUILD_DEFS.DISABLE_TRAINING) {
return 'ort-training-wasm-simd.wasm';
}
return useThreads ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-simd.wasm';
} else {
return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm';
return useThreads ? 'ort-wasm-threaded.wasm' : 'ort-wasm.wasm';
}
};

View file

@ -151,6 +151,20 @@
"development": "./dist/ort.webgpu.js",
"default": "./dist/ort.webgpu.min.js"
}
},
"./training": {
"import": {
"development": "./dist/esm/ort.training.wasm.js",
"default": "./dist/esm/ort.training.wasm.min.js"
},
"require": {
"development": "./dist/cjs/ort.training.wasm.js",
"default": "./dist/cjs/ort.training.wasm.min.js"
},
"default": {
"development": "./dist/ort.training.wasm.js",
"default": "./dist/ort.training.wasm.min.js"
}
}
},
"types": "./types.d.ts",

View file

@ -47,6 +47,7 @@ const DEFAULT_DEFINE = {
'BUILD_DEFS.DISABLE_WASM': 'false',
'BUILD_DEFS.DISABLE_WASM_PROXY': 'false',
'BUILD_DEFS.DISABLE_WASM_THREAD': 'false',
'BUILD_DEFS.DISABLE_TRAINING': 'true',
};
const COPYRIGHT_HEADER = `/*!
@ -407,7 +408,7 @@ async function main() {
});
// ort.wasm-core[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.wasm-core.min',
outputBundleName: 'ort.wasm-core',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
@ -416,6 +417,16 @@ async function main() {
'BUILD_DEFS.DISABLE_WASM_THREAD': 'true',
},
});
// ort.training.wasm[.min].js
await addAllWebBuildTasks({
outputBundleName: 'ort.training.wasm',
define: {
...DEFAULT_DEFINE,
'BUILD_DEFS.DISABLE_TRAINING': 'false',
'BUILD_DEFS.DISABLE_WEBGPU': 'true',
'BUILD_DEFS.DISABLE_WEBGL': 'true',
},
});
}
if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') {

4
js/web/types.d.ts vendored
View file

@ -24,3 +24,7 @@ declare module 'onnxruntime-web/webgl' {
declare module 'onnxruntime-web/webgpu' {
export * from 'onnxruntime-web';
}
declare module 'onnxruntime-web/training' {
export * from 'onnxruntime-web';
}