[js/web] support override wasm file path (#8610)

This commit is contained in:
Yulong Wang 2021-08-05 18:01:03 -07:00 committed by GitHub
parent eab6c51413
commit f3a1aebb33
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 137 additions and 28 deletions

View file

@ -3,6 +3,12 @@
import {EnvImpl} from './env-impl';
export declare namespace Env {
export type WasmPrefixOrFilePaths = string|{
'ort-wasm.wasm'?: string;
'ort-wasm-threaded.wasm'?: string;
'ort-wasm-simd.wasm'?: string;
'ort-wasm-simd-threaded.wasm'?: string;
};
export interface WebAssemblyFlags {
/**
* set or get number of thread(s). If omitted or set to 0, number of thread(s) will be determined by system. If set
@ -24,6 +30,12 @@ export declare namespace Env {
* value indicates no timeout is set. (default is 0)
*/
initTimeout?: number;
/**
* Set a custom URL prefix to the .wasm files or a set of overrides for each .wasm file. The override path should be
* an absolute path.
*/
wasmPaths?: WasmPrefixOrFilePaths;
}
export interface WebGLFlags {

View file

@ -44,6 +44,14 @@ const isSimdSupported = (): boolean => {
}
};
const getWasmFileName = (useSimd: boolean, useThreads: boolean) => {
if (useThreads) {
return useSimd ? 'ort-wasm-simd-threaded.wasm' : 'ort-wasm-threaded.wasm';
} else {
return useSimd ? 'ort-wasm-simd.wasm' : 'ort-wasm.wasm';
}
};
export const initializeWebAssembly = async(): Promise<void> => {
if (initialized) {
return Promise.resolve();
@ -64,6 +72,13 @@ export const initializeWebAssembly = async(): Promise<void> => {
const useThreads = numThreads > 1 && isMultiThreadSupported();
const useSimd = simd && isSimdSupported();
const wasmPrefixOverride = typeof env.wasm.wasmPaths === 'string' ? env.wasm.wasmPaths : undefined;
const wasmFileName = getWasmFileName(false, useThreads);
const wasmOverrideFileName = getWasmFileName(useSimd, useThreads);
const wasmPathOverride =
typeof env.wasm.wasmPaths === 'object' ? env.wasm.wasmPaths[wasmOverrideFileName] : undefined;
let isTimeout = false;
const tasks: Array<Promise<void>> = [];
@ -81,38 +96,34 @@ export const initializeWebAssembly = async(): Promise<void> => {
// promise for module initialization
tasks.push(new Promise((resolve, reject) => {
const factory = useThreads ? ortWasmFactoryThreaded : ortWasmFactory;
const config: Partial<OrtWasmModule> = {};
if (!useThreads) {
config.locateFile = (fileName: string, scriptDirectory: string) => {
if (useSimd && fileName === 'ort-wasm.wasm') {
return scriptDirectory + 'ort-wasm-simd.wasm';
const config: Partial<OrtWasmModule> = {
locateFile: (fileName: string, scriptDirectory: string) => {
if (fileName.endsWith('.worker.js') && typeof Blob !== 'undefined') {
return URL.createObjectURL(new Blob(
[
// This require() function is handled by webpack to load file content of the corresponding .worker.js
// eslint-disable-next-line @typescript-eslint/no-require-imports
require('./binding/ort-wasm-threaded.worker.js')
],
{type: 'text/javascript'}));
}
if (fileName === wasmFileName) {
const prefix: string = wasmPrefixOverride ?? scriptDirectory;
return wasmPathOverride ?? prefix + wasmOverrideFileName;
}
return scriptDirectory + fileName;
};
} else {
}
};
if (useThreads) {
if (typeof Blob === 'undefined') {
config.mainScriptUrlOrBlob = path.join(__dirname, 'ort-wasm-threaded.js');
} else {
const scriptSourceCode =
`var ortWasmThreaded=(function(){var _scriptDir;return ${ortWasmFactoryThreaded.toString()}})();`;
config.mainScriptUrlOrBlob = new Blob([scriptSourceCode], {type: 'text/javascript'});
config.locateFile = (fileName: string, scriptDirectory: string) => {
if (fileName.endsWith('.worker.js')) {
return URL.createObjectURL(new Blob(
[
// This require() function is handled by webpack to load file content of the corresponding .worker.js
// eslint-disable-next-line @typescript-eslint/no-require-imports
require('./binding/ort-wasm-threaded.worker.js')
],
{type: 'text/javascript'}));
}
if (useSimd && fileName === 'ort-wasm-threaded.wasm') {
return scriptDirectory + 'ort-wasm-simd-threaded.wasm';
}
return scriptDirectory + fileName;
};
}
}

View file

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
it('Browser E2E testing - WebAssembly backend (path override filename)', async function () {
// disable SIMD and multi-thread
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = false;
// override .wasm file path for 'ort-wasm.wasm'
ort.env.wasm.wasmPaths = {
'ort-wasm.wasm': new URL('./test-wasm-path-override/renamed.wasm', document.baseURI).href
};
await testFunction(ort, { executionProviders: ['wasm'] });
});

View file

@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
it('Browser E2E testing - WebAssembly backend (path override prefix)', async function () {
// disable SIMD and multi-thread
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = false;
// override .wasm file path prefix
ort.env.wasm.wasmPaths = new URL('./test-wasm-path-override/', document.baseURI).href;
await testFunction(ort, { executionProviders: ['wasm'] });
});

View file

@ -20,11 +20,13 @@ module.exports = function (config) {
{ pattern: distPrefix + 'ort.js' },
{ pattern: './common.js' },
{ pattern: TEST_MAIN },
{ pattern: './node_modules/onnxruntime-web/dist/**/*', included: false, nocache: true },
{ pattern: './node_modules/onnxruntime-web/dist/*.wasm', included: false, nocache: true },
{ pattern: './model.onnx', included: false }
],
proxies: {
'/model.onnx': '/base/model.onnx',
'/test-wasm-path-override/ort-wasm.wasm': '/base/node_modules/onnxruntime-web/dist/ort-wasm.wasm',
'/test-wasm-path-override/renamed.wasm': '/base/node_modules/onnxruntime-web/dist/ort-wasm.wasm',
},
client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } },
reporters: ['mocha'],

View file

@ -4,7 +4,7 @@
const ort = require('onnxruntime-web');
const testFunction = require('./common');
it('Browser E2E testing - WebAssembly backend', async function () {
it('Node.js E2E testing - WebAssembly backend (no threads)', async function () {
ort.env.wasm.numThreads = 1;
await testFunction(ort, { executionProviders: ['wasm'] });
});

View file

@ -4,7 +4,7 @@
const ort = require('onnxruntime-web');
const testFunction = require('./common');
it('Browser E2E testing - WebAssembly backend', async function () {
it('Node.js E2E testing - WebAssembly backend', async function () {
await testFunction(ort, { executionProviders: ['wasm'] });
process.exit();

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
const path = require('path');
const ort = require('onnxruntime-web');
const testFunction = require('./common');
it('Node.js E2E testing - WebAssembly backend (path override filename)', async function () {
// disable SIMD and multi-thread
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = false;
// override .wasm file path for 'ort-wasm.wasm'
ort.env.wasm.wasmPaths = {
'ort-wasm.wasm': path.join(__dirname, 'test-wasm-path-override/renamed.wasm')
};
await testFunction(ort, { executionProviders: ['wasm'] });
});

View file

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
const path = require('path');
const ort = require('onnxruntime-web');
const testFunction = require('./common');
it('Node.js E2E testing - WebAssembly backend (path override prefix)', async function () {
// disable SIMD and multi-thread
ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = false;
// override .wasm file path prefix
ort.env.wasm.wasmPaths = path.join(__dirname, 'test-wasm-path-override/');
await testFunction(ort, { executionProviders: ['wasm'] });
});

View file

@ -46,6 +46,9 @@ async function main() {
// npm install with "--cache" to install packed packages with an empty cache folder
await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" "${ORT_COMMON_PACKED_FILEPATH}" "${ORT_WEB_PACKED_FILEPATH}"`);
// prepare .wasm files for path override testing
prepareWasmPathOverrideFiles();
// test case run in Node.js
await testAllNodejsCases();
@ -60,17 +63,29 @@ async function main() {
process.exit(0);
}
function prepareWasmPathOverrideFiles() {
const folder = path.join(TEST_E2E_RUN_FOLDER, 'test-wasm-path-override');
const sourceFile = path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web', 'dist', 'ort-wasm.wasm');
fs.emptyDirSync(folder);
fs.copyFileSync(sourceFile, path.join(folder, 'ort-wasm.wasm'));
fs.copyFileSync(sourceFile, path.join(folder, 'renamed.wasm'));
}
async function testAllNodejsCases() {
await runInShell('node ./node_modules/mocha/bin/mocha ./node-test-main-no-threads.js');
await runInShell('node ./node_modules/mocha/bin/mocha ./node-test-main.js');
await runInShell('node --experimental-wasm-threads --experimental-wasm-bulk-memory ./node_modules/mocha/bin/mocha ./node-test-main-no-threads.js');
await runInShell('node --experimental-wasm-threads --experimental-wasm-bulk-memory ./node_modules/mocha/bin/mocha ./node-test-main.js');
await runInShell('node ./node_modules/mocha/bin/mocha ./node-test-wasm-path-override-filename.js');
await runInShell('node ./node_modules/mocha/bin/mocha ./node-test-wasm-path-override-prefix.js');
}
async function testAllBrowserCases({ hostInKarma }) {
await runKarma({ hostInKarma, main: './browser-test-webgl.js', browser: 'Chrome_default' });
await runKarma({ hostInKarma, main: './browser-test-wasm.js', browser: 'Chrome_default' });
await runKarma({ hostInKarma, main: './browser-test-wasm-no-threads.js', browser: 'Chrome_default' });
await runKarma({ hostInKarma, main: './browser-test-wasm-path-override-filename.js', browser: 'Chrome_default' });
await runKarma({ hostInKarma, main: './browser-test-wasm-path-override-prefix.js', browser: 'Chrome_default' });
}
async function runKarma({ hostInKarma, main, browser }) {

View file

@ -8,11 +8,15 @@ var http = require('http');
var fs = require('fs');
var path = require('path');
var simpleProxies = {
'./ort-wasm.wasm': './ort-wasm.wasm'
};
module.exports = function (dir) {
http.createServer(function (request, response) {
console.log('request ', request.url);
var filePath = '.' + request.url;
var filePath = '.' + (simpleProxies[request.url] ?? request.url);
var extname = String(path.extname(filePath)).toLowerCase();
var mimeTypes = {