mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Fix wasm build with WebGPU EP
This commit is contained in:
parent
a74817ab10
commit
22bcd9a985
22 changed files with 701 additions and 146 deletions
124
cmake/external/onnxruntime_external_deps.cmake
vendored
124
cmake/external/onnxruntime_external_deps.cmake
vendored
|
|
@ -631,78 +631,92 @@ if (onnxruntime_USE_WEBGPU)
|
|||
URL_HASH SHA1=${DEP_SHA1_dawn}
|
||||
# All previous patches are merged into the upstream dawn project. We don't need to apply any patches right now.
|
||||
# if we need to apply patches in the future, we can uncomment the following line.
|
||||
# PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch
|
||||
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/dawn/dawn.patch
|
||||
)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE)
|
||||
|
||||
if (onnxruntime_USE_EXTERNAL_DAWN)
|
||||
message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.")
|
||||
endif()
|
||||
else()
|
||||
# use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size
|
||||
set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE)
|
||||
|
||||
# disable things we don't use
|
||||
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
|
||||
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE)
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||
set(DAWN_EMSCRIPTEN_TOOLCHAIN "${REPO_ROOT}/cmake/external/emsdk/upstream/emscripten" CACHE STRING "" FORCE)
|
||||
|
||||
set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving
|
||||
set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled.
|
||||
# file "${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/maint/gen_struct_info.py" is missing from emsdk installation.
|
||||
# we should copy if from ${PROJECT_SOURCE_DIR}/patches/
|
||||
file(COPY_FILE
|
||||
"${PROJECT_SOURCE_DIR}/patches/emscripten/tools/maint/gen_struct_info.py"
|
||||
"${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/maint/gen_struct_info.py"
|
||||
ONLY_IF_DIFFERENT)
|
||||
else()
|
||||
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
set(DAWN_BUILD_MONOLITHIC_LIBRARY ON CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_INSTALL ON CACHE BOOL "" FORCE)
|
||||
|
||||
# SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V.
|
||||
set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE)
|
||||
|
||||
if (WIN32)
|
||||
# building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON.
|
||||
set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE)
|
||||
|
||||
if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12))
|
||||
message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.")
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN)
|
||||
set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE)
|
||||
if (onnxruntime_USE_EXTERNAL_DAWN)
|
||||
message(FATAL_ERROR "onnxruntime_USE_EXTERNAL_DAWN and onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY cannot be enabled at the same time.")
|
||||
endif()
|
||||
else()
|
||||
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
|
||||
# use dawn::dawn_native and dawn::dawn_proc instead of the monolithic dawn::webgpu_dawn to minimize binary size
|
||||
set(DAWN_BUILD_MONOLITHIC_LIBRARY OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_INSTALL OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
|
||||
set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE)
|
||||
else()
|
||||
set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# disable things we don't use
|
||||
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
|
||||
set(DAWN_ENABLE_DESKTOP_GL OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_ENABLE_OPENGLES OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_SUPPORTS_GLFW_FOR_WINDOWING OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_USE_GLFW OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_USE_WINDOWS_UI OFF CACHE BOOL "" FORCE)
|
||||
set(DAWN_USE_X11 OFF CACHE BOOL "" FORCE)
|
||||
|
||||
set(TINT_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_CMD_TOOLS OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_GLSL_WRITER OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_GLSL_VALIDATOR OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_IR_BINARY OFF CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_SPV_READER OFF CACHE BOOL "" FORCE) # don't need. disabling is a large binary size saving
|
||||
set(TINT_BUILD_WGSL_WRITER ON CACHE BOOL "" FORCE) # needed to create cache key. runtime error if not enabled.
|
||||
|
||||
# SPIR-V validation shouldn't be required given we're using Tint to create the SPIR-V.
|
||||
set(DAWN_ENABLE_SPIRV_VALIDATION OFF CACHE BOOL "" FORCE)
|
||||
|
||||
if (WIN32)
|
||||
# building this requires the HLSL writer to be enabled in Tint. TBD if that we need either of these to be ON.
|
||||
set(DAWN_USE_BUILT_DXC ON CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_HLSL_WRITER ON CACHE BOOL "" FORCE)
|
||||
|
||||
if ((NOT onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) AND (NOT onnxruntime_ENABLE_DAWN_BACKEND_D3D12))
|
||||
message(FATAL_ERROR "At least one of onnxruntime_ENABLE_DAWN_BACKEND_VULKAN or onnxruntime_ENABLE_DAWN_BACKEND_D3D12 must be enabled when using Dawn on Windows.")
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN)
|
||||
set(DAWN_ENABLE_VULKAN ON CACHE BOOL "" FORCE)
|
||||
set(TINT_BUILD_SPV_WRITER ON CACHE BOOL "" FORCE)
|
||||
else()
|
||||
set(DAWN_ENABLE_VULKAN OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
|
||||
set(DAWN_ENABLE_D3D12 ON CACHE BOOL "" FORCE)
|
||||
else()
|
||||
set(DAWN_ENABLE_D3D12 OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
# We are currently always using the D3D12 backend.
|
||||
set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
# We are currently always using the D3D12 backend.
|
||||
set(DAWN_ENABLE_D3D11 OFF CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
|
||||
onnxruntime_fetchcontent_makeavailable(dawn)
|
||||
|
||||
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn)
|
||||
else()
|
||||
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
|
||||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native)
|
||||
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::webgpu_dawn)
|
||||
else()
|
||||
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
|
||||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_native)
|
||||
endif()
|
||||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
|
||||
endif()
|
||||
list(APPEND onnxruntime_EXTERNAL_LIBRARIES dawn::dawn_proc)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -21,44 +21,57 @@
|
|||
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs})
|
||||
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu
|
||||
onnxruntime_common dawn::dawncpp_headers dawn::dawn_headers onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
|
||||
onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
|
||||
|
||||
set(onnxruntime_providers_webgpu_dll_deps)
|
||||
|
||||
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)
|
||||
|
||||
if (WIN32)
|
||||
if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS)
|
||||
list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll")
|
||||
endif()
|
||||
|
||||
list(APPEND onnxruntime_providers_webgpu_dll_deps "$<TARGET_FILE:dawn::webgpu_dawn>")
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
|
||||
# if "-s DISABLE_EXCEPTION_CATCHING=0" is globally set, we need to remove "-fno-exceptions" from emdawnwebgpu_c
|
||||
if (CMAKE_CXX_FLAGS MATCHES "DISABLE_EXCEPTION_CATCHING=0")
|
||||
get_property(EM_DAWN_WEBGPU_C_COMPILE_OPTIONS TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS)
|
||||
list(REMOVE_ITEM EM_DAWN_WEBGPU_C_COMPILE_OPTIONS "-fno-exceptions")
|
||||
set_property(TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS ${EM_DAWN_WEBGPU_C_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
target_link_libraries(onnxruntime_providers_webgpu PUBLIC emdawnwebgpu_cpp)
|
||||
else()
|
||||
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
|
||||
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native)
|
||||
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu dawn::dawncpp_headers dawn::dawn_headers)
|
||||
|
||||
set(onnxruntime_providers_webgpu_dll_deps)
|
||||
|
||||
if (onnxruntime_BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)
|
||||
|
||||
if (WIN32)
|
||||
if (onnxruntime_ENABLE_DELAY_LOADING_WIN_DLLS)
|
||||
list(APPEND onnxruntime_DELAYLOAD_FLAGS "/DELAYLOAD:webgpu_dawn.dll")
|
||||
endif()
|
||||
|
||||
list(APPEND onnxruntime_providers_webgpu_dll_deps "$<TARGET_FILE:dawn::webgpu_dawn>")
|
||||
endif()
|
||||
else()
|
||||
if (NOT onnxruntime_USE_EXTERNAL_DAWN)
|
||||
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_native)
|
||||
endif()
|
||||
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc)
|
||||
endif()
|
||||
target_link_libraries(onnxruntime_providers_webgpu dawn::dawn_proc)
|
||||
endif()
|
||||
|
||||
if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
|
||||
# Ensure dxil.dll and dxcompiler.dll exist in the output directory $<TARGET_FILE_DIR:dxcompiler>
|
||||
add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll)
|
||||
add_dependencies(onnxruntime_providers_webgpu dxcompiler)
|
||||
if (WIN32 AND onnxruntime_ENABLE_DAWN_BACKEND_D3D12)
|
||||
# Ensure dxil.dll and dxcompiler.dll exist in the output directory $<TARGET_FILE_DIR:dxcompiler>
|
||||
add_dependencies(onnxruntime_providers_webgpu copy_dxil_dll)
|
||||
add_dependencies(onnxruntime_providers_webgpu dxcompiler)
|
||||
|
||||
list(APPEND onnxruntime_providers_webgpu_dll_deps "$<TARGET_FILE_DIR:dxcompiler>/dxil.dll")
|
||||
list(APPEND onnxruntime_providers_webgpu_dll_deps "$<TARGET_FILE_DIR:dxcompiler>/dxcompiler.dll")
|
||||
endif()
|
||||
list(APPEND onnxruntime_providers_webgpu_dll_deps "$<TARGET_FILE_DIR:dxcompiler>/dxil.dll")
|
||||
list(APPEND onnxruntime_providers_webgpu_dll_deps "$<TARGET_FILE_DIR:dxcompiler>/dxcompiler.dll")
|
||||
endif()
|
||||
|
||||
if (onnxruntime_providers_webgpu_dll_deps)
|
||||
# Copy dependency DLLs to the output directory
|
||||
add_custom_command(
|
||||
TARGET onnxruntime_providers_webgpu
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different "${onnxruntime_providers_webgpu_dll_deps}" "$<TARGET_FILE_DIR:onnxruntime_providers_webgpu>"
|
||||
COMMAND_EXPAND_LISTS
|
||||
VERBATIM )
|
||||
if (onnxruntime_providers_webgpu_dll_deps)
|
||||
# Copy dependency DLLs to the output directory
|
||||
add_custom_command(
|
||||
TARGET onnxruntime_providers_webgpu
|
||||
POST_BUILD
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different "${onnxruntime_providers_webgpu_dll_deps}" "$<TARGET_FILE_DIR:onnxruntime_providers_webgpu>"
|
||||
COMMAND_EXPAND_LISTS
|
||||
VERBATIM )
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_dependencies(onnxruntime_providers_webgpu ${onnxruntime_EXTERNAL_DEPENDENCIES})
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
|
|||
${PROVIDERS_JS}
|
||||
${PROVIDERS_XNNPACK}
|
||||
${PROVIDERS_WEBNN}
|
||||
${PROVIDERS_WEBGPU}
|
||||
onnxruntime_session
|
||||
onnxruntime_util
|
||||
re2::re2
|
||||
|
|
@ -188,6 +189,7 @@ else()
|
|||
${PROVIDERS_JS}
|
||||
${PROVIDERS_XNNPACK}
|
||||
${PROVIDERS_WEBNN}
|
||||
${PROVIDERS_WEBGPU}
|
||||
onnxruntime_session
|
||||
onnxruntime_util
|
||||
re2::re2
|
||||
|
|
@ -386,6 +388,11 @@ jsepDownload:_pp_")
|
|||
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_WEBGPU)
|
||||
target_compile_definitions(onnxruntime_webassembly PRIVATE USE_WEBGPU=1)
|
||||
# target_link_libraries(onnxruntime_webassembly PUBLIC emdawnwebgpu_cpp)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_EMSCRIPTEN_SETTINGS)
|
||||
foreach(setting IN LISTS onnxruntime_EMSCRIPTEN_SETTINGS)
|
||||
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ${setting}")
|
||||
|
|
|
|||
12
cmake/patches/dawn/dawn.patch
Normal file
12
cmake/patches/dawn/dawn.patch
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
diff --git a/src/emdawnwebgpu/CMakeLists.txt b/src/emdawnwebgpu/CMakeLists.txt
|
||||
index de673537d3..c98dc46de7 100644
|
||||
--- a/src/emdawnwebgpu/CMakeLists.txt
|
||||
+++ b/src/emdawnwebgpu/CMakeLists.txt
|
||||
@@ -78,6 +78,7 @@ if (${DAWN_ENABLE_EMSCRIPTEN})
|
||||
endif()
|
||||
|
||||
set(ARGS
|
||||
+ ${Python3_EXECUTABLE}
|
||||
"${DAWN_EMSCRIPTEN_TOOLCHAIN}/tools/maint/gen_struct_info.py"
|
||||
-q
|
||||
"${EM_BUILD_GEN_DIR}/struct_info_webgpu.json"
|
||||
448
cmake/patches/emscripten/tools/maint/gen_struct_info.py
Normal file
448
cmake/patches/emscripten/tools/maint/gen_struct_info.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
#!/usr/bin/env python3
|
||||
# coding=utf-8
|
||||
# Copyright 2013 The Emscripten Authors. All rights reserved.
|
||||
# Emscripten is available under two separate licenses, the MIT license and the
|
||||
# University of Illinois/NCSA Open Source License. Both these licenses can be
|
||||
# found in the LICENSE file.
|
||||
|
||||
"""This tool extracts information about structs and defines from the C headers.
|
||||
|
||||
The JSON input format is as follows:
|
||||
[
|
||||
{
|
||||
'file': 'some/header.h',
|
||||
'structs': {
|
||||
'struct_name': [
|
||||
'field1',
|
||||
'field2',
|
||||
'field3',
|
||||
{
|
||||
'field4': [
|
||||
'nested1',
|
||||
'nested2',
|
||||
{
|
||||
'nested3': [
|
||||
'deep_nested1',
|
||||
...
|
||||
]
|
||||
}
|
||||
...
|
||||
]
|
||||
},
|
||||
'field5'
|
||||
],
|
||||
'other_struct': [
|
||||
'field1',
|
||||
'field2',
|
||||
...
|
||||
]
|
||||
},
|
||||
'defines': [
|
||||
'DEFINE_1',
|
||||
'DEFINE_2',
|
||||
['f', 'FLOAT_DEFINE'],
|
||||
'DEFINE_3',
|
||||
...
|
||||
]
|
||||
},
|
||||
{
|
||||
'file': 'some/other/header.h',
|
||||
...
|
||||
}
|
||||
]
|
||||
|
||||
Please note that the 'f' for 'FLOAT_DEFINE' is just the format passed to printf(), you can put
|
||||
anything printf() understands.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
import tempfile
|
||||
import subprocess
|
||||
|
||||
__scriptdir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
__rootdir__ = os.path.dirname(os.path.dirname(__scriptdir__))
|
||||
sys.path.insert(0, __rootdir__)
|
||||
|
||||
from tools import building
|
||||
from tools import config
|
||||
from tools import shared
|
||||
from tools import system_libs
|
||||
from tools import utils
|
||||
|
||||
QUIET = (__name__ != '__main__')
|
||||
DEBUG = False
|
||||
|
||||
CFLAGS = [
|
||||
# Avoid parsing problems due to gcc specific syntax.
|
||||
'-D_GNU_SOURCE',
|
||||
]
|
||||
|
||||
INTERNAL_CFLAGS = [
|
||||
'-I' + utils.path_from_root('system/lib/libc/musl/src/internal'),
|
||||
'-I' + utils.path_from_root('system/lib/libc/musl/src/include'),
|
||||
'-I' + utils.path_from_root('system/lib/pthread/'),
|
||||
]
|
||||
|
||||
CXXFLAGS = [
|
||||
'-I' + utils.path_from_root('system/lib/libcxxabi/src'),
|
||||
'-D__EMSCRIPTEN_EXCEPTIONS__',
|
||||
'-I' + utils.path_from_root('system/lib/wasmfs/'),
|
||||
'-std=c++17',
|
||||
]
|
||||
|
||||
DEFAULT_JSON_FILES = [
|
||||
utils.path_from_root('src/struct_info.json'),
|
||||
utils.path_from_root('src/struct_info_internal.json'),
|
||||
utils.path_from_root('src/struct_info_cxx.json'),
|
||||
utils.path_from_root('src/struct_info_webgpu.json'),
|
||||
]
|
||||
|
||||
|
||||
def show(msg):
|
||||
if shared.DEBUG or not QUIET:
|
||||
sys.stderr.write('gen_struct_info: %s\n' % msg)
|
||||
|
||||
|
||||
# The following three functions generate C code. The output of the compiled code will be
|
||||
# parsed later on and then put back together into a dict structure by parse_c_output().
|
||||
#
|
||||
# Example:
|
||||
# c_descent('test1', code)
|
||||
# c_set('item', 'i%i', '111', code)
|
||||
# c_set('item2', 'i%i', '9', code)
|
||||
# c_set('item3', 's%s', '"Hello"', code)
|
||||
# c_ascent(code)
|
||||
# c_set('outer', 'f%f', '0.999', code)
|
||||
#
|
||||
# Will result in:
|
||||
# {
|
||||
# 'test1': {
|
||||
# 'item': 111,
|
||||
# 'item2': 9,
|
||||
# 'item3': 'Hello',
|
||||
# },
|
||||
# 'outer': 0.999
|
||||
# }
|
||||
def c_set(name, type_, value, code):
|
||||
code.append('printf("K' + name + '\\n");')
|
||||
code.append('printf("V' + type_ + '\\n", ' + value + ');')
|
||||
|
||||
|
||||
def c_descent(name, code):
|
||||
code.append('printf("D' + name + '\\n");')
|
||||
|
||||
|
||||
def c_ascent(code):
|
||||
code.append('printf("A\\n");')
|
||||
|
||||
|
||||
def parse_c_output(lines):
|
||||
result = {}
|
||||
cur_level = result
|
||||
parent = []
|
||||
key = None
|
||||
|
||||
for line in lines:
|
||||
arg = line[1:].strip()
|
||||
if '::' in arg:
|
||||
arg = arg.split('::', 1)[1]
|
||||
if line[0] == 'K':
|
||||
# This is a key
|
||||
key = arg
|
||||
elif line[0] == 'V':
|
||||
# A value
|
||||
if arg[0] == 'i':
|
||||
arg = int(arg[1:])
|
||||
elif arg[0] == 'f':
|
||||
arg = float(arg[1:])
|
||||
elif arg[0] == 's':
|
||||
arg = arg[1:]
|
||||
|
||||
cur_level[key] = arg
|
||||
elif line[0] == 'D':
|
||||
# Remember the current level as the last parent.
|
||||
parent.append(cur_level)
|
||||
|
||||
# We descend one level.
|
||||
cur_level[arg] = {}
|
||||
cur_level = cur_level[arg]
|
||||
elif line[0] == 'A':
|
||||
# We return to the parent dict. (One level up.)
|
||||
cur_level = parent.pop()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def gen_inspect_code(path, struct, code):
|
||||
if path[0][-1] == '#':
|
||||
path[0] = path[0].rstrip('#')
|
||||
prefix = ''
|
||||
else:
|
||||
prefix = 'struct '
|
||||
|
||||
c_descent(path[-1], code)
|
||||
|
||||
if len(path) == 1:
|
||||
c_set('__size__', 'i%zu', 'sizeof (' + prefix + path[0] + ')', code)
|
||||
else:
|
||||
c_set('__size__', 'i%zu', 'sizeof ((' + prefix + path[0] + ' *)0)->' + '.'.join(path[1:]), code)
|
||||
# c_set('__offset__', 'i%zu', 'offsetof(' + prefix + path[0] + ', ' + '.'.join(path[1:]) + ')', code)
|
||||
|
||||
for field in struct:
|
||||
if isinstance(field, dict):
|
||||
# We have to recurse to inspect the nested dict.
|
||||
fname = list(field.keys())[0]
|
||||
gen_inspect_code(path + [fname], field[fname], code)
|
||||
else:
|
||||
c_set(field, 'i%zu', 'offsetof(' + prefix + path[0] + ', ' + '.'.join(path[1:] + [field]) + ')', code)
|
||||
|
||||
c_ascent(code)
|
||||
|
||||
|
||||
def generate_c_code(headers):
|
||||
code = ['#include <stdio.h>', '#include <stddef.h>']
|
||||
|
||||
code.extend(f'''#include "{header['name']}"''' for header in headers)
|
||||
|
||||
code.append('int main() {')
|
||||
c_descent('structs', code)
|
||||
for header in headers:
|
||||
for name, struct in header['structs'].items():
|
||||
gen_inspect_code([name], struct, code)
|
||||
|
||||
c_ascent(code)
|
||||
c_descent('defines', code)
|
||||
for header in headers:
|
||||
for name, type_ in header['defines'].items():
|
||||
# Add the necessary python type, if missing.
|
||||
if '%' not in type_:
|
||||
if type_[-1] in {'d', 'i', 'u'}:
|
||||
# integer
|
||||
type_ = 'i%' + type_
|
||||
elif type_[-1] in {'f', 'F', 'e', 'E', 'g', 'G'}:
|
||||
# float
|
||||
type_ = 'f%' + type_
|
||||
elif type_[-1] in {'x', 'X', 'a', 'A', 'c', 's'}:
|
||||
# hexadecimal or string
|
||||
type_ = 's%' + type_
|
||||
|
||||
c_set(name, type_, name, code)
|
||||
|
||||
code.append('return 0;')
|
||||
code.append('}')
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def generate_cmd(js_file_path, src_file_path, cflags):
|
||||
# Compile the program.
|
||||
show('Compiling generated code...')
|
||||
|
||||
if any('libcxxabi' in f for f in cflags):
|
||||
compiler = shared.EMXX
|
||||
else:
|
||||
compiler = shared.EMCC
|
||||
|
||||
node_flags = building.get_emcc_node_flags(shared.check_node_version())
|
||||
|
||||
# -O1+ produces calls to iprintf, which libcompiler_rt doesn't support
|
||||
cmd = [compiler] + cflags + ['-o', js_file_path, src_file_path,
|
||||
'-O0',
|
||||
'-Werror',
|
||||
'-Wno-format',
|
||||
'-nolibc',
|
||||
'-sBOOTSTRAPPING_STRUCT_INFO',
|
||||
'-sINCOMING_MODULE_JS_API=',
|
||||
'-sSTRICT',
|
||||
'-sSUPPORT_LONGJMP=0',
|
||||
'-sASSERTIONS=0'] + node_flags
|
||||
|
||||
# Default behavior for emcc is to warn for binaryen version check mismatches
|
||||
# so we should try to match that behavior.
|
||||
cmd += ['-Wno-error=version-check']
|
||||
|
||||
# TODO(sbc): Remove this one we remove the test_em_config_env_var test
|
||||
cmd += ['-Wno-deprecated']
|
||||
|
||||
show(shared.shlex_join(cmd))
|
||||
return cmd
|
||||
|
||||
|
||||
def inspect_headers(headers, cflags):
|
||||
# Write the source code to a temporary file.
|
||||
src_file_fd, src_file_path = tempfile.mkstemp('.c', text=True)
|
||||
show('Generating C code... ' + src_file_path)
|
||||
code = generate_c_code(headers)
|
||||
os.write(src_file_fd, '\n'.join(code).encode())
|
||||
os.close(src_file_fd)
|
||||
|
||||
js_file_fd, js_file_path = tempfile.mkstemp('.js')
|
||||
# Close the unneeded FD.
|
||||
os.close(js_file_fd)
|
||||
|
||||
cmd = generate_cmd(js_file_path, src_file_path, cflags)
|
||||
|
||||
try:
|
||||
subprocess.check_call(cmd, env=system_libs.clean_env())
|
||||
except subprocess.CalledProcessError as e:
|
||||
sys.stderr.write('FAIL: Compilation failed!: %s\n' % e.cmd)
|
||||
sys.exit(1)
|
||||
|
||||
# Run the compiled program.
|
||||
show('Calling generated program... ' + js_file_path)
|
||||
node_args = shared.node_bigint_flags(config.NODE_JS)
|
||||
info = shared.run_js_tool(js_file_path, node_args=node_args, stdout=shared.PIPE).splitlines()
|
||||
|
||||
if not DEBUG:
|
||||
# Remove all temporary files.
|
||||
os.unlink(src_file_path)
|
||||
|
||||
if os.path.exists(js_file_path):
|
||||
os.unlink(js_file_path)
|
||||
wasm_file_path = shared.replace_suffix(js_file_path, '.wasm')
|
||||
os.unlink(wasm_file_path)
|
||||
|
||||
# Parse the output of the program into a dict.
|
||||
return parse_c_output(info)
|
||||
|
||||
|
||||
def merge_info(target, src):
|
||||
for key, value in src['defines'].items():
|
||||
if key in target['defines']:
|
||||
raise Exception('duplicate define: %s' % key)
|
||||
target['defines'][key] = value
|
||||
|
||||
for key, value in src['structs'].items():
|
||||
if key in target['structs']:
|
||||
raise Exception('duplicate struct: %s' % key)
|
||||
target['structs'][key] = value
|
||||
|
||||
|
||||
def inspect_code(headers, cflags):
|
||||
if not DEBUG:
|
||||
info = inspect_headers(headers, cflags)
|
||||
else:
|
||||
info = {'defines': {}, 'structs': {}}
|
||||
for header in headers:
|
||||
merge_info(info, inspect_headers([header], cflags))
|
||||
return info
|
||||
|
||||
|
||||
def parse_json(path):
|
||||
header_files = []
|
||||
|
||||
with open(path, 'r') as stream:
|
||||
# Remove comments before loading the JSON.
|
||||
data = json.loads(re.sub(r'//.*\n', '', stream.read()))
|
||||
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
|
||||
for item in data:
|
||||
for key in item:
|
||||
if key not in ['file', 'defines', 'structs']:
|
||||
raise 'Unexpected key in json file: %s' % key
|
||||
|
||||
header = {'name': item['file'], 'structs': {}, 'defines': {}}
|
||||
for name, data in item.get('structs', {}).items():
|
||||
if name in header['structs']:
|
||||
show('WARN: Description of struct "' + name + '" in file "' + item['file'] + '" replaces an existing description!')
|
||||
|
||||
header['structs'][name] = data
|
||||
|
||||
for part in item.get('defines', []):
|
||||
if not isinstance(part, list):
|
||||
# If no type is specified, assume integer.
|
||||
part = ['i', part]
|
||||
|
||||
if part[1] in header['defines']:
|
||||
show('WARN: Description of define "' + part[1] + '" in file "' + item['file'] + '" replaces an existing description!')
|
||||
|
||||
header['defines'][part[1]] = part[0]
|
||||
|
||||
header_files.append(header)
|
||||
|
||||
return header_files
|
||||
|
||||
|
||||
def output_json(obj, stream):
|
||||
json.dump(obj, stream, indent=4, sort_keys=True)
|
||||
stream.write('\n')
|
||||
stream.close()
|
||||
|
||||
|
||||
def main(args):
|
||||
global QUIET
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate JSON infos for structs.')
|
||||
parser.add_argument('json', nargs='*',
|
||||
help='JSON file with a list of structs and their fields (defaults to src/struct_info.json)',
|
||||
default=DEFAULT_JSON_FILES)
|
||||
parser.add_argument('-q', dest='quiet', action='store_true', default=False,
|
||||
help='Don\'t output anything besides error messages.')
|
||||
parser.add_argument('-o', dest='output', metavar='path', default=None,
|
||||
help='Path to the JSON file that will be written. If omitted, the default location under `src` will be used.')
|
||||
parser.add_argument('-I', dest='includes', metavar='dir', action='append', default=[],
|
||||
help='Add directory to include search path')
|
||||
parser.add_argument('-D', dest='defines', metavar='define', action='append', default=[],
|
||||
help='Pass a define to the preprocessor')
|
||||
parser.add_argument('-U', dest='undefines', metavar='undefine', action='append', default=[],
|
||||
help='Pass an undefine to the preprocessor')
|
||||
parser.add_argument('--wasm64', action='store_true',
|
||||
help='use wasm64 architecture')
|
||||
args = parser.parse_args(args)
|
||||
|
||||
QUIET = args.quiet
|
||||
|
||||
extra_cflags = []
|
||||
|
||||
if args.wasm64:
|
||||
# Always use =2 here so that we don't generate a binary that actually requires
|
||||
# memory64 to run. All we care about is that the output is correct.
|
||||
extra_cflags += ['-sMEMORY64=2', '-Wno-experimental']
|
||||
|
||||
# Add the user options to the list as well.
|
||||
for path in args.includes:
|
||||
extra_cflags.append('-I' + path)
|
||||
|
||||
for arg in args.defines:
|
||||
extra_cflags.append('-D' + arg)
|
||||
|
||||
for arg in args.undefines:
|
||||
extra_cflags.append('-U' + arg)
|
||||
|
||||
# Look for structs in all passed headers.
|
||||
info = {'defines': {}, 'structs': {}}
|
||||
|
||||
for f in args.json:
|
||||
# This is a JSON file, parse it.
|
||||
header_files = parse_json(f)
|
||||
# Inspect all collected structs.
|
||||
if 'internal' in f:
|
||||
use_cflags = CFLAGS + extra_cflags + INTERNAL_CFLAGS
|
||||
elif 'cxx' in f:
|
||||
use_cflags = CFLAGS + extra_cflags + CXXFLAGS
|
||||
else:
|
||||
use_cflags = CFLAGS + extra_cflags
|
||||
info_fragment = inspect_code(header_files, use_cflags)
|
||||
merge_info(info, info_fragment)
|
||||
|
||||
if args.output:
|
||||
output_file = args.output
|
||||
elif args.wasm64:
|
||||
output_file = utils.path_from_root('src/struct_info_generated_wasm64.json')
|
||||
else:
|
||||
output_file = utils.path_from_root('src/struct_info_generated.json')
|
||||
|
||||
with open(output_file, 'w') as f:
|
||||
output_json(info, f)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main(sys.argv[1:]))
|
||||
78
js/build_webgpu.bat
Normal file
78
js/build_webgpu.bat
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
@echo off
|
||||
|
||||
rem build_webgpu.bat --- build onnxruntime-web with WebGPU EP
|
||||
rem
|
||||
rem Usage:
|
||||
rem build_webgpu.bat config [clean]
|
||||
rem
|
||||
rem Options:
|
||||
rem config Build configuration, "d" or "r"
|
||||
rem clean Perform a clean build, "clean" or empty
|
||||
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
set ROOT=%~dp0..\
|
||||
set BUILD_DIR=%ROOT%build_webgpu
|
||||
|
||||
:arg1
|
||||
if ["%~1"]==["d"] (
|
||||
set CONFIG=Debug
|
||||
set CONFIG_EXTRA_FLAG=--enable_wasm_profiling --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1
|
||||
@rem --enable_wasm_debug_info
|
||||
goto :arg2
|
||||
)
|
||||
if ["%~1"]==["r"] (
|
||||
set CONFIG=Release
|
||||
set CONFIG_EXTRA_FLAG=--enable_wasm_api_exception_catching --disable_rtti
|
||||
goto :arg2
|
||||
)
|
||||
echo Invalid configuration "%~1", must be "d"(Debug) or "r"(Release)
|
||||
exit /b 1
|
||||
|
||||
:arg2
|
||||
if ["%~2"]==["clean"] (
|
||||
goto :clean
|
||||
)
|
||||
if not exist "%ROOT%js\web\dist" (
|
||||
goto :npm_ci
|
||||
)
|
||||
|
||||
goto :build_wasm
|
||||
|
||||
:clean
|
||||
if exist "%BUILD_DIR%" (
|
||||
rd /s /q %BUILD_DIR%
|
||||
)
|
||||
|
||||
pushd %ROOT%
|
||||
git submodule sync --recursive
|
||||
git submodule update --init --recursive
|
||||
popd
|
||||
|
||||
:npm_ci
|
||||
pushd %ROOT%js
|
||||
call npm ci
|
||||
popd
|
||||
pushd %ROOT%js\common
|
||||
call npm ci
|
||||
popd
|
||||
pushd %ROOT%js\web
|
||||
call npm ci
|
||||
call npm run pull:wasm
|
||||
popd
|
||||
|
||||
:build_wasm
|
||||
|
||||
set PATH=C:\Program Files\Git\usr\bin;%PATH%
|
||||
|
||||
call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm^
|
||||
--enable_wasm_simd --enable_wasm_threads --use_webgpu --build_dir %BUILD_DIR%
|
||||
|
||||
@REM --target onnxruntime_webassembly
|
||||
|
||||
IF NOT "%ERRORLEVEL%" == "0" (
|
||||
exit /b %ERRORLEVEL%
|
||||
)
|
||||
|
||||
@REM copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.wasm %ROOT%js\web\dist\
|
||||
@REM copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jsep.mjs %ROOT%js\web\dist\
|
||||
|
|
@ -116,7 +116,7 @@ Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeCo
|
|||
auto* output = context.Output(0, x_shape);
|
||||
auto* input_skip_bias_sum = context.Output(3, x_shape);
|
||||
|
||||
size_t data_size = x_shape.Size();
|
||||
int64_t data_size = x_shape.Size();
|
||||
if (data_size == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten.h>
|
||||
#endif
|
||||
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/providers/webgpu/allocator.h"
|
||||
#include "core/providers/webgpu/webgpu_context.h"
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class SimpleCacheManager : public IBufferCacheManager {
|
|||
|
||||
void OnRefresh() override {
|
||||
for (auto& buffer : pending_buffers_) {
|
||||
buffers_[wgpuBufferGetSize(buffer)].push_back(buffer);
|
||||
buffers_[static_cast<size_t>(wgpuBufferGetSize(buffer))].push_back(buffer);
|
||||
}
|
||||
pending_buffers_.clear();
|
||||
}
|
||||
|
|
@ -167,7 +167,7 @@ class BucketCacheManager : public IBufferCacheManager {
|
|||
// TODO: consider graph capture. currently not supported
|
||||
|
||||
for (auto& buffer : pending_buffers_) {
|
||||
auto buffer_size = wgpuBufferGetSize(buffer);
|
||||
auto buffer_size = static_cast<size_t>(wgpuBufferGetSize(buffer));
|
||||
|
||||
auto it = buckets_.find(buffer_size);
|
||||
if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) {
|
||||
|
|
|
|||
|
|
@ -5,10 +5,6 @@
|
|||
|
||||
#include <iosfwd>
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include "core/framework/execution_provider.h"
|
||||
|
|
|
|||
|
|
@ -3,10 +3,6 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include <utility>
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten.h>
|
||||
#endif
|
||||
|
||||
#include "core/providers/webgpu/data_transfer.h"
|
||||
#include "core/providers/webgpu/webgpu_context.h"
|
||||
|
||||
|
|
|
|||
|
|
@ -123,10 +123,10 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
|
|||
if (a_last_dim_divisible_by_4 || b_last_dim_divisible_by_4) {
|
||||
vectorize = true;
|
||||
} else {
|
||||
size_t shared_dimension = 1;
|
||||
int64_t shared_dimension = 1;
|
||||
for (size_t i = 1; i < output_shape.NumDimensions(); i++) {
|
||||
size_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1;
|
||||
size_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1;
|
||||
int64_t dimA = lhs_shape.NumDimensions() >= i ? lhs_shape[lhs_shape.NumDimensions() - i] : 1;
|
||||
int64_t dimB = rhs_shape.NumDimensions() >= i ? rhs_shape[rhs_shape.NumDimensions() - i] : 1;
|
||||
if (dimA == dimB) {
|
||||
shared_dimension *= dimA;
|
||||
num_shared_dimension++;
|
||||
|
|
|
|||
|
|
@ -85,8 +85,7 @@ Status LayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContex
|
|||
|
||||
auto* output = context.Output(0, x_shape);
|
||||
|
||||
size_t data_size = x_shape.Size();
|
||||
if (data_size == 0) {
|
||||
if (x_shape.Size() == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,10 +3,6 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,6 @@
|
|||
|
||||
#include <sstream>
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <webgpu/webgpu_cpp.h>
|
||||
|
||||
#include "core/framework/tensor_shape.h"
|
||||
|
|
|
|||
|
|
@ -106,7 +106,8 @@ Status Concat::ComputeInternal(ComputeContext& context) const {
|
|||
|
||||
uint32_t output_size = gsl::narrow_cast<int32_t>(prepare.output_tensor->Shape().Size());
|
||||
|
||||
ConcatProgram program{prepare.axis};
|
||||
size_t axis = static_cast<size_t>(prepare.axis);
|
||||
ConcatProgram program{axis};
|
||||
|
||||
std::vector<uint32_t> sizes_in_concat_axis;
|
||||
sizes_in_concat_axis.reserve(input_count);
|
||||
|
|
@ -118,7 +119,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const {
|
|||
}
|
||||
program.AddInput({input.tensor, ProgramTensorMetadataDependency::TypeAndRank});
|
||||
|
||||
auto axis_size = input.tensor->Shape()[prepare.axis];
|
||||
auto axis_size = input.tensor->Shape()[axis];
|
||||
sum += static_cast<uint32_t>(axis_size);
|
||||
sizes_in_concat_axis.push_back(sum);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,15 +31,8 @@ class Flatten final : public OpKernel {
|
|||
return Status(common::ONNXRUNTIME, common::FAIL, "Invalid value for axis, must be less than or equal to input_rank");
|
||||
}
|
||||
|
||||
int64_t first_dim = 1;
|
||||
for (int64_t i = 0; i < axis; i++) {
|
||||
first_dim *= input_shape[i];
|
||||
}
|
||||
|
||||
int64_t second_dim = 1;
|
||||
for (int64_t i = axis; i < input_rank; i++) {
|
||||
second_dim *= input_shape[i];
|
||||
}
|
||||
int64_t first_dim = input_shape.SizeToDimension(static_cast<size_t>(axis));
|
||||
int64_t second_dim = input_shape.SizeFromDimension(static_cast<size_t>(axis));
|
||||
|
||||
TensorShape output_shape({first_dim, second_dim});
|
||||
Tensor* output_tensor = context->Output(0, output_shape);
|
||||
|
|
@ -59,4 +52,4 @@ class Flatten final : public OpKernel {
|
|||
};
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ Status GatherElements::ComputeInternal(ComputeContext& context) const {
|
|||
axis += input_rank;
|
||||
}
|
||||
|
||||
auto axis_dim_limit = input_shape[axis];
|
||||
auto axis_dim_limit = input_shape[static_cast<size_t>(axis)];
|
||||
|
||||
auto output_dims = indices_shape.AsShapeVector();
|
||||
TensorShape output_shape(output_dims);
|
||||
|
|
@ -83,4 +83,4 @@ Status GatherElements::ComputeInternal(ComputeContext& context) const {
|
|||
}
|
||||
|
||||
} // namespace webgpu
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
#include <memory>
|
||||
#include <cmath>
|
||||
|
||||
#if !defined(__wasm__)
|
||||
#include "dawn/dawn_proc.h"
|
||||
#if !defined(USE_EXTERNAL_DAWN)
|
||||
#include "dawn/native/DawnNative.h"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/common/path_string.h"
|
||||
|
|
@ -29,7 +31,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() {
|
||||
// Create wgpu::Adapter
|
||||
if (adapter_ == nullptr) {
|
||||
#if !defined(__EMSCRIPTEN__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
|
||||
#if !defined(__wasm__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
|
||||
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
|
||||
//
|
||||
// Dawn will try to load them later, but if they are in the different directory to the executable, it may fail to find them.
|
||||
|
|
@ -54,15 +56,19 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
#endif
|
||||
|
||||
wgpu::RequestAdapterOptions req_adapter_options = {};
|
||||
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
|
||||
req_adapter_options.nextInChain = &adapter_toggles_desc;
|
||||
req_adapter_options.backendType = static_cast<wgpu::BackendType>(backend_type);
|
||||
req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance;
|
||||
|
||||
#if !defined(__wasm__)
|
||||
auto enabled_adapter_toggles = GetEnabledAdapterToggles();
|
||||
|
||||
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
|
||||
adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size();
|
||||
adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data();
|
||||
|
||||
req_adapter_options.nextInChain = &adapter_toggles_desc;
|
||||
#endif
|
||||
|
||||
ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(
|
||||
&req_adapter_options,
|
||||
wgpu::CallbackMode::WaitAnyOnly,
|
||||
|
|
@ -78,6 +84,8 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
// Create wgpu::Device
|
||||
if (device_ == nullptr) {
|
||||
wgpu::DeviceDescriptor device_desc = {};
|
||||
|
||||
#if !defined(__wasm__)
|
||||
wgpu::DawnTogglesDescriptor device_toggles_desc = {};
|
||||
device_desc.nextInChain = &device_toggles_desc;
|
||||
|
||||
|
|
@ -88,6 +96,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
auto disabled_device_toggles = GetDisabledDeviceToggles();
|
||||
device_toggles_desc.disabledToggleCount = disabled_device_toggles.size();
|
||||
device_toggles_desc.disabledToggles = disabled_device_toggles.data();
|
||||
#endif
|
||||
|
||||
std::vector<wgpu::FeatureName> required_features = GetAvailableRequiredFeatures(adapter_);
|
||||
if (required_features.size() > 0) {
|
||||
|
|
@ -136,9 +145,12 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi
|
|||
program_mgr_ = std::make_unique<ProgramManager>(Device(), DeviceLimits());
|
||||
|
||||
// set query type
|
||||
#if !defined(__wasm__)
|
||||
if (device_.HasFeature(wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses)) {
|
||||
query_type_ = TimestampQueryType::InsidePasses;
|
||||
} else if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) {
|
||||
} else
|
||||
#endif
|
||||
if (device_.HasFeature(wgpu::FeatureName::TimestampQuery)) {
|
||||
query_type_ = TimestampQueryType::AtPasses;
|
||||
} else {
|
||||
query_type_ = TimestampQueryType::None;
|
||||
|
|
@ -456,7 +468,9 @@ std::vector<const char*> WebGpuContext::GetDisabledDeviceToggles() const {
|
|||
std::vector<wgpu::FeatureName> WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const {
|
||||
std::vector<wgpu::FeatureName> required_features;
|
||||
constexpr wgpu::FeatureName features[]{
|
||||
#if !defined(__wasm__)
|
||||
wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses,
|
||||
#endif
|
||||
wgpu::FeatureName::TimestampQuery,
|
||||
wgpu::FeatureName::ShaderF16,
|
||||
wgpu::FeatureName::Subgroups,
|
||||
|
|
@ -531,7 +545,7 @@ void WebGpuContext::CollectProfilingData(profiling::Events& events) {
|
|||
|
||||
ORT_ENFORCE(Wait(query_read_buffer.MapAsync(wgpu::MapMode::Read,
|
||||
0,
|
||||
query_read_buffer.GetSize(),
|
||||
static_cast<size_t>(query_read_buffer.GetSize()),
|
||||
wgpu::CallbackMode::WaitAnyOnly,
|
||||
[](wgpu::MapAsyncStatus status, wgpu::StringView message) {
|
||||
ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", std::string_view{message});
|
||||
|
|
@ -658,8 +672,14 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
|
|||
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
|
||||
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");
|
||||
|
||||
std::call_once(init_default_flag_, [dawn_proc_table = config.dawn_proc_table]() {
|
||||
// Step.1 - setup dawn proc table
|
||||
std::call_once(init_default_flag_, [
|
||||
#if !defined(__wasm__)
|
||||
dawn_proc_table = config.dawn_proc_table
|
||||
#endif
|
||||
]() {
|
||||
// Step.1 - setup dawn proc table (only for non-WASM build)
|
||||
|
||||
#if !defined(__wasm__)
|
||||
const DawnProcTable* dawn_procs = reinterpret_cast<const DawnProcTable*>(dawn_proc_table);
|
||||
#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY)
|
||||
ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn.");
|
||||
|
|
@ -672,6 +692,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co
|
|||
ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided.");
|
||||
#endif
|
||||
dawnProcSetProcs(dawn_procs);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Step.2 - Create wgpu::Instance
|
||||
|
|
|
|||
|
|
@ -3,10 +3,6 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten/emscripten.h>
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,6 @@
|
|||
|
||||
#include "core/providers/webgpu/webgpu_execution_provider.h"
|
||||
|
||||
#ifdef __EMSCRIPTEN__
|
||||
#include <emscripten.h>
|
||||
#endif
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
|
|
|||
Loading…
Reference in a new issue