[js/web] a few optimizations for test runner (#17174)

### Description
1. allows passing session options to operator test (eg. graph
optimization level)
2. add a short flag '-x' for '--wasm-number-threads' as it is frequently
used.
This commit is contained in:
Yulong Wang 2023-08-15 21:00:23 -07:00 committed by GitHub
parent 2575b9aaa1
commit 35363dd9a5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 6 deletions

View file

@ -66,7 +66,7 @@ Options:
*** Backend Options ***
--wasm-number-threads Set the WebAssembly number of threads
-x, --wasm-number-threads Set the WebAssembly number of threads
--wasm-init-timeout Set the timeout for WebAssembly backend initialization, in milliseconds
--wasm-enable-simd Set whether to enable SIMD
--wasm-enable-proxy Set whether to enable proxy worker
@ -264,9 +264,9 @@ function parseWasmOptions(_args: minimist.ParsedArgs): InferenceSession.WebAssem
}
function parseWasmFlags(args: minimist.ParsedArgs): Env.WebAssemblyFlags {
const numThreads = args['wasm-number-threads'];
const numThreads = args.x || args['wasm-number-threads'];
if (typeof numThreads !== 'undefined' && typeof numThreads !== 'number') {
throw new Error('Flag "wasm-number-threads" must be a number value');
throw new Error('Flag "x"/"wasm-number-threads" must be a number value');
}
const initTimeout = args['wasm-init-timeout'];
if (typeof initTimeout !== 'undefined' && typeof initTimeout !== 'number') {

View file

@ -137,7 +137,8 @@ for (const group of ORT_WEB_TEST_CONFIG.op) {
let context: ProtoOpTestContext|OpTestContext;
before('Initialize Context', async () => {
context = useProtoOpTest ? new ProtoOpTestContext(test) : new OpTestContext(test);
context = useProtoOpTest ? new ProtoOpTestContext(test, ORT_WEB_TEST_CONFIG.options.sessionOptions) :
new OpTestContext(test);
await context.init();
if (ORT_WEB_TEST_CONFIG.profile) {
if (context instanceof ProtoOpTestContext) {

View file

@ -574,7 +574,7 @@ export class ProtoOpTestContext {
private readonly loadedData: Uint8Array; // model data, inputs, outputs
session: ort.InferenceSession;
readonly backendHint: string;
constructor(test: Test.OperatorTest) {
constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) {
const opsetImport = onnx.OperatorSetIdProto.create(test.opset);
const operator = test.operator;
const attribute = (test.attributes || []).map(attr => {
@ -714,7 +714,8 @@ export class ProtoOpTestContext {
}
}
async init(): Promise<void> {
this.session = await ort.InferenceSession.create(this.loadedData, {executionProviders: [this.backendHint]});
this.session = await ort.InferenceSession.create(
this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions});
}
async dispose(): Promise<void> {