diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 4f7fbdcdcf..14db5c59d9 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -243,7 +243,7 @@ export declare namespace InferenceSession { } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; - deviceType?: 'cpu'|'gpu'; + deviceType?: 'cpu'|'gpu'|'npu'; numThreads?: number; powerPreference?: 'default'|'low-power'|'high-performance'; } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 745f504b04..adcd940178 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -78,7 +78,7 @@ Options: --webgpu.<...>=<...> --webgpu.profiling.mode=default --wasm.numThreads=1 --wasm.simd=false --webnn.<...>=<...> - --webnn-device-type Set the WebNN device type (cpu/gpu) + --webnn-device-type Set the WebNN device type (cpu/gpu/npu) -x, --wasm-number-threads Set the WebAssembly number of threads ("--wasm-number-threads" is deprecated. use "--wasm.numThreads" or "-x" instead) @@ -351,7 +351,7 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExecutionProviderOption { const deviceType = args['webnn-device-type']; - if (deviceType !== undefined && deviceType !== 'cpu' && deviceType !== 'gpu') { + if (deviceType !== undefined && !['cpu', 'gpu', 'npu'].includes(deviceType)) { throw new Error('Flag "webnn-device-type" is invalid'); } return {name: 'webnn', deviceType}; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d1d7804332..d35a2ae17f 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -28,6 +28,7 @@ namespace webnn { enum class WebnnDeviceType { CPU, GPU, + NPU, }; typedef struct { diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 29c8ca91fe..d72abf1a72 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -37,7 +37,13 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f } } else { preferred_layout_ = DataLayout::NCHW; - wnn_device_type_ = webnn::WebnnDeviceType::GPU; + if (webnn_device_flags.compare("gpu") == 0) { + wnn_device_type_ = webnn::WebnnDeviceType::GPU; + } else if (webnn_device_flags.compare("npu") == 0) { + wnn_device_type_ = webnn::WebnnDeviceType::NPU; + } else { + ORT_THROW("Unknown WebNN deviceType."); + } } if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags));