[webgpu] no longer need pass-in gpu adapter for custom context

This commit is contained in:
Yulong Wang 2025-02-08 23:01:16 -08:00
parent 0274b7b82f
commit 1dbc22d5b1
4 changed files with 22 additions and 38 deletions

View file

@ -37,8 +37,8 @@ namespace webgpu {
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
if (device_ == nullptr) {
// Create wgpu::Adapter
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
//
@ -77,20 +77,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
req_adapter_options.nextInChain = &adapter_toggles_desc;
#endif
wgpu::Adapter adapter;
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(
&req_adapter_options,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) {
ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message});
*ptr = adapter;
*ptr = std::move(adapter);
},
&adapter_),
&adapter),
UINT64_MAX));
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
}
ORT_ENFORCE(adapter != nullptr, "Failed to get a WebGPU adapter.");
// Create wgpu::Device
if (device_ == nullptr) {
// Create wgpu::Device
wgpu::DeviceDescriptor device_desc = {};
#if !defined(__wasm__)
@ -106,12 +105,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
device_toggles_desc.disabledToggles = disabled_device_toggles.data();
#endif
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter_);
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter);
if (required_features.size() > 0) {
device_desc.requiredFeatures = required_features.data();
device_desc.requiredFeatureCount = required_features.size();
}
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_);
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter);
device_desc.requiredLimits = &required_limits;
// TODO: revise temporary error handling
@ -123,12 +122,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
});
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter.RequestDevice(
&device_desc,
wgpu::CallbackMode::WaitAnyOnly,
[](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) {
ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message});
*ptr = device;
*ptr = std::move(device);
},
&device_),
UINT64_MAX));
@ -136,7 +135,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
}
// cache adapter info
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_));
// cache device limits
wgpu::SupportedLimits device_supported_limits;
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
@ -706,13 +705,12 @@ wgpu::Instance WebGpuContextFactory::default_instance_;
WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
const int context_id = config.context_id;
WGPUInstance instance = config.instance;
WGPUAdapter adapter = config.adapter;
WGPUDevice device = config.device;
if (context_id == 0) {
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");
ORT_ENFORCE(instance == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");
std::call_once(init_default_flag_, [
#if !defined(__wasm__)
@ -750,9 +748,9 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
});
instance = default_instance_.Get();
} else {
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device.");
// for context ID > 0, user must provide custom WebGPU instance and device.
ORT_ENFORCE(instance != nullptr && device != nullptr,
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device.");
}
std::lock_guard<std::mutex> lock(mutex_);
@ -760,13 +758,12 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
it->second.context->adapter_.Get() == adapter &&
it->second.context->device_.Get() == device,
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
}
it->second.ref_count++;
return *it->second.context;

View file

@ -29,7 +29,6 @@ class ProgramBase;
struct WebGpuContextConfig {
int context_id;
WGPUInstance instance;
WGPUAdapter adapter;
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
@ -76,7 +75,6 @@ class WebGpuContext final {
Status Wait(wgpu::Future f);
const wgpu::Adapter& Adapter() const { return adapter_; }
const wgpu::Device& Device() const { return device_; }
const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
@ -149,8 +147,8 @@ class WebGpuContext final {
AtPasses
};
WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);
std::vector<const char*> GetEnabledAdapterToggles() const;
@ -198,7 +196,6 @@ class WebGpuContext final {
LibraryHandles modules_;
wgpu::Instance instance_;
wgpu::Adapter adapter_;
wgpu::Device device_;
webgpu::ValidationMode validation_mode_;

View file

@ -106,14 +106,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec);
}
size_t webgpu_adapter = 0;
std::string webgpu_adapter_str;
if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) {
static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch");
ORT_ENFORCE(std::errc{} ==
std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec);
}
size_t webgpu_device = 0;
std::string webgpu_device_str;
if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) {
@ -154,7 +146,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
webgpu::WebGpuContextConfig context_config{
context_id,
reinterpret_cast<WGPUInstance>(webgpu_instance),
reinterpret_cast<WGPUAdapter>(webgpu_adapter),
reinterpret_cast<WGPUDevice>(webgpu_device),
reinterpret_cast<const void*>(dawn_proc_table),
validation_mode,
@ -238,7 +229,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
// STEP.4 - start initialization.
//
// Load the Dawn library and create the WebGPU instance and adapter.
// Load the Dawn library and create the WebGPU instance.
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);
// Create WebGPU device and initialize the context.

View file

@ -18,7 +18,6 @@ constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType";
constexpr const char* kDeviceId = "WebGPU:deviceId";
constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance";
constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter";
constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice";
constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode";