From 35363dd9a5ebdd0d58ecdca6f4b18229f8ae5d99 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 15 Aug 2023 21:00:23 -0700 Subject: [PATCH] [js/web] a few optimizations for test runner (#17174) ### Description 1. allows passing session options to operator test (eg. graph optimization level) 2. add a short flag '-x' for '--wasm-number-threads' as it is frequently used. --- js/web/script/test-runner-cli-args.ts | 6 +++--- js/web/test/test-main.ts | 3 ++- js/web/test/test-runner.ts | 5 +++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 3528c8d639..f2f44b795a 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -66,7 +66,7 @@ Options: *** Backend Options *** - --wasm-number-threads Set the WebAssembly number of threads + -x, --wasm-number-threads Set the WebAssembly number of threads --wasm-init-timeout Set the timeout for WebAssembly backend initialization, in milliseconds --wasm-enable-simd Set whether to enable SIMD --wasm-enable-proxy Set whether to enable proxy worker @@ -264,9 +264,9 @@ function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssem } function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags { - const numThreads = args['wasm-number-threads']; + const numThreads = args.x || args['wasm-number-threads']; if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') { - throw new Error('Flag "wasm-number-threads" must be a number value'); + throw new Error('Flag "x"/"wasm-number-threads" must be a number value'); } const initTimeout = args['wasm-init-timeout']; if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') { diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index d19a4a7b0e..e614cc8e67 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -137,7 +137,8 @@ for (const group of ORT_WEB_TEST_CONFIG.op) { let context: ProtoOpTestContext|OpTestContext; before('Initialize Context', async () => { - context = useProtoOpTest ? new ProtoOpTestContext(test) : new OpTestContext(test); + context = useProtoOpTest ? new ProtoOpTestContext(test, ORT_WEB_TEST_CONFIG.options.sessionOptions) : + new OpTestContext(test); await context.init(); if (ORT_WEB_TEST_CONFIG.profile) { if (context instanceof ProtoOpTestContext) { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 5552a8e299..916243e3d4 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -574,7 +574,7 @@ export class ProtoOpTestContext { private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; - constructor(test: Test.OperatorTest) { + constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; const attribute = (test.attributes || []).map(attr => { @@ -714,7 +714,8 @@ export class ProtoOpTestContext { } } async init(): Promise { - this.session = await ort.InferenceSession.create(this.loadedData, {executionProviders: [this.backendHint]}); + this.session = await ort.InferenceSession.create( + this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions}); } async dispose(): Promise {