update xnnpack to newer version and update API usage in pytorch (#94330)

Summary:
Update XNNPACK to 51a987591a6fc9f0fc0707077f53d763ac132cbf (51a987591a)

Update the corresponding CMake and BUCK rules, as well as the generate_wrapper.py for the new version.

Due to XNNPACK having already changed a lot. We need to update XNNPACK in this time for many reasons. Firstly, XNNAPCK has updated a lot, and developers' community has re-factored codes' such as API changes. We can see from their cmakefile.txt to see there are many changes! Thus, in order to follow up upstream. We need to update xnnpack at this time. It is very crucial for our future development. Also, many projects are relying on newer versions of XNNPACK, so we probably need to update XNNPACK third-party libs at this time. we have some api changes of XNNPACK, so we also need to update them in this time. We also update target building files and generate-wrapper.py file to make this process more automatically. The original target files have some files which are missing, so we add them into buck2 building files so that it can build and test XNNPACK successfully.

Test Plan:
buck2 build //xplat/third-party/XNNPACK:operators
buck2 build //xplat/third-party/XNNPACK:XNNPACK
buck2 test fbcode//caffe2/test:xnnpack_integration

Reviewed By: digantdesai

Differential Revision: D43092938

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94330
Approved by: https://github.com/digantdesai, https://github.com/albanD
This commit is contained in:
Cuiqing Li 2023-02-11 08:59:35 +00:00 committed by PyTorch MergeBot
parent e7a8af9376
commit 9dd7e83676
8 changed files with 13798 additions and 1602 deletions

View file

@ -99,6 +99,7 @@ enum xnn_status xnnp_create_convolution2d_nhwc(
op_min, /* int8_t output_min */
op_max, /* int8_t output_max */
flags, /* uint32_t flags */
nullptr, /* xnn_caches_t caches */
op); /* xnn_operator_t* deconvolution_op_out */
}
@ -130,6 +131,7 @@ enum xnn_status xnnp_create_convolution2d_nhwc(
op_min, /* int8_t output_min */
op_max, /* int8_t output_max */
flags, /* uint32_t flags */
nullptr, /* xnn_caches_t caches */
op); /* xnn_operator_t* convolution_op_out */
} else { /* per_channel */
return xnn_create_convolution2d_nhwc_qc8(
@ -158,6 +160,7 @@ enum xnn_status xnnp_create_convolution2d_nhwc(
op_min, /* int8_t output_min */
op_max, /* int8_t output_max */
flags, /* uint32_t flags */
nullptr, /* xnn_caches_t caches */
op); /* xnn_operator_t* convolution_op_out */
}
}
@ -254,6 +257,7 @@ enum xnn_status xnnp_create_fully_connected_nc(
output_min, /* int8_t output_min */
output_max, /* int8_t output_max */
flags, /* uint32_t flags */
nullptr, /* xnn_caches_t caches */
fully_connected_op_out); /* xnn_operator_t* fully_connected_op_out */
}

View file

@ -236,6 +236,7 @@ ContextConv2D create(
output_min, // output_min
output_max, // output_max
0u, // flags
nullptr, // xnn_caches_t
&convolution_op); // operator
} else {
for (const auto i : c10::irange(4)) {
@ -264,6 +265,7 @@ ContextConv2D create(
output_min, // output_min
output_max, // output_max
0u, // flags
nullptr, // xnn_caches_t
&convolution_op); // operator
}

View file

@ -97,6 +97,7 @@ ContextLinear create(
output_min, // output_min
output_max, // output_max
0u, // flags
nullptr, // xnn_caches_t
&linear_op); // operator
TORCH_CHECK(

2
third_party/XNNPACK vendored

@ -1 +1 @@
Subproject commit ae108ef49aa5623b896fc93d4298c49d1750d9ba
Subproject commit 51a987591a6fc9f0fc0707077f53d763ac132cbf

View file

@ -4,6 +4,7 @@ from __future__ import print_function
import collections
import os
import sys
import logging
BANNER = "Auto-generated by generate-wrappers.py script. Do not modify"
WRAPPER_SRC_NAMES = {
@ -11,6 +12,7 @@ WRAPPER_SRC_NAMES = {
"PROD_SCALAR_AARCH32_MICROKERNEL_SRCS" : "defined(__arm__)",
"PROD_NEON_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_NEONFP16_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_NEON_AARCH64_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_NEONFMA_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_AARCH64_NEON_MICROKERNEL_SRCS": "defined(__aarch64__)",
"PROD_NEONV8_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
@ -27,14 +29,50 @@ WRAPPER_SRC_NAMES = {
"PROD_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
"AARCH64_ASM_MICROKERNEL_SRCS": "defined(__aarch64__)",
# add additoonal:
"PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"ALL_ARMSIMD32_MICROKERNEL_SRCS": "defined(__arm__)",
"ALL_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"ALL_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
"ALL_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_AVX512SKX_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_AVX512VBMI_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_F16C_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_FMA3_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_FP16ARITH_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEON_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEON_AARCH64_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_NEONBF16_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONDOT_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFMA_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_NEONFP16_MICROKERNEL_SRCS':"defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFP16ARITH_MICROKERNEL_SRCS': "defined(__arm__) || defined(__aarch64__)",
'ALL_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_NEONV8_MICROKERNEL_SRCS': "defined(__aarch64__)",
'ALL_SCALAR_MICROKERNEL_SRCS': "defined(__arm__)",
'ALL_SSE_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_SSE2_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_SSE41_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_SSSE3_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'ALL_XOP_MICROKERNEL_SRCS': "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
'AARCH32_ASM_MICROKERNEL_SRCS': "defined(__arm__)",
"PROD_FP16ARITH_MICROKERNEL_SRCS": "defined(__aarch64__)",
"PROD_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
"PROD_SCALAR_MICROKERNEL_SRCS": "defined(__arm__)",
}
SRC_NAMES = [
SRC_NAMES = set([
"OPERATOR_SRCS",
"SUBGRAPH_SRCS",
"LOGGING_SRCS",
"XNNPACK_SRCS",
"HOT_SRCS",
"TABLE_SRCS",
"JIT_SRCS",
@ -52,15 +90,83 @@ SRC_NAMES = [
"PROD_AVX2_MICROKERNEL_SRCS",
"PROD_AVX512F_MICROKERNEL_SRCS",
"PROD_AVX512SKX_MICROKERNEL_SRCS",
]
"PROD_SCALAR_MICROKERNEL_SRCS",
"PROD_SCALAR_AARCH32_MICROKERNEL_SRCS",
"PROD_SCALAR_RISCV_MICROKERNEL_SRCS",
"PROD_ARMSIMD32_MICROKERNEL_SRCS",
"PROD_FP16ARITH_MICROKERNEL_SRCS",
"PROD_NEON_MICROKERNEL_SRCS",
"PROD_NEONFP16_MICROKERNEL_SRCS",
"PROD_NEONFMA_MICROKERNEL_SRCS",
"PROD_NEON_AARCH64_MICROKERNEL_SRCS",
"PROD_NEONV8_MICROKERNEL_SRCS",
"PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS",
"PROD_NEONDOT_MICROKERNEL_SRCS",
"PROD_SSE2_MICROKERNEL_SRCS",
"PROD_SSSE3_MICROKERNEL_SRCS",
"PROD_SSE41_MICROKERNEL_SRCS",
"PROD_AVX_MICROKERNEL_SRCS",
"PROD_F16C_MICROKERNEL_SRCS",
"PROD_AVX512VBMI_MICROKERNEL_SRCS",
"PROD_NEONFP16ARITH_MICROKERNEL_SRCS",
def update_sources(xnnpack_path):
# new adding libs:
'ALL_ARMSIMD32_MICROKERNEL_SRCS',
'ALL_AVX_MICROKERNEL_SRCS',
'ALL_AVX2_MICROKERNEL_SRCS',
'ALL_AVX512F_MICROKERNEL_SRCS',
'ALL_AVX512SKX_MICROKERNEL_SRCS',
'ALL_AVX512VBMI_MICROKERNEL_SRCS',
'ALL_F16C_MICROKERNEL_SRCS',
'ALL_FMA3_MICROKERNEL_SRCS',
'ALL_FP16ARITH_MICROKERNEL_SRCS',
'ALL_HEXAGON_MICROKERNEL_SRCS',
'ALL_NEON_MICROKERNEL_SRCS',
'ALL_NEON_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONBF16_MICROKERNEL_SRCS',
'ALL_NEONBF16_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONDOT_MICROKERNEL_SRCS',
'ALL_NEONFMA_MICROKERNEL_SRCS',
'ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONFP16_MICROKERNEL_SRCS',
'ALL_NEONFP16ARITH_MICROKERNEL_SRCS',
'ALL_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS',
'ALL_NEONV8_MICROKERNEL_SRCS',
'ALL_SCALAR_MICROKERNEL_SRCS',
'ALL_SSE_MICROKERNEL_SRCS',
'ALL_SSE2_MICROKERNEL_SRCS',
'ALL_SSE41_MICROKERNEL_SRCS',
'ALL_SSSE3_MICROKERNEL_SRCS',
'ALL_WASM_MICROKERNEL_SRCS',
'ALL_WASMRELAXEDSIMD_MICROKERNEL_SRCS',
'ALL_WASMSIMD_MICROKERNEL_SRCS',
'ALL_XOP_MICROKERNEL_SRCS',
'AARCH32_ASM_MICROKERNEL_SRCS',
'AARCH64_ASM_MICROKERNEL_SRCS',
])
def handle_singleline_parse(line):
start_index = line.find("(")
end_index = line.find(")")
line = line[start_index+1:end_index]
key_val = line.split(" ")
return key_val[0], key_val[1][4:]
def update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
sources = collections.defaultdict(list)
with open(os.path.join(xnnpack_path, "XNNPACK/CMakeLists.txt")) as cmake:
count = 0
with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
lines = cmake.readlines()
i = 0
while i < len(lines):
line = lines[i]
if lines[i].startswith("SET") and "src/" in lines[i]:
name, val = handle_singleline_parse(line)
sources[name].append(val)
i+=1
continue
if line.startswith("SET") and line.split('(')[1].strip(' \t\n\r') in set(WRAPPER_SRC_NAMES.keys()) | set(SRC_NAMES):
name = line.split('(')[1].strip(' \t\n\r')
i += 1
@ -80,11 +186,19 @@ def update_sources(xnnpack_path):
def gen_wrappers(xnnpack_path):
xnnpack_sources = collections.defaultdict(list)
sources = update_sources(xnnpack_path)
microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
for key in microkernels_sources:
sources[key] = microkernels_sources[key]
for name in WRAPPER_SRC_NAMES:
xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name])
for condition, filenames in xnnpack_sources.items():
print(condition)
for filename in filenames:
filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename)
if not os.path.isdir(os.path.dirname(filepath)):
os.makedirs(os.path.dirname(filepath))
with open(filepath, "w") as wrapper:

View file

@ -35,6 +35,10 @@ load(
"PROD_SSE_MICROKERNEL_SRCS",
"PROD_SSSE3_MICROKERNEL_SRCS",
"PROD_XOP_MICROKERNEL_SRCS",
"ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS",
"ALL_NEON_AARCH64_MICROKERNEL_SRCS",
"PROD_AVX512VBMI_MICROKERNEL_SRCS",
"ALL_AVX512VBMI_MICROKERNEL_SRCS",
)
# This defines XNNPACK targets for both fbsource BUCK and OSS BUCK
@ -99,6 +103,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
preferred_linkage = "static",
preprocessor_flags = [
"-DXNN_LOG_LEVEL=0",
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION=0",
],
visibility = ["PUBLIC"],
windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS,
@ -131,6 +136,9 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
preferred_linkage = "static",
preprocessor_flags = [
"-DXNN_LOG_LEVEL=0",
"-DXNN_ENABLE_JIT=0",
"-DXNN_ENABLE_SPARSE=0",
"-DXNN_ENABLE_MEMOPT",
],
visibility = ["PUBLIC"],
windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS,
@ -1088,6 +1096,78 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
],
)
fb_xplat_cxx_library(
name = "ukernels_avx512vbmi",
srcs = (select({
"DEFAULT": [],
"ovr_config//os:macos-x86_64": PROD_AVX512VBMI_MICROKERNEL_SRCS,
}) if is_arvr_mode() else []),
headers = subdir_glob([
("XNNPACK/src", "**/*.c"),
("XNNPACK/src", "**/*.h"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
compiler_flags = [
"-O2",
"-mavx512f",
"-mavx512cd",
"-mavx512bw",
"-mavx512dq",
"-mavx512vl",
"-mavx512vbmi",
],
fbobjc_preprocessor_flags = [
"-DXNN_PRIVATE=",
"-DXNN_INTERNAL=",
],
labels = labels,
platform_compiler_flags = [
(
"^(i[3-6]86|x86|x86_64|AMD64)$",
[
"-mavx512f",
"-mavx512cd",
"-mavx512bw",
"-mavx512dq",
"-mavx512vl",
"-mavx512vbmi",
],
),
],
platform_srcs = ([
(
"x86|x86_64|platform009|platform010",
PROD_AVX512VBMI_MICROKERNEL_SRCS,
),
] if not is_arvr_mode() else []),
preferred_linkage = "static",
preprocessor_flags = [
"-DXNN_LOG_LEVEL=0",
],
visibility = ["PUBLIC"],
windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS + [
"-mavx512f",
"-mavx512cd",
"-mavx512bw",
"-mavx512dq",
"-mavx512vl",
"-mavx512vbmi",
],
windows_compiler_flags_override = WINDOWS_FLAGS + [
"-mavx512f",
"-mavx512cd",
"-mavx512bw",
"-mavx512dq",
"-mavx512vl",
"-mavx512vbmi",
],
deps = [
":interface",
],
)
fb_xplat_cxx_library(
name = "ukernels_avx512_ovr_win32",
headers = subdir_glob([
@ -1474,7 +1554,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
fb_xplat_cxx_library(
name = "ukernels_neon_aarch64",
srcs = PROD_AARCH64_NEON_MICROKERNEL_SRCS,
srcs = ALL_NEON_AARCH64_MICROKERNEL_SRCS,
headers = subdir_glob([
("XNNPACK/src", "**/*.c"),
("XNNPACK/src", "**/*.h"),
@ -1589,6 +1669,47 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
],
)
fb_xplat_cxx_library(
name = "ukernels_neonfma_aarch64",
srcs = ALL_NEONFMA_AARCH64_MICROKERNEL_SRCS,
headers = subdir_glob([
("XNNPACK/src", "**/*.h"),
("XNNPACK/src", "**/*.c"),
]),
header_namespace = "",
apple_sdks = (IOS, MACOSX, APPLETVOS),
compiler_flags = [
"-O2",
],
fbobjc_preprocessor_flags = [
"-DXNN_PRIVATE=",
"-DXNN_INTERNAL=",
],
labels = labels,
platform_compiler_flags = [
(
"^(android-armv8|iphoneos-armv8)$",
[
"-march=armv8-a",
"-mfpu=neon-fp-armv8",
"-mfloat-abi=softfp",
],
),
],
platforms = (APPLE, ANDROID, CXX, WINDOWS),
preferred_linkage = "static",
preprocessor_flags = [
"-DXNN_LOG_LEVEL=0",
],
visibility = ["PUBLIC"],
windows_clang_compiler_flags_override = WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS,
windows_compiler_flags_override = WINDOWS_FLAGS,
deps = [
":interface",
third_party("FP16"),
],
)
fb_xplat_cxx_library(
name = "ukernels_asm_aarch32",
srcs = AARCH32_ASM_MICROKERNEL_SRCS,
@ -1686,6 +1807,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
":ukernels_neon_fp16",
":ukernels_neon_fp16arith_aarch64",
":ukernels_neon_v8",
":ukernels_neonfma_aarch64",
],
)
@ -1707,6 +1829,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
":ukernels_sse41",
":ukernels_ssse3",
":ukernels_xop",
":ukernels_avx512vbmi",
],
)
@ -1728,6 +1851,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
":ukernels_sse_ovr_win32",
":ukernels_ssse3_ovr_win32",
":ukernels_xop_ovr_win32",
":ukernels_avx512vbmi",
],
)
@ -1749,6 +1873,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
":ukernels_neon_fp16arith_aarch64",
":ukernels_neon_v8",
":ukernels_scalar_aarch32",
":ukernels_neonfma_aarch64",
],
)
@ -1820,15 +1945,30 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F
"-DXNN_NO_X8_OPERATORS",
"-DXNN_NO_XX_OPERATORS",
"-DXNN_ENABLE_MEMOPT",
"-DXNN_ENABLE_SPARSE=0",
"-DXNN_ENABLE_JIT=0",
"-DXNN_ENABLE_ASSEMBLY",
"-DXNN_ENABLE_GEMM_M_SPECIALIZATION",
"-DXNN_ENABLE_ARM_DOTPROD",
],
srcs = [
"XNNPACK/src/allocator.c",
"XNNPACK/src/init.c",
"XNNPACK/src/memory-planner.c",
"XNNPACK/src/operator-delete.c",
"XNNPACK/src/runtime.c",
"XNNPACK/src/subgraph.c",
"XNNPACK/src/tensor.c",
"XNNPACK/src/params.c",
"XNNPACK/src/operator-run.c",
"XNNPACK/src/microparams-init.c",
"XNNPACK/src/binary-elementwise-config.c",
"XNNPACK/src/packing.c",
"XNNPACK/src/indirection.c",
"XNNPACK/src/cache.c",
"XNNPACK/src/mutex.c",
"XNNPACK/src/operator-utils.c",
"XNNPACK/src/memory.c",
"XNNPACK/src/hardware-config.c",
"XNNPACK/src/x8-lut-config.c",
"XNNPACK/src/normalization.c",
"XNNPACK/src/transpose-config.c",
"XNNPACK/src/amalgam/scalar.c",
] + LOGGING_SRCS,
visibility = ["PUBLIC"],
windows_clang_compiler_flags_override = (WINDOWS_FLAGS + WINDOWS_CLANG_COMPILER_FLAGS) if XNNPACK_WINDOWS_AVX512F_ENABLED else WINDOWS_FLAGS,

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff