[WebGPU] Support PIX Capture for WebGPU EP (#23192)

PIX Capture tool requires 'present' to end a frame capture. ORT doesn't
have rendering work so no 'present' happens.

To avoid endless waiting for PIX capture tool, this PR added a blank
surface and 'present' on it in each session run.

The surface is created in WebGPU ep constructor and closed in WebGPU ep
destructor.

### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
shaoboyan091 2025-02-08 18:05:15 +08:00 committed by GitHub
parent 01145511b1
commit 5ef18328bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 225 additions and 10 deletions

View file

@ -139,6 +139,7 @@ option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C+
option(onnxruntime_USE_EXTERNAL_DAWN "Build with treating Dawn as external dependency. Will not link Dawn at build time." OFF)
option(onnxruntime_CUSTOM_DAWN_SRC_PATH "Path to custom Dawn src dir.")
option(onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY "Build Dawn as a monolithic library" OFF)
option(onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP "Adding frame present for PIX to capture a frame" OFF)
# The following 2 options are only for Windows
option(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN "Enable Vulkan backend for Dawn (on Windows)" OFF)
option(onnxruntime_ENABLE_DAWN_BACKEND_D3D12 "Enable D3D12 backend for Dawn (on Windows)" ON)
@ -1038,6 +1039,14 @@ if (onnxruntime_USE_WEBGPU)
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1)
endif()
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
if (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12 OR NOT WIN32)
message(
FATAL_ERROR
"Option onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP can only be set on windows with onnxruntime_ENABLE_DAWN_BACKEND_D3D12 is enabled.")
endif()
add_compile_definitions(ENABLE_PIX_FOR_WEBGPU_EP)
endif()
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)

View file

@ -685,6 +685,24 @@ if (onnxruntime_USE_WEBGPU)
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
endif()
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
set(DAWN_ENABLE_DESKTOP_GL ON CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES ON CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING ON CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW ON CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI ON CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER ON CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR ON CACHE BOOL "" FORCE)
else()
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
endif()
# disable things we don't use
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
@ -741,6 +759,10 @@ if (onnxruntime_USE_WEBGPU)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
endif()
endif()
if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES glfw webgpu_glfw)
endif()
endif()
if(onnxruntime_USE_COREML)

View file

@ -35,8 +35,8 @@
namespace onnxruntime {
namespace webgpu {
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() {
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type, enable_pix_capture]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
@ -162,6 +162,16 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
} else {
query_type_ = TimestampQueryType::None;
}
if (enable_pix_capture) {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
pix_frame_generator_ = std::make_unique<WebGpuPIXFrameGenerator>(instance_,
Adapter(),
Device());
#else
ORT_THROW("Support PIX capture requires extra build flags (--enable_pix_capture)");
#endif // ENABLE_PIX_FOR_WEBGPU_EP
}
});
}
@ -680,6 +690,14 @@ void WebGpuContext::Flush() {
num_pending_dispatches_ = 0;
}
void WebGpuContext::OnRunEnd() {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
if (pix_frame_generator_) {
pix_frame_generator_->GeneratePIXFrame();
}
#endif // ENABLE_PIX_FOR_WEBGPU_EP
}
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;

View file

@ -14,6 +14,10 @@
#include "core/providers/webgpu/buffer_manager.h"
#include "core/providers/webgpu/program_manager.h"
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
#endif // ENABLE_PIX_FOR_WEBGPU_EP
namespace onnxruntime {
class Tensor;
@ -68,7 +72,7 @@ class WebGpuContextFactory {
// Class WebGpuContext includes all necessary resources for the context.
class WebGpuContext final {
public:
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type);
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type, bool enable_pix_capture);
Status Wait(wgpu::Future f);
@ -136,6 +140,7 @@ class WebGpuContext final {
Status PopErrorScope();
Status Run(ComputeContext& context, const ProgramBase& program);
void OnRunEnd();
private:
enum class TimestampQueryType {
@ -222,6 +227,10 @@ class WebGpuContext final {
uint64_t gpu_timestamp_offset_ = 0;
bool is_profiling_ = false;
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
#endif // ENABLE_PIX_FOR_WEBGPU_EP
};
} // namespace webgpu

View file

@ -747,8 +747,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
context_{context},
preferred_data_layout_{config.data_layout},
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
enable_graph_capture_{config.enable_graph_capture} {
}
enable_graph_capture_{config.enable_graph_capture} {}
std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) {
@ -862,11 +861,13 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
context_.CollectProfilingData(profiler_->Events());
}
context_.OnRunEnd();
if (context_.ValidationMode() >= ValidationMode::Basic) {
return context_.PopErrorScope();
} else {
return Status::OK();
}
return Status::OK();
}
bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {

View file

@ -23,15 +23,17 @@ class WebGpuProfiler;
} // namespace webgpu
struct WebGpuExecutionProviderConfig {
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture)
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture, bool enable_pix_capture)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture} {}
enable_graph_capture{enable_graph_capture},
enable_pix_capture{enable_pix_capture} {}
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);
DataLayout data_layout;
bool enable_graph_capture;
bool enable_pix_capture;
std::vector<std::string> force_cpu_node_names;
};

View file

@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#include <webgpu/webgpu_glfw.h>
#include "core/common/common.h"
#include "core/providers/webgpu/webgpu_pix_frame_generator.h"
namespace onnxruntime {
namespace webgpu {
WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device) {
// Trivial window size for surface texture creation and provide frame concept for PIX.
static constexpr uint32_t kWidth = 512u;
static constexpr uint32_t kHeight = 512u;
if (!glfwInit()) {
ORT_ENFORCE("Failed to init glfw for PIX capture");
}
glfwWindowHint(GLFW_CLIENT_API, GLFW_NO_API);
window_ =
glfwCreateWindow(kWidth, kHeight, "WebGPU window", nullptr, nullptr);
ORT_ENFORCE(window_ != nullptr, "PIX Capture: Failed to create Window for capturing frames.");
surface_ = wgpu::glfw::CreateSurfaceForWindow(instance, window_);
ORT_ENFORCE(surface_.Get() != nullptr, "PIX Capture: Failed to create surface for capturing frames.");
wgpu::TextureFormat format;
wgpu::SurfaceCapabilities capabilities;
surface_.GetCapabilities(adapter, &capabilities);
format = capabilities.formats[0];
wgpu::SurfaceConfiguration config;
config.device = device;
config.format = format;
config.width = kWidth;
config.height = kHeight;
surface_.Configure(&config);
}
void WebGpuPIXFrameGenerator::GeneratePIXFrame() {
ORT_ENFORCE(surface_.Get() != nullptr, "PIX Capture: Cannot do present on null surface for capturing frames");
wgpu::SurfaceTexture surfaceTexture;
surface_.GetCurrentTexture(&surfaceTexture);
// Call present to trigger dxgi_swapchain present. PIX
// take this as a frame boundary.
surface_.Present();
}
WebGpuPIXFrameGenerator::~WebGpuPIXFrameGenerator() {
if (surface_.Get()) {
surface_.Unconfigure();
}
if (window_) {
glfwDestroyWindow(window_);
window_ = nullptr;
}
}
} // namespace webgpu
} // namespace onnxruntime
#endif // ENABLE_PIX_FOR_WEBGPU_EP

View file

@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#pragma once
#ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h>
#endif
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
#include <GLFW/glfw3.h>
#endif // ENABLE_PIX_FOR_WEBGPU_EP
#include <memory>
#include <webgpu/webgpu_cpp.h>
namespace onnxruntime {
namespace webgpu {
// PIX(https://devblogs.microsoft.com/pix/introduction/) is a profiling tool
// provides by Microsoft. It has ability to do GPU capture to profile gpu
// behavior among different GPU vendors. It works on Windows only.
//
// GPU capture(present-to-present) provided by PIX uses present as a frame boundary to
// capture and generate a valid frame infos. But ORT WebGPU EP doesn't have any present logic
// and hangs PIX GPU Capture forever.
//
// To make PIX works with ORT WebGPU EP on Windows, WebGpuPIXFrameGenerator class includes codes
// to create a trivial window through glfw, config surface with Dawn device and call present in
// proper place to trigger frame boundary for PIX GPU Capture.
//
// WebGpuPIXFrameGenerator is an friend class because:
// - It should only be used in WebGpuContext class implementation.
// - It requires instance and device from WebGpuContext.
//
// The lifecycle of WebGpuPIXFrameGenerator instance should be nested into WebGpuContext lifecycle.
// WebGpuPIXFrameGenerator instance should be created during WebGpuContext creation and be destroyed during
// WebGpuContext destruction.
class WebGpuPIXFrameGenerator {
public:
WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu::Adapter adapter, wgpu::Device device);
~WebGpuPIXFrameGenerator();
void GeneratePIXFrame();
private:
void CreateSurface();
wgpu::Surface surface_;
GLFWwindow* window_;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuPIXFrameGenerator);
};
} // namespace webgpu
} // namespace onnxruntime
#endif // ENABLE_PIX_FOR_WEBGPU_EP

View file

@ -40,6 +40,8 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
DataLayout::NHWC,
// graph capture feature is disabled by default
false,
// enable pix capture feature is diabled by default
false,
};
std::string preferred_layout_str;
@ -219,6 +221,19 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled);
LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode;
bool enable_pix_capture = false;
std::string enable_pix_capture_str;
if (config_options.TryGetConfigEntry(kEnablePIXCapture, enable_pix_capture_str)) {
if (enable_pix_capture_str == kEnablePIXCapture_ON) {
enable_pix_capture = true;
} else if (enable_pix_capture_str == kEnablePIXCapture_OFF) {
enable_pix_capture = false;
} else {
ORT_THROW("Invalid enable pix capture: ", enable_pix_capture_str);
}
}
LOGS_DEFAULT(VERBOSE) << "WebGPU EP pix capture enable: " << enable_pix_capture;
//
// STEP.4 - start initialization.
//
@ -227,7 +242,7 @@ std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(
auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config);
// Create WebGPU device and initialize the context.
context.Initialize(buffer_cache_config, backend_type);
context.Initialize(buffer_cache_config, backend_type, enable_pix_capture);
// Create WebGPU EP factory.
return std::make_shared<WebGpuProviderFactory>(context_id, context, std::move(webgpu_ep_config));

View file

@ -29,6 +29,7 @@ constexpr const char* kDefaultBufferCacheMode = "WebGPU:defaultBufferCacheMode";
constexpr const char* kValidationMode = "WebGPU:validationMode";
constexpr const char* kForceCpuNodeNames = "WebGPU:forceCpuNodeNames";
constexpr const char* kEnablePIXCapture = "WebGPU:enablePIXCapture";
// The following are the possible values for the provider options.
@ -41,6 +42,9 @@ constexpr const char* kPreferredLayout_NHWC = "NHWC";
constexpr const char* kEnableGraphCapture_ON = "1";
constexpr const char* kEnableGraphCapture_OFF = "0";
constexpr const char* kEnablePIXCapture_ON = "1";
constexpr const char* kEnablePIXCapture_OFF = "0";
constexpr const char* kBufferCacheMode_Disabled = "disabled";
constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease";
constexpr const char* kBufferCacheMode_Simple = "simple";

View file

@ -613,6 +613,7 @@ def parse_arguments():
parser.add_argument("--use_migraphx", action="store_true", help="Build with MIGraphX")
parser.add_argument("--migraphx_home", help="Path to MIGraphX installation dir")
parser.add_argument("--use_full_protobuf", action="store_true", help="Use the full protobuf library")
parser.add_argument("--enable_pix_capture", action="store_true", help="Enable Pix Support.")
parser.add_argument(
"--skip_onnx_tests",
@ -1077,6 +1078,7 @@ def generate_build_tree(
"-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"),
"-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"),
"-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"),
"-Donnxruntime_ENABLE_PIX_FOR_WEBGPU_EP=" + ("ON" if args.enable_pix_capture else "OFF"),
"-Donnxruntime_USE_EXTERNAL_DAWN=" + ("ON" if args.use_external_dawn else "OFF"),
# Training related flags
"-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"),
@ -1457,6 +1459,11 @@ def generate_build_tree(
if args.use_external_dawn and not args.use_webgpu:
raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).")
if args.enable_pix_capture and (not args.use_webgpu or not is_windows()):
raise BuildError(
"Enable PIX Capture (--enable_pix_capture) must be enabled with WebGPU (--use_webgpu) on Windows"
)
if args.use_snpe:
cmake_args += ["-Donnxruntime_USE_SNPE=ON"]