From 22bcd9a9853014d644008d25dad86ef9920e9be4 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 10 Jan 2025 20:41:59 -0800 Subject: [PATCH] Fix wasm build with WebGPU EP --- .../external/onnxruntime_external_deps.cmake | 124 ++--- cmake/onnxruntime_providers_webgpu.cmake | 75 +-- cmake/onnxruntime_webassembly.cmake | 7 + cmake/patches/dawn/dawn.patch | 12 + .../emscripten/tools/maint/gen_struct_info.py | 448 ++++++++++++++++++ js/build_webgpu.bat | 78 +++ .../webgpu/bert/skip_layer_norm.cc | 2 +- .../core/providers/webgpu/allocator.cc | 4 - .../core/providers/webgpu/buffer_manager.cc | 4 +- .../core/providers/webgpu/buffer_manager.h | 4 - .../core/providers/webgpu/compute_context.h | 4 - .../core/providers/webgpu/data_transfer.cc | 4 - .../webgpu/math/binary_elementwise_ops.cc | 6 +- .../core/providers/webgpu/nn/layer_norm.cc | 3 +- .../core/providers/webgpu/program_manager.h | 4 - .../core/providers/webgpu/shader_helper.h | 4 - .../core/providers/webgpu/tensor/concat.cc | 5 +- .../core/providers/webgpu/tensor/flatten.h | 13 +- .../webgpu/tensor/gather_elements.cc | 4 +- .../core/providers/webgpu/webgpu_context.cc | 35 +- .../core/providers/webgpu/webgpu_context.h | 4 - .../webgpu/webgpu_execution_provider.cc | 3 - 22 files changed, 701 insertions(+), 146 deletions(-) create mode 100644 cmake/patches/dawn/dawn.patch create mode 100644 cmake/patches/emscripten/tools/maint/gen_struct_info.py create mode 100644 js/build_webgpu.bat diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index aeaaa7b51d..387f88bbfa 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -631,78 +631,92 @@ if (onnxruntime_USE_WEBGPU) URL_HASH SHA1=${DEP_SHA1_dawn} # All previous patches are merged into the upstream dawn project. We don't need to apply any patches right now. # if we need to apply patches in the future, we can uncomment the following line. - # PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch ) endif() - if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) - set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) - - if (onnxruntime_USE_EXTERNAL_DAWN) - message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") - endif() - else() - # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size - set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) - endif() set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) - # disable things we don't use - set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) - set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) - set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) - set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) - set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) - set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) - set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set(DAWN_EMSCRIPTEN_TOOLCHAIN "${REPO_ROOT}/cmake/external/emsdk/upstream/emscripten" CACHE STRING "" FORCE) - set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) - set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving - set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. + # file "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/maint/gen_struct_info.py" is missing from emsdk installation. + # we should copy if from ${PROJECT_SOURCE_DIR}/patches/ + file(COPY_FILE + "${PROJECT_SOURCE_DIR}/patches/emscripten/tools/maint/gen_struct_info.py" + "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/maint/gen_struct_info.py" + ONLY_IF_DIFFERENT) + else() + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE) - # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. - set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) - - if (WIN32) - # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. - set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) - set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) - - if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) - message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") - endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) - set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) - set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + if (onnxruntime_USE_EXTERNAL_DAWN) + message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.") + endif() else() - set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + # use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size + set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) - else() - set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + + # disable things we don't use + set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) + set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE) + set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE) + set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE) + set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE) + set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE) + set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE) + + set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving + set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled. + + # SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V. + set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE) + + if (WIN32) + # building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON. + set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE) + set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE) + + if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12)) + message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.") + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE) + set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE) + endif() + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE) + else() + set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE) + endif() + # We are currently always using the D3D12 backend. + set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() - # We are currently always using the D3D12 backend. - set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE) endif() onnxruntime_fetchcontent_makeavailable(dawn) - if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn) - else() - if (NOT onnxruntime_USE_EXTERNAL_DAWN) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native) + if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn) + else() + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native) + endif() + list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc) endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc) endif() endif() diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index fb3dc43abc..40729e6009 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -21,44 +21,57 @@ source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu - onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) + onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) - set(onnxruntime_providers_webgpu_dll_deps) - - if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) - target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) - - if (WIN32) - if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS) - list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") - endif() - - list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # if "-s DISABLE_EXCEPTION_CATCHING=0" is globally set, we need to remove "-fno-exceptions" from emdawnwebgpu_c + if (CMAKE_CXX_FLAGS MATCHES "DISABLE_EXCEPTION_CATCHING=0") + get_property(EM_DAWN_WEBGPU_C_COMPILE_OPTIONS TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS) + list(REMOVE_ITEM EM_DAWN_WEBGPU_C_COMPILE_OPTIONS "-fno-exceptions") + set_property(TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS ${EM_DAWN_WEBGPU_C_COMPILE_OPTIONS}) endif() + + target_link_libraries(onnxruntime_providers_webgpu PUBLIC emdawnwebgpu_cpp) else() - if (NOT onnxruntime_USE_EXTERNAL_DAWN) - target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu dawn::dawncpp_headers dawn::dawn_headers) + + set(onnxruntime_providers_webgpu_dll_deps) + + if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY) + target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + + if (WIN32) + if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS) + list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll") + endif() + + list(APPEND onnxruntime_providers_webgpu_dll_deps "$") + endif() + else() + if (NOT onnxruntime_USE_EXTERNAL_DAWN) + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native) + endif() + target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc) endif() - target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc) - endif() - if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - # Ensure dxil.dll and dxcompiler.dll exist in the output directory $ - add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll) - add_dependencies(onnxruntime_providers_webgpu dxcompiler) + if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + # Ensure dxil.dll and dxcompiler.dll exist in the output directory $ + add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll) + add_dependencies(onnxruntime_providers_webgpu dxcompiler) - list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxil.dll") - list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxcompiler.dll") - endif() + list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxil.dll") + list(APPEND onnxruntime_providers_webgpu_dll_deps "$/dxcompiler.dll") + endif() - if (onnxruntime_providers_webgpu_dll_deps) - # Copy dependency DLLs to the output directory - add_custom_command( - TARGET onnxruntime_providers_webgpu - POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different "${onnxruntime_providers_webgpu_dll_deps}" "$" - COMMAND_EXPAND_LISTS - VERBATIM ) + if (onnxruntime_providers_webgpu_dll_deps) + # Copy dependency DLLs to the output directory + add_custom_command( + TARGET onnxruntime_providers_webgpu + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${onnxruntime_providers_webgpu_dll_deps}" "$" + COMMAND_EXPAND_LISTS + VERBATIM ) + endif() endif() add_dependencies(onnxruntime_providers_webgpu ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 66268cefac..51541c914b 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -111,6 +111,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} + ${PROVIDERS_WEBGPU} onnxruntime_session onnxruntime_util re2::re2 @@ -188,6 +189,7 @@ else() ${PROVIDERS_JS} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} + ${PROVIDERS_WEBGPU} onnxruntime_session onnxruntime_util re2::re2 @@ -386,6 +388,11 @@ jsepDownload:_pp_") set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() + if (onnxruntime_USE_WEBGPU) + target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBGPU=1) + # target_link_libraries(onnxruntime_webassembly PUBLIC emdawnwebgpu_cpp) + endif() + if (onnxruntime_EMSCRIPTEN_SETTINGS) foreach(setting IN LISTS onnxruntime_EMSCRIPTEN_SETTINGS) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ${setting}") diff --git a/cmake/patches/dawn/dawn.patch b/cmake/patches/dawn/dawn.patch new file mode 100644 index 0000000000..9897dbcec6 --- /dev/null +++ b/cmake/patches/dawn/dawn.patch @@ -0,0 +1,12 @@ +diff --git a/src/emdawnwebgpu/CMakeLists.txt b/src/emdawnwebgpu/CMakeLists.txt +index de673537d3..c98dc46de7 100644 +--- a/src/emdawnwebgpu/CMakeLists.txt ++++ b/src/emdawnwebgpu/CMakeLists.txt +@@ -78,6 +78,7 @@ if (${DAWN_ENABLE_EMSCRIPTEN}) + endif() + + set(ARGS ++ ${Python3_EXECUTABLE} + "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/maint/gen_struct_info.py" + -q + "${EM_BUILD_GEN_DIR}/struct_info_webgpu.json" diff --git a/cmake/patches/emscripten/tools/maint/gen_struct_info.py b/cmake/patches/emscripten/tools/maint/gen_struct_info.py new file mode 100644 index 0000000000..70754f35d1 --- /dev/null +++ b/cmake/patches/emscripten/tools/maint/gen_struct_info.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +# coding=utf-8 +# Copyright 2013 The Emscripten Authors. All rights reserved. +# Emscripten is available under two separate licenses, the MIT license and the +# University of Illinois/NCSA Open Source License. Both these licenses can be +# found in the LICENSE file. + +"""This tool extracts information about structs and defines from the C headers. + +The JSON input format is as follows: +[ + { + 'file': 'some/header.h', + 'structs': { + 'struct_name': [ + 'field1', + 'field2', + 'field3', + { + 'field4': [ + 'nested1', + 'nested2', + { + 'nested3': [ + 'deep_nested1', + ... + ] + } + ... + ] + }, + 'field5' + ], + 'other_struct': [ + 'field1', + 'field2', + ... + ] + }, + 'defines': [ + 'DEFINE_1', + 'DEFINE_2', + ['f', 'FLOAT_DEFINE'], + 'DEFINE_3', + ... + ] + }, + { + 'file': 'some/other/header.h', + ... + } +] + +Please note that the 'f' for 'FLOAT_DEFINE' is just the format passed to printf(), you can put +anything printf() understands. +""" + +import sys +import os +import re +import json +import argparse +import tempfile +import subprocess + +__scriptdir__ = os.path.dirname(os.path.abspath(__file__)) +__rootdir__ = os.path.dirname(os.path.dirname(__scriptdir__)) +sys.path.insert(0, __rootdir__) + +from tools import building +from tools import config +from tools import shared +from tools import system_libs +from tools import utils + +QUIET = (__name__ != '__main__') +DEBUG = False + +CFLAGS = [ + # Avoid parsing problems due to gcc specific syntax. + '-D_GNU_SOURCE', +] + +INTERNAL_CFLAGS = [ + '-I' + utils.path_from_root('system/lib/libc/musl/src/internal'), + '-I' + utils.path_from_root('system/lib/libc/musl/src/include'), + '-I' + utils.path_from_root('system/lib/pthread/'), +] + +CXXFLAGS = [ + '-I' + utils.path_from_root('system/lib/libcxxabi/src'), + '-D__EMSCRIPTEN_EXCEPTIONS__', + '-I' + utils.path_from_root('system/lib/wasmfs/'), + '-std=c++17', +] + +DEFAULT_JSON_FILES = [ + utils.path_from_root('src/struct_info.json'), + utils.path_from_root('src/struct_info_internal.json'), + utils.path_from_root('src/struct_info_cxx.json'), + utils.path_from_root('src/struct_info_webgpu.json'), +] + + +def show(msg): + if shared.DEBUG or not QUIET: + sys.stderr.write('gen_struct_info: %s\n' % msg) + + +# The following three functions generate C code. The output of the compiled code will be +# parsed later on and then put back together into a dict structure by parse_c_output(). +# +# Example: +# c_descent('test1', code) +# c_set('item', 'i%i', '111', code) +# c_set('item2', 'i%i', '9', code) +# c_set('item3', 's%s', '"Hello"', code) +# c_ascent(code) +# c_set('outer', 'f%f', '0.999', code) +# +# Will result in: +# { +# 'test1': { +# 'item': 111, +# 'item2': 9, +# 'item3': 'Hello', +# }, +# 'outer': 0.999 +# } +def c_set(name, type_, value, code): + code.append('printf("K' + name + '\\n");') + code.append('printf("V' + type_ + '\\n", ' + value + ');') + + +def c_descent(name, code): + code.append('printf("D' + name + '\\n");') + + +def c_ascent(code): + code.append('printf("A\\n");') + + +def parse_c_output(lines): + result = {} + cur_level = result + parent = [] + key = None + + for line in lines: + arg = line[1:].strip() + if '::' in arg: + arg = arg.split('::', 1)[1] + if line[0] == 'K': + # This is a key + key = arg + elif line[0] == 'V': + # A value + if arg[0] == 'i': + arg = int(arg[1:]) + elif arg[0] == 'f': + arg = float(arg[1:]) + elif arg[0] == 's': + arg = arg[1:] + + cur_level[key] = arg + elif line[0] == 'D': + # Remember the current level as the last parent. + parent.append(cur_level) + + # We descend one level. + cur_level[arg] = {} + cur_level = cur_level[arg] + elif line[0] == 'A': + # We return to the parent dict. (One level up.) + cur_level = parent.pop() + + return result + + +def gen_inspect_code(path, struct, code): + if path[0][-1] == '#': + path[0] = path[0].rstrip('#') + prefix = '' + else: + prefix = 'struct ' + + c_descent(path[-1], code) + + if len(path) == 1: + c_set('__size__', 'i%zu', 'sizeof (' + prefix + path[0] + ')', code) + else: + c_set('__size__', 'i%zu', 'sizeof ((' + prefix + path[0] + ' *)0)->' + '.'.join(path[1:]), code) + # c_set('__offset__', 'i%zu', 'offsetof(' + prefix + path[0] + ', ' + '.'.join(path[1:]) + ')', code) + + for field in struct: + if isinstance(field, dict): + # We have to recurse to inspect the nested dict. + fname = list(field.keys())[0] + gen_inspect_code(path + [fname], field[fname], code) + else: + c_set(field, 'i%zu', 'offsetof(' + prefix + path[0] + ', ' + '.'.join(path[1:] + [field]) + ')', code) + + c_ascent(code) + + +def generate_c_code(headers): + code = ['#include ', '#include '] + + code.extend(f'''#include "{header['name']}"''' for header in headers) + + code.append('int main() {') + c_descent('structs', code) + for header in headers: + for name, struct in header['structs'].items(): + gen_inspect_code([name], struct, code) + + c_ascent(code) + c_descent('defines', code) + for header in headers: + for name, type_ in header['defines'].items(): + # Add the necessary python type, if missing. + if '%' not in type_: + if type_[-1] in {'d', 'i', 'u'}: + # integer + type_ = 'i%' + type_ + elif type_[-1] in {'f', 'F', 'e', 'E', 'g', 'G'}: + # float + type_ = 'f%' + type_ + elif type_[-1] in {'x', 'X', 'a', 'A', 'c', 's'}: + # hexadecimal or string + type_ = 's%' + type_ + + c_set(name, type_, name, code) + + code.append('return 0;') + code.append('}') + + return code + + +def generate_cmd(js_file_path, src_file_path, cflags): + # Compile the program. + show('Compiling generated code...') + + if any('libcxxabi' in f for f in cflags): + compiler = shared.EMXX + else: + compiler = shared.EMCC + + node_flags = building.get_emcc_node_flags(shared.check_node_version()) + + # -O1+ produces calls to iprintf, which libcompiler_rt doesn't support + cmd = [compiler] + cflags + ['-o', js_file_path, src_file_path, + '-O0', + '-Werror', + '-Wno-format', + '-nolibc', + '-sBOOTSTRAPPING_STRUCT_INFO', + '-sINCOMING_MODULE_JS_API=', + '-sSTRICT', + '-sSUPPORT_LONGJMP=0', + '-sASSERTIONS=0'] + node_flags + + # Default behavior for emcc is to warn for binaryen version check mismatches + # so we should try to match that behavior. + cmd += ['-Wno-error=version-check'] + + # TODO(sbc): Remove this one we remove the test_em_config_env_var test + cmd += ['-Wno-deprecated'] + + show(shared.shlex_join(cmd)) + return cmd + + +def inspect_headers(headers, cflags): + # Write the source code to a temporary file. + src_file_fd, src_file_path = tempfile.mkstemp('.c', text=True) + show('Generating C code... ' + src_file_path) + code = generate_c_code(headers) + os.write(src_file_fd, '\n'.join(code).encode()) + os.close(src_file_fd) + + js_file_fd, js_file_path = tempfile.mkstemp('.js') + # Close the unneeded FD. + os.close(js_file_fd) + + cmd = generate_cmd(js_file_path, src_file_path, cflags) + + try: + subprocess.check_call(cmd, env=system_libs.clean_env()) + except subprocess.CalledProcessError as e: + sys.stderr.write('FAIL: Compilation failed!: %s\n' % e.cmd) + sys.exit(1) + + # Run the compiled program. + show('Calling generated program... ' + js_file_path) + node_args = shared.node_bigint_flags(config.NODE_JS) + info = shared.run_js_tool(js_file_path, node_args=node_args, stdout=shared.PIPE).splitlines() + + if not DEBUG: + # Remove all temporary files. + os.unlink(src_file_path) + + if os.path.exists(js_file_path): + os.unlink(js_file_path) + wasm_file_path = shared.replace_suffix(js_file_path, '.wasm') + os.unlink(wasm_file_path) + + # Parse the output of the program into a dict. + return parse_c_output(info) + + +def merge_info(target, src): + for key, value in src['defines'].items(): + if key in target['defines']: + raise Exception('duplicate define: %s' % key) + target['defines'][key] = value + + for key, value in src['structs'].items(): + if key in target['structs']: + raise Exception('duplicate struct: %s' % key) + target['structs'][key] = value + + +def inspect_code(headers, cflags): + if not DEBUG: + info = inspect_headers(headers, cflags) + else: + info = {'defines': {}, 'structs': {}} + for header in headers: + merge_info(info, inspect_headers([header], cflags)) + return info + + +def parse_json(path): + header_files = [] + + with open(path, 'r') as stream: + # Remove comments before loading the JSON. + data = json.loads(re.sub(r'//.*\n', '', stream.read())) + + if not isinstance(data, list): + data = [data] + + for item in data: + for key in item: + if key not in ['file', 'defines', 'structs']: + raise 'Unexpected key in json file: %s' % key + + header = {'name': item['file'], 'structs': {}, 'defines': {}} + for name, data in item.get('structs', {}).items(): + if name in header['structs']: + show('WARN: Description of struct "' + name + '" in file "' + item['file'] + '" replaces an existing description!') + + header['structs'][name] = data + + for part in item.get('defines', []): + if not isinstance(part, list): + # If no type is specified, assume integer. + part = ['i', part] + + if part[1] in header['defines']: + show('WARN: Description of define "' + part[1] + '" in file "' + item['file'] + '" replaces an existing description!') + + header['defines'][part[1]] = part[0] + + header_files.append(header) + + return header_files + + +def output_json(obj, stream): + json.dump(obj, stream, indent=4, sort_keys=True) + stream.write('\n') + stream.close() + + +def main(args): + global QUIET + + parser = argparse.ArgumentParser(description='Generate JSON infos for structs.') + parser.add_argument('json', nargs='*', + help='JSON file with a list of structs and their fields (defaults to src/struct_info.json)', + default=DEFAULT_JSON_FILES) + parser.add_argument('-q', dest='quiet', action='store_true', default=False, + help='Don\'t output anything besides error messages.') + parser.add_argument('-o', dest='output', metavar='path', default=None, + help='Path to the JSON file that will be written. If omitted, the default location under `src` will be used.') + parser.add_argument('-I', dest='includes', metavar='dir', action='append', default=[], + help='Add directory to include search path') + parser.add_argument('-D', dest='defines', metavar='define', action='append', default=[], + help='Pass a define to the preprocessor') + parser.add_argument('-U', dest='undefines', metavar='undefine', action='append', default=[], + help='Pass an undefine to the preprocessor') + parser.add_argument('--wasm64', action='store_true', + help='use wasm64 architecture') + args = parser.parse_args(args) + + QUIET = args.quiet + + extra_cflags = [] + + if args.wasm64: + # Always use =2 here so that we don't generate a binary that actually requires + # memory64 to run. All we care about is that the output is correct. + extra_cflags += ['-sMEMORY64=2', '-Wno-experimental'] + + # Add the user options to the list as well. + for path in args.includes: + extra_cflags.append('-I' + path) + + for arg in args.defines: + extra_cflags.append('-D' + arg) + + for arg in args.undefines: + extra_cflags.append('-U' + arg) + + # Look for structs in all passed headers. + info = {'defines': {}, 'structs': {}} + + for f in args.json: + # This is a JSON file, parse it. + header_files = parse_json(f) + # Inspect all collected structs. + if 'internal' in f: + use_cflags = CFLAGS + extra_cflags + INTERNAL_CFLAGS + elif 'cxx' in f: + use_cflags = CFLAGS + extra_cflags + CXXFLAGS + else: + use_cflags = CFLAGS + extra_cflags + info_fragment = inspect_code(header_files, use_cflags) + merge_info(info, info_fragment) + + if args.output: + output_file = args.output + elif args.wasm64: + output_file = utils.path_from_root('src/struct_info_generated_wasm64.json') + else: + output_file = utils.path_from_root('src/struct_info_generated.json') + + with open(output_file, 'w') as f: + output_json(info, f) + + return 0 + + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/js/build_webgpu.bat b/js/build_webgpu.bat new file mode 100644 index 0000000000..767f6b3444 --- /dev/null +++ b/js/build_webgpu.bat @@ -0,0 +1,78 @@ +@echo off + +rem build_webgpu.bat --- build onnxruntime-web with WebGPU EP +rem +rem Usage: +rem build_webgpu.bat config [clean] +rem +rem Options: +rem config Build configuration, "d" or "r" +rem clean Perform a clean build, "clean" or empty + +setlocal enabledelayedexpansion + +set ROOT=%~dp0..\ +set BUILD_DIR=%ROOT%build_webgpu + +:arg1 +if ["%~1"]==["d"] ( + set CONFIG=Debug + set CONFIG_EXTRA_FLAG=--enable_wasm_profiling --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1 + @rem --enable_wasm_debug_info + goto :arg2 +) +if ["%~1"]==["r"] ( + set CONFIG=Release + set CONFIG_EXTRA_FLAG=--enable_wasm_api_exception_catching --disable_rtti + goto :arg2 +) +echo Invalid configuration "%~1", must be "d"(Debug) or "r"(Release) +exit /b 1 + +:arg2 +if ["%~2"]==["clean"] ( + goto :clean +) +if not exist "%ROOT%js\web\dist" ( + goto :npm_ci +) + +goto :build_wasm + +:clean +if exist "%BUILD_DIR%" ( + rd /s /q %BUILD_DIR% +) + +pushd %ROOT% +git submodule sync --recursive +git submodule update --init --recursive +popd + +:npm_ci +pushd %ROOT%js +call npm ci +popd +pushd %ROOT%js\common +call npm ci +popd +pushd %ROOT%js\web +call npm ci +call npm run pull:wasm +popd + +:build_wasm + +set PATH=C:\Program Files\Git\usr\bin;%PATH% + +call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm^ + --enable_wasm_simd --enable_wasm_threads --use_webgpu --build_dir %BUILD_DIR% + +@REM --target onnxruntime_webassembly + +IF NOT "%ERRORLEVEL%" == "0" ( + exit /b %ERRORLEVEL% +) + +@REM copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.wasm %ROOT%js\web\dist\ +@REM copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.mjs %ROOT%js\web\dist\ diff --git a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index fe541f58d3..a1840257d7 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -116,7 +116,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo auto* output = context.Output(0, x_shape); auto* input_skip_bias_sum = context.Output(3, x_shape); - size_t data_size = x_shape.Size(); + int64_t data_size = x_shape.Size(); if (data_size == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 8e27acdc28..51a06ad764 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -1,10 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __EMSCRIPTEN__ -#include -#endif - #include "core/framework/session_state.h" #include "core/providers/webgpu/allocator.h" #include "core/providers/webgpu/webgpu_context.h" diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index 233bb24083..1c2ff3f91a 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -85,7 +85,7 @@ class SimpleCacheManager : public IBufferCacheManager { void OnRefresh() override { for (auto& buffer : pending_buffers_) { - buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); + buffers_[static_cast(wgpuBufferGetSize(buffer))].push_back(buffer); } pending_buffers_.clear(); } @@ -167,7 +167,7 @@ class BucketCacheManager : public IBufferCacheManager { // TODO: consider graph capture. currently not supported for (auto& buffer : pending_buffers_) { - auto buffer_size = wgpuBufferGetSize(buffer); + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); auto it = buckets_.find(buffer_size); if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index 00febfbc29..20bee52835 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -5,10 +5,6 @@ #include -#ifdef __EMSCRIPTEN__ -#include -#endif - #include #include "core/framework/execution_provider.h" diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index b7ea8a58e2..991cfce2fc 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -3,10 +3,6 @@ #pragma once -#ifdef __EMSCRIPTEN__ -#include -#endif - #include #include diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index 615ae11175..ac376b4fce 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -1,10 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __EMSCRIPTEN__ -#include -#endif - #include "core/providers/webgpu/data_transfer.h" #include "core/providers/webgpu/webgpu_context.h" diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index ec7fed5a94..75866513e2 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -123,10 +123,10 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) { vectorize = true; } else { - size_t shared_dimension = 1; + int64_t shared_dimension = 1; for (size_t i = 1; i < output_shape.NumDimensions(); i++) { - size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; - size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; + int64_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1; + int64_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1; if (dimA == dimB) { shared_dimension *= dimA; num_shared_dimension++; diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 1ee771e945..64172021e8 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -85,8 +85,7 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex auto* output = context.Output(0, x_shape); - size_t data_size = x_shape.Size(); - if (data_size == 0) { + if (x_shape.Size() == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index eded1cfa17..5572177001 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -3,10 +3,6 @@ #pragma once -#ifdef __EMSCRIPTEN__ -#include -#endif - #include #include diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index a4b96edc63..dac08f3bd9 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -5,10 +5,6 @@ #include -#ifdef __EMSCRIPTEN__ -#include -#endif - #include #include "core/framework/tensor_shape.h" diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index c708f24dcc..5ed8099fde 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -106,7 +106,8 @@ Status Concat::ComputeInternal(ComputeContext& context) const { uint32_t output_size = gsl::narrow_cast(prepare.output_tensor->Shape().Size()); - ConcatProgram program{prepare.axis}; + size_t axis = static_cast(prepare.axis); + ConcatProgram program{axis}; std::vector sizes_in_concat_axis; sizes_in_concat_axis.reserve(input_count); @@ -118,7 +119,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank}); - auto axis_size = input.tensor->Shape()[prepare.axis]; + auto axis_size = input.tensor->Shape()[axis]; sum += static_cast(axis_size); sizes_in_concat_axis.push_back(sum); } diff --git a/onnxruntime/core/providers/webgpu/tensor/flatten.h b/onnxruntime/core/providers/webgpu/tensor/flatten.h index 5fc49a844b..68f47eacd8 100644 --- a/onnxruntime/core/providers/webgpu/tensor/flatten.h +++ b/onnxruntime/core/providers/webgpu/tensor/flatten.h @@ -31,15 +31,8 @@ class Flatten final : public OpKernel { return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank"); } - int64_t first_dim = 1; - for (int64_t i = 0; i < axis; i++) { - first_dim *= input_shape[i]; - } - - int64_t second_dim = 1; - for (int64_t i = axis; i < input_rank; i++) { - second_dim *= input_shape[i]; - } + int64_t first_dim = input_shape.SizeToDimension(static_cast(axis)); + int64_t second_dim = input_shape.SizeFromDimension(static_cast(axis)); TensorShape output_shape({first_dim, second_dim}); Tensor* output_tensor = context->Output(0, output_shape); @@ -59,4 +52,4 @@ class Flatten final : public OpKernel { }; } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc b/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc index 00d8caf262..4938cd7acf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc @@ -59,7 +59,7 @@ Status GatherElements::ComputeInternal(ComputeContext& context) const { axis += input_rank; } - auto axis_dim_limit = input_shape[axis]; + auto axis_dim_limit = input_shape[static_cast(axis)]; auto output_dims = indices_shape.AsShapeVector(); TensorShape output_shape(output_dims); @@ -83,4 +83,4 @@ Status GatherElements::ComputeInternal(ComputeContext& context) const { } } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 7e7ff7dbde..c6f84a8099 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -4,10 +4,12 @@ #include #include +#if !defined(__wasm__) #include "dawn/dawn_proc.h" #if !defined(USE_EXTERNAL_DAWN) #include "dawn/native/DawnNative.h" #endif +#endif #include "core/common/common.h" #include "core/common/path_string.h" @@ -29,7 +31,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() { // Create wgpu::Adapter if (adapter_ == nullptr) { -#if !defined(__EMSCRIPTEN__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN) +#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN) // If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required. // // Dawn will try to load them later, but if they are in the different directory to the executable, it may fail to find them. @@ -54,15 +56,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi #endif wgpu::RequestAdapterOptions req_adapter_options = {}; - wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; - req_adapter_options.nextInChain = &adapter_toggles_desc; req_adapter_options.backendType = static_cast(backend_type); req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; +#if !defined(__wasm__) auto enabled_adapter_toggles = GetEnabledAdapterToggles(); + + wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data(); + req_adapter_options.nextInChain = &adapter_toggles_desc; +#endif + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter( &req_adapter_options, wgpu::CallbackMode::WaitAnyOnly, @@ -78,6 +84,8 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // Create wgpu::Device if (device_ == nullptr) { wgpu::DeviceDescriptor device_desc = {}; + +#if !defined(__wasm__) wgpu::DawnTogglesDescriptor device_toggles_desc = {}; device_desc.nextInChain = &device_toggles_desc; @@ -88,6 +96,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi auto disabled_device_toggles = GetDisabledDeviceToggles(); device_toggles_desc.disabledToggleCount = disabled_device_toggles.size(); device_toggles_desc.disabledToggles = disabled_device_toggles.data(); +#endif std::vector required_features = GetAvailableRequiredFeatures(adapter_); if (required_features.size() > 0) { @@ -136,9 +145,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi program_mgr_ = std::make_unique(Device(), DeviceLimits()); // set query type +#if !defined(__wasm__) if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) { query_type_ = TimestampQueryType::InsidePasses; - } else if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) { + } else +#endif + if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) { query_type_ = TimestampQueryType::AtPasses; } else { query_type_ = TimestampQueryType::None; @@ -456,7 +468,9 @@ std::vector WebGpuContext::GetDisabledDeviceToggles() const { std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { std::vector required_features; constexpr wgpu::FeatureName features[]{ +#if !defined(__wasm__) wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, +#endif wgpu::FeatureName::TimestampQuery, wgpu::FeatureName::ShaderF16, wgpu::FeatureName::Subgroups, @@ -531,7 +545,7 @@ void WebGpuContext::CollectProfilingData(profiling::Events& events) { ORT_ENFORCE(Wait(query_read_buffer.MapAsync(wgpu::MapMode::Read, 0, - query_read_buffer.GetSize(), + static_cast(query_read_buffer.GetSize()), wgpu::CallbackMode::WaitAnyOnly, [](wgpu::MapAsyncStatus status, wgpu::StringView message) { ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", std::string_view{message}); @@ -658,8 +672,14 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); - std::call_once(init_default_flag_, [dawn_proc_table = config.dawn_proc_table]() { - // Step.1 - setup dawn proc table + std::call_once(init_default_flag_, [ +#if !defined(__wasm__) + dawn_proc_table = config.dawn_proc_table +#endif + ]() { + // Step.1 - setup dawn proc table (only for non-WASM build) + +#if !defined(__wasm__) const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); #if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); @@ -672,6 +692,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); #endif dawnProcSetProcs(dawn_procs); +#endif #endif // Step.2 - Create wgpu::Instance diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 4724118a29..80c8c64ce7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -3,10 +3,6 @@ #pragma once -#ifdef __EMSCRIPTEN__ -#include -#endif - #include #include diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 76a55b7ce4..d12964c9a4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -3,9 +3,6 @@ #include "core/providers/webgpu/webgpu_execution_provider.h" -#ifdef __EMSCRIPTEN__ -#include -#endif #include #include #include