diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b6d7d93438..ea88454d227 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -423,7 +423,7 @@ if(USE_PYTORCH_QNNPACK) endif() if(USE_XNNPACK) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL") endif() # ---[ Whitelist file if whitelist is specified diff --git a/aten/src/ATen/native/xnnpack/Common.h b/aten/src/ATen/native/xnnpack/Common.h index 09472d31368..9484d99c1f6 100644 --- a/aten/src/ATen/native/xnnpack/Common.h +++ b/aten/src/ATen/native/xnnpack/Common.h @@ -5,6 +5,7 @@ #ifdef USE_XNNPACK #include +#include "caffe2/utils/threadpool/ThreadPoolXNNPACK.h" namespace at { namespace native { diff --git a/aten/src/ATen/native/xnnpack/Convolution.cpp b/aten/src/ATen/native/xnnpack/Convolution.cpp index a8239ea8738..82c8d9b1ad4 100644 --- a/aten/src/ATen/native/xnnpack/Convolution.cpp +++ b/aten/src/ATen/native/xnnpack/Convolution.cpp @@ -110,15 +110,15 @@ Tensor run( padded_input_nhwc.size(Layout::Activation4D::width), // input_width padded_input_nhwc.data_ptr(), // input output.data_ptr(), // output - nullptr); // threadpool + caffe2::xnnpack_threadpool()); // threadpool TORCH_CHECK( xnn_status_success == setup_status, "xnn_setup_convolution2d_nhwc_f32 failed!"); const xnn_status run_status = xnn_run_operator( - context.op.get(), // operator - nullptr); // threadpool + context.op.get(), // operator + caffe2::xnnpack_threadpool()); // threadpool TORCH_INTERNAL_ASSERT( xnn_status_success == run_status, diff --git a/aten/src/ATen/native/xnnpack/Linear.cpp b/aten/src/ATen/native/xnnpack/Linear.cpp index ff34c7ff897..b62746f1d43 100644 --- a/aten/src/ATen/native/xnnpack/Linear.cpp +++ b/aten/src/ATen/native/xnnpack/Linear.cpp @@ -72,15 +72,15 @@ Tensor run( Layout::ActivationND::batch(padded_input.sizes()), // Batch, padded_input.data_ptr(), // input output.data_ptr(), // output - nullptr); // threadpool + caffe2::xnnpack_threadpool()); // threadpool TORCH_CHECK( xnn_status_success == setup_status, "xnn_setup_fully_connected_nc_f32 failed!"); const xnn_status run_status = xnn_run_operator( - context.op.get(), // operator - nullptr); // threadpool + context.op.get(), // operator + caffe2::xnnpack_threadpool()); // threadpool TORCH_INTERNAL_ASSERT( xnn_status_success == run_status, diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 7d13acea9ae..320bba10d8e 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -434,6 +434,15 @@ if(USE_XNNPACK) "${CONFU_DEPENDENCIES_BINARY_DIR}/XNNPACK") set_property(TARGET XNNPACK PROPERTY POSITION_INDEPENDENT_CODE ON) + # Context: pthreadpool_get_threads_count implementation that is built in pytorch, uses + # implementation defined in caffe2/utils/threadpool/pthreadpool_impl.cc. This implementation + # assumes the the pthreadpool* passed is of type caffe2::ThradPool and thus does reinterpret cast. + # This is not valid when we create pthreadpool via caffe2::xnnpack_threadpool, which is of type + # compatible with new pthreadpool interface and is used in PT's XNNPACK integration. + # Thus all the calls for pthreadpool_get_threads_count originating from XNNPACK must be routed + # appropriately to pthreadpool_get_threads_count_xnnpack, which does not do the aforementioned + # casting to caffe2::ThradPool. Once the threadpools are unified, we will not need this. + target_compile_definitions(XNNPACK PRIVATE -Dpthreadpool_get_threads_count=pthreadpool_get_threads_count_xnnpack) endif() include_directories(SYSTEM ${XNNPACK_INCLUDE_DIR})