[WebNN EP] Support NPU deviceType (#20278)

This commit is contained in:
Wanming Lin 2024-04-16 09:43:46 +08:00 committed by GitHub
parent 287ecea2f1
commit fe1c3a45c1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 11 additions and 4 deletions

View file

@ -243,7 +243,7 @@ export declare namespace InferenceSession {
}
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webnn';
deviceType?: 'cpu'|'gpu';
deviceType?: 'cpu'|'gpu'|'npu';
numThreads?: number;
powerPreference?: 'default'|'low-power'|'high-performance';
}

View file

@ -78,7 +78,7 @@ Options:
--webgpu.<...>=<...> --webgpu.profiling.mode=default --wasm.numThreads=1 --wasm.simd=false
--webnn.<...>=<...>
--webnn-device-type Set the WebNN device type (cpu/gpu)
--webnn-device-type Set the WebNN device type (cpu/gpu/npu)
-x, --wasm-number-threads Set the WebAssembly number of threads
("--wasm-number-threads" is deprecated. use "--wasm.numThreads" or "-x" instead)
@ -351,7 +351,7 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial<Env.WebGpuFlags> {
function parseWebNNOptions(args: minimist.ParsedArgs): InferenceSession.WebNNExecutionProviderOption {
const deviceType = args['webnn-device-type'];
if (deviceType !== undefined && deviceType !== 'cpu' && deviceType !== 'gpu') {
if (deviceType !== undefined && !['cpu', 'gpu', 'npu'].includes(deviceType)) {
throw new Error('Flag "webnn-device-type" is invalid');
}
return {name: 'webnn', deviceType};

View file

@ -28,6 +28,7 @@ namespace webnn {
enum class WebnnDeviceType {
CPU,
GPU,
NPU,
};
typedef struct {

View file

@ -37,7 +37,13 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
}
} else {
preferred_layout_ = DataLayout::NCHW;
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
if (webnn_device_flags.compare("gpu") == 0) {
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
} else if (webnn_device_flags.compare("npu") == 0) {
wnn_device_type_ = webnn::WebnnDeviceType::NPU;
} else {
ORT_THROW("Unknown WebNN deviceType.");
}
}
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));