mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[webgpu] no longer need pass-in gpu adapter for custom context
This commit is contained in:
parent
0274b7b82f
commit
1dbc22d5b1
4 changed files with 22 additions and 38 deletions
|
|
@ -37,8 +37,8 @@ namespace webgpu {
|
|||
|
||||
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
|
||||
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
|
||||
// Create wgpu::Adapter
|
||||
if (adapter_ == nullptr) {
|
||||
if (device_ == nullptr) {
|
||||
// Create wgpu::Adapter
|
||||
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
|
||||
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
|
||||
//
|
||||
|
|
@ -77,20 +77,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
req_adapter_options.nextInChain = &adapter_toggles_desc;
|
||||
#endif
|
||||
|
||||
wgpu::Adapter adapter;
|
||||
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(
|
||||
&req_adapter_options,
|
||||
wgpu::CallbackMode::WaitAnyOnly,
|
||||
[](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, wgpu::StringView message, wgpu::Adapter* ptr) {
|
||||
ORT_ENFORCE(status == wgpu::RequestAdapterStatus::Success, "Failed to get a WebGPU adapter: ", std::string_view{message});
|
||||
*ptr = adapter;
|
||||
*ptr = std::move(adapter);
|
||||
},
|
||||
&adapter_),
|
||||
&adapter),
|
||||
UINT64_MAX));
|
||||
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
|
||||
}
|
||||
ORT_ENFORCE(adapter != nullptr, "Failed to get a WebGPU adapter.");
|
||||
|
||||
// Create wgpu::Device
|
||||
if (device_ == nullptr) {
|
||||
// Create wgpu::Device
|
||||
wgpu::DeviceDescriptor device_desc = {};
|
||||
|
||||
#if !defined(__wasm__)
|
||||
|
|
@ -106,12 +105,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
device_toggles_desc.disabledToggles = disabled_device_toggles.data();
|
||||
#endif
|
||||
|
||||
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter_);
|
||||
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter);
|
||||
if (required_features.size() > 0) {
|
||||
device_desc.requiredFeatures = required_features.data();
|
||||
device_desc.requiredFeatureCount = required_features.size();
|
||||
}
|
||||
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_);
|
||||
wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter);
|
||||
device_desc.requiredLimits = &required_limits;
|
||||
|
||||
// TODO: revise temporary error handling
|
||||
|
|
@ -123,12 +122,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
LOGS_DEFAULT(INFO) << "WebGPU device lost (" << int(reason) << "): " << std::string_view{message};
|
||||
});
|
||||
|
||||
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(
|
||||
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter.RequestDevice(
|
||||
&device_desc,
|
||||
wgpu::CallbackMode::WaitAnyOnly,
|
||||
[](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message, wgpu::Device* ptr) {
|
||||
ORT_ENFORCE(status == wgpu::RequestDeviceStatus::Success, "Failed to get a WebGPU device: ", std::string_view{message});
|
||||
*ptr = device;
|
||||
*ptr = std::move(device);
|
||||
},
|
||||
&device_),
|
||||
UINT64_MAX));
|
||||
|
|
@ -136,7 +135,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
}
|
||||
|
||||
// cache adapter info
|
||||
ORT_ENFORCE(Adapter().GetInfo(&adapter_info_));
|
||||
ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_));
|
||||
// cache device limits
|
||||
wgpu::SupportedLimits device_supported_limits;
|
||||
ORT_ENFORCE(Device().GetLimits(&device_supported_limits));
|
||||
|
|
@ -706,13 +705,12 @@ wgpu::Instance WebGpuContextFactory::default_instance_;
|
|||
WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
|
||||
const int context_id = config.context_id;
|
||||
WGPUInstance instance = config.instance;
|
||||
WGPUAdapter adapter = config.adapter;
|
||||
WGPUDevice device = config.device;
|
||||
|
||||
if (context_id == 0) {
|
||||
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
|
||||
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
|
||||
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");
|
||||
ORT_ENFORCE(instance == nullptr && device == nullptr,
|
||||
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance or device.");
|
||||
|
||||
std::call_once(init_default_flag_, [
|
||||
#if !defined(__wasm__)
|
||||
|
|
@ -750,9 +748,9 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
|
|||
});
|
||||
instance = default_instance_.Get();
|
||||
} else {
|
||||
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
|
||||
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
|
||||
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device.");
|
||||
// for context ID > 0, user must provide custom WebGPU instance and device.
|
||||
ORT_ENFORCE(instance != nullptr && device != nullptr,
|
||||
"WebGPU EP custom context (contextId>0) must have custom WebGPU instance and device.");
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
|
@ -760,13 +758,12 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
|
|||
auto it = contexts_.find(context_id);
|
||||
if (it == contexts_.end()) {
|
||||
GSL_SUPPRESS(r.11)
|
||||
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
|
||||
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, device, config.validation_mode));
|
||||
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
|
||||
} else if (context_id != 0) {
|
||||
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
|
||||
it->second.context->adapter_.Get() == adapter &&
|
||||
it->second.context->device_.Get() == device,
|
||||
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
|
||||
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance or device.");
|
||||
}
|
||||
it->second.ref_count++;
|
||||
return *it->second.context;
|
||||
|
|
|
|||
|
|
@ -29,7 +29,6 @@ class ProgramBase;
|
|||
struct WebGpuContextConfig {
|
||||
int context_id;
|
||||
WGPUInstance instance;
|
||||
WGPUAdapter adapter;
|
||||
WGPUDevice device;
|
||||
const void* dawn_proc_table;
|
||||
ValidationMode validation_mode;
|
||||
|
|
@ -76,7 +75,6 @@ class WebGpuContext final {
|
|||
|
||||
Status Wait(wgpu::Future f);
|
||||
|
||||
const wgpu::Adapter& Adapter() const { return adapter_; }
|
||||
const wgpu::Device& Device() const { return device_; }
|
||||
|
||||
const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; }
|
||||
|
|
@ -149,8 +147,8 @@ class WebGpuContext final {
|
|||
AtPasses
|
||||
};
|
||||
|
||||
WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode)
|
||||
: instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
|
||||
WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode)
|
||||
: instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None} {}
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext);
|
||||
|
||||
std::vector<const char*> GetEnabledAdapterToggles() const;
|
||||
|
|
@ -198,7 +196,6 @@ class WebGpuContext final {
|
|||
LibraryHandles modules_;
|
||||
|
||||
wgpu::Instance instance_;
|
||||
wgpu::Adapter adapter_;
|
||||
wgpu::Device device_;
|
||||
|
||||
webgpu::ValidationMode validation_mode_;
|
||||
|
|
|
|||
|
|
@ -106,14 +106,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
|
|||
std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec);
|
||||
}
|
||||
|
||||
size_t webgpu_adapter = 0;
|
||||
std::string webgpu_adapter_str;
|
||||
if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) {
|
||||
static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch");
|
||||
ORT_ENFORCE(std::errc{} ==
|
||||
std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec);
|
||||
}
|
||||
|
||||
size_t webgpu_device = 0;
|
||||
std::string webgpu_device_str;
|
||||
if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) {
|
||||
|
|
@ -154,7 +146,6 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
|
|||
webgpu::WebGpuContextConfig context_config{
|
||||
context_id,
|
||||
reinterpret_cast<WGPUInstance>(webgpu_instance),
|
||||
reinterpret_cast<WGPUAdapter>(webgpu_adapter),
|
||||
reinterpret_cast<WGPUDevice>(webgpu_device),
|
||||
reinterpret_cast<const void*>(dawn_proc_table),
|
||||
validation_mode,
|
||||
|
|
@ -238,7 +229,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
|
|||
// STEP.4 - start initialization.
|
||||
//
|
||||
|
||||
// Load the Dawn library and create the WebGPU instance and adapter.
|
||||
// Load the Dawn library and create the WebGPU instance.
|
||||
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);
|
||||
|
||||
// Create WebGPU device and initialize the context.
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ constexpr const char* kDawnBackendType = "WebGPU:dawnBackendType";
|
|||
|
||||
constexpr const char* kDeviceId = "WebGPU:deviceId";
|
||||
constexpr const char* kWebGpuInstance = "WebGPU:webgpuInstance";
|
||||
constexpr const char* kWebGpuAdapter = "WebGPU:webgpuAdapter";
|
||||
constexpr const char* kWebGpuDevice = "WebGPU:webgpuDevice";
|
||||
|
||||
constexpr const char* kStorageBufferCacheMode = "WebGPU:storageBufferCacheMode";
|
||||
|
|
|
|||
Loading…
Reference in a new issue