mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
[WebNN EP] Support numThreads option for WebNN CPU device (#18054)
This commit is contained in:
parent
cbf0cf06db
commit
73ed34ac4b
6 changed files with 30 additions and 6 deletions
|
|
@ -241,6 +241,7 @@ export declare namespace InferenceSession {
|
|||
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {
|
||||
readonly name: 'webnn';
|
||||
deviceType?: 'cpu'|'gpu';
|
||||
numThreads?: number;
|
||||
powerPreference?: 'default'|'low-power'|'high-performance';
|
||||
}
|
||||
export interface CoreMLExecutionProviderOption extends ExecutionProviderOption {
|
||||
|
|
|
|||
|
|
@ -75,6 +75,19 @@ const setExecutionProviders =
|
|||
checkLastError(`Can't set a session config entry: 'deviceType' - ${webnnOptions.deviceType}.`);
|
||||
}
|
||||
}
|
||||
if (webnnOptions?.numThreads) {
|
||||
let numThreads = webnnOptions.numThreads;
|
||||
// Just ignore invalid webnnOptions.numThreads.
|
||||
if (typeof numThreads != 'number' || !Number.isInteger(numThreads) || numThreads < 0) {
|
||||
numThreads = 0;
|
||||
}
|
||||
const keyDataOffset = allocWasmString('numThreads', allocs);
|
||||
const valueDataOffset = allocWasmString(numThreads.toString(), allocs);
|
||||
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
|
||||
0) {
|
||||
checkLastError(`Can't set a session config entry: 'numThreads' - ${webnnOptions.numThreads}.`);
|
||||
}
|
||||
}
|
||||
if (webnnOptions?.powerPreference) {
|
||||
const keyDataOffset = allocWasmString('powerPreference', allocs);
|
||||
const valueDataOffset = allocWasmString(webnnOptions.powerPreference, allocs);
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
WebNNExecutionProvider::WebNNExecutionProvider(
|
||||
const std::string& webnn_device_flags, const std::string& webnn_power_flags)
|
||||
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags,
|
||||
const std::string& webnn_threads_number, const std::string& webnn_power_flags)
|
||||
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} {
|
||||
// Create WebNN context and graph builder.
|
||||
const emscripten::val ml = emscripten::val::global("navigator")["ml"];
|
||||
|
|
@ -31,6 +31,10 @@ WebNNExecutionProvider::WebNNExecutionProvider(
|
|||
if (webnn_device_flags.compare("cpu") == 0) {
|
||||
preferred_layout_ = DataLayout::NHWC;
|
||||
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
|
||||
// Set "numThreads" if it's not default 0.
|
||||
if (webnn_threads_number.compare("0") != 0) {
|
||||
context_options.set("numThreads", stoi(webnn_threads_number));
|
||||
}
|
||||
} else {
|
||||
preferred_layout_ = DataLayout::NCHW;
|
||||
wnn_device_type_ = webnn::WebnnDeviceType::GPU;
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ class Model;
|
|||
|
||||
class WebNNExecutionProvider : public IExecutionProvider {
|
||||
public:
|
||||
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_power_flags);
|
||||
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
|
||||
const std::string& webnn_power_flags);
|
||||
virtual ~WebNNExecutionProvider();
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
|
|
|
|||
|
|
@ -10,23 +10,26 @@ using namespace onnxruntime;
|
|||
|
||||
namespace onnxruntime {
|
||||
struct WebNNProviderFactory : IExecutionProviderFactory {
|
||||
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_power_flags)
|
||||
: webnn_device_flags_(webnn_device_flags), webnn_power_flags_(webnn_power_flags) {}
|
||||
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
|
||||
const std::string& webnn_power_flags)
|
||||
: webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {}
|
||||
~WebNNProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
||||
std::string webnn_device_flags_;
|
||||
std::string webnn_threads_number_;
|
||||
std::string webnn_power_flags_;
|
||||
};
|
||||
|
||||
std::unique_ptr<IExecutionProvider> WebNNProviderFactory::CreateProvider() {
|
||||
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_power_flags_);
|
||||
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> WebNNProviderFactoryCreator::Create(
|
||||
const ProviderOptions& provider_options) {
|
||||
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"),
|
||||
provider_options.at("numThreads"),
|
||||
provider_options.at("powerPreference"));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -104,8 +104,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
|
|||
} else if (strcmp(provider_name, "WEBNN") == 0) {
|
||||
#if defined(USE_WEBNN)
|
||||
std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu");
|
||||
std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0");
|
||||
std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default");
|
||||
provider_options["deviceType"] = deviceType;
|
||||
provider_options["numThreads"] = numThreads;
|
||||
provider_options["powerPreference"] = powerPreference;
|
||||
options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options));
|
||||
#else
|
||||
|
|
|
|||
Loading…
Reference in a new issue