diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 726972af2d..9377980eca 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -184,6 +184,9 @@ option(onnxruntime_ENABLE_CUDA_PROFILING "Enable CUDA kernel profiling" OFF) option(onnxruntime_ENABLE_CPUINFO "Enable cpuinfo" ON) +# ATen fallback support +option(onnxruntime_ENABLE_ATEN "Enable ATen fallback" OFF) + if (onnxruntime_USE_CUDA) set(onnxruntime_DISABLE_RTTI OFF) endif() @@ -906,7 +909,7 @@ if (onnxruntime_ENABLE_CPUINFO) else() # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo # so we don't set CPUINFO_SUPPORTED in the CXX flags below. - if (onnxruntime_BUILD_WEBASSEMBLY AND NOT onnxruntime_USE_XNNPACK) + if (onnxruntime_BUILD_WEBASSEMBLY AND NOT onnxruntime_USE_XNNPACK) set(CPUINFO_SUPPORTED FALSE) else() set(CPUINFO_SUPPORTED TRUE) @@ -937,8 +940,8 @@ if (CPUINFO_SUPPORTED) set(IOS ON CACHE INTERNAL "") set(IOS_ARCH "${CMAKE_OSX_ARCHITECTURES}" CACHE INTERNAL "") endif() - - # if this is a wasm build with xnnpack (only type of wasm build where cpuinfo is involved) + + # if this is a wasm build with xnnpack (only type of wasm build where cpuinfo is involved) # we do not use cpuinfo in ORT code, so don't define CPUINFO_SUPPORTED. if (NOT onnxruntime_BUILD_WEBASSEMBLY) string(APPEND CMAKE_CXX_FLAGS " -DCPUINFO_SUPPORTED") @@ -985,8 +988,8 @@ include(eigen) set(onnxruntime_EXTERNAL_LIBRARIES onnx onnx_proto ${PROTOBUF_LIB} re2::re2) if(NOT onnxruntime_DISABLE_ABSEIL) - set(ABSEIL_LIBS absl::inlined_vector absl::flat_hash_set - absl::flat_hash_map absl::node_hash_set absl::node_hash_map absl::base absl::throw_delegate absl::raw_hash_set + set(ABSEIL_LIBS absl::inlined_vector absl::flat_hash_set + absl::flat_hash_map absl::node_hash_set absl::node_hash_map absl::base absl::throw_delegate absl::raw_hash_set absl::hash absl::city absl::low_level_hash absl::raw_logging_internal) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${ABSEIL_LIBS}) else() @@ -1301,7 +1304,7 @@ function(onnxruntime_configure_target target_name) set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELWITHDEBINFO TRUE) set_target_properties(${target_name} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_MINSIZEREL TRUE) endif() - + # Keep BinSkim happy if(MSVC AND NOT onnxruntime_target_platform MATCHES "ARM") target_link_options(${target_name} PRIVATE "/CETCOMPAT") @@ -1542,7 +1545,7 @@ if (onnxruntime_USE_XNNPACK) set(XNNPACK_DIR external/XNNPACK) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) - set(XNNPACK_USE_SYSTEM_LIBS ON CACHE INTERNAL "") + set(XNNPACK_USE_SYSTEM_LIBS ON CACHE INTERNAL "") set(XNNPACK_BUILD_TESTS OFF CACHE INTERNAL "") set(XNNPACK_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") @@ -1550,14 +1553,14 @@ if (onnxruntime_USE_XNNPACK) set(CLOG_SOURCE_DIR "${PYTORCH_CPUINFO_DIR}/deps/clog") set(CPUINFO_SOURCE_DIR ${PYTORCH_CPUINFO_DIR}) - add_subdirectory(external/FP16) + add_subdirectory(external/FP16) add_subdirectory(external/pthreadpool) add_subdirectory(external/XNNPACK) set_target_properties(fp16 PROPERTIES FOLDER "External/Xnnpack") set_target_properties(pthreadpool PROPERTIES FOLDER "External/Xnnpack") set_target_properties(XNNPACK PROPERTIES FOLDER "External/Xnnpack") - + set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK}) @@ -1570,19 +1573,19 @@ if (onnxruntime_USE_XNNPACK) message("Adding WebAssembly Source Files to XNNPACK") set(wasm_src_patterns "${XNNPACK_DIR}/src/wasm-*.c" "${XNNPACK_DIR}/src/*-wasm-*.c" - "${XNNPACK_DIR}/src/*-wasm.c") + "${XNNPACK_DIR}/src/*-wasm.c") set(wasm32_asm_src_patterns "${XNNPACK_DIR}/src/wasm_shr_*.S") file(GLOB_RECURSE XNNPACK_WASM_MICROKERNEL_SRCS CONFIGURE_DEPENDS ${wasm_src_patterns}) file(GLOB_RECURSE XNNPACK_WASM32_ASM_MICROKERNEL_SRCS CONFIGURE_DEPENDS ${wasm32_asm_src_patterns}) - + message(DEBUG "XNNPACK_WASM_MICROKERNEL_SRCS:${XNNPACK_WASM_MICROKERNEL_SRCS}") message(DEBUG "XNNPACK_WASM32_ASM_MICROKERNEL_SRCS:${XNNPACK_WASM32_ASM_MICROKERNEL_SRCS}") target_sources(XNNPACK PRIVATE ${XNNPACK_WASM_MICROKERNEL_SRCS} ${XNNPACK_WASM32_ASM_MICROKERNEL_SRCS}) - if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) target_compile_options(XNNPACK PRIVATE "-msimd128") set(wasmsimd_src_patterns "${XNNPACK_DIR}/src/wasmsimd-*.c" @@ -1591,7 +1594,7 @@ if (onnxruntime_USE_XNNPACK) file(GLOB_RECURSE XNNPACK_WASMSIMD_MICROKERNEL_SRCS CONFIGURE_DEPENDS ${wasmsimd_src_patterns}) message(DEBUG "XNNPACK_WASMSIMD_MICROKERNEL_SRCS:${XNNPACK_WASMSIMD_MICROKERNEL_SRCS}") - + target_sources(XNNPACK PRIVATE ${XNNPACK_WASMSIMD_MICROKERNEL_SRCS}) endif() endif() @@ -2033,10 +2036,14 @@ if (onnxruntime_ENABLE_TRAINING) list(APPEND onnxruntime_EXTERNAL_LIBRARIES tensorboard) endif() +if (onnxruntime_ENABLE_TRAINING) + set(onnxruntime_ENABLE_ATEN ON) +endif() + set(ONNXRUNTIME_TARGETS onnxruntime_common onnxruntime_graph onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session onnxruntime_mlas onnxruntime_flatbuffers) if (onnxruntime_USE_NUPHAR_TVM) - list(APPEND ONNXRUNTIME_TARGETS onnxruntime_codegen_tvm) + list(APPEND ONNXRUNTIME_TARGETS onnxruntime_codegen_tvm) endif() if (onnxruntime_ENABLE_EAGER_MODE) if (NOT onnxruntime_ENABLE_TRAINING OR NOT onnxruntime_ENABLE_PYTHON) @@ -2205,4 +2212,3 @@ if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) COMMENT "Installing protobuf" ) endif() - diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 98a1f74fdd..ca501dbeac 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -57,8 +57,9 @@ if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_framework PUBLIC ${MPI_CXX_INCLUDE_DIRS}) endif() endif() -if (onnxruntime_ENABLE_TRAINING) +if (onnxruntime_ENABLE_ATEN) # DLPack is a header-only dependency + target_compile_definitions(onnxruntime_framework PRIVATE ENABLE_ATEN) set(DLPACK_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/external/dlpack/include) target_include_directories(onnxruntime_framework PRIVATE ${DLPACK_INCLUDE_DIR}) endif() diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index a3e8ffea2c..8afb54d2a1 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -126,11 +126,15 @@ if (WIN32) set_target_properties(onnxruntime_graph PROPERTIES STATIC_LIBRARY_FLAGS "${onnxruntime_graph_static_library_flags}") - if (NOT onnxruntime_DISABLE_EXCEPTIONS) + if (NOT onnxruntime_DISABLE_EXCEPTIONS) target_compile_options(onnxruntime_graph PRIVATE /EHsc # exception handling - C++ may throw, extern "C" will not ) - endif() + endif() +endif() + +if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_graph PRIVATE ENABLE_ATEN) endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 89f7326941..c272314c94 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -165,6 +165,13 @@ set(onnxruntime_providers_src ${onnxruntime_providers_common_srcs} ${onnxruntime # disable contrib ops conditionally if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if (NOT onnxruntime_ENABLE_ATEN) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" + ) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -195,13 +202,21 @@ if (onnxruntime_ENABLE_TRAINING_OPS) "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/aten_ops/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/aten_ops/*.h" ) list(REMOVE_ITEM onnxruntime_providers_src ${onnxruntime_cpu_full_training_only_srcs}) endif() +if (onnxruntime_ENABLE_ATEN) + file(GLOB_RECURSE onnxruntime_providers_dlpack_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.cc" + "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.h" + ) + set(onnxruntime_providers_dlpack_srcs ${onnxruntime_providers_dlpack_srcs}) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dlpack_srcs}) + list(APPEND onnxruntime_providers_src ${onnxruntime_providers_dlpack_srcs}) +endif() + if (onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" @@ -222,15 +237,6 @@ if (onnxruntime_ENABLE_TRAINING) source_group(TREE ${ORTTRAINING_ROOT}/ FILES ${onnxruntime_cpu_training_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_training_ops_srcs}) - - # todo: put this in core/framework and enabled only for training - file(GLOB_RECURSE onnxruntime_providers_dlpack_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.cc" - "${ONNXRUNTIME_ROOT}/core/dlpack/dlpack_converter.h" - ) - set(onnxruntime_providers_dlpack_srcs ${onnxruntime_providers_dlpack_srcs}) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_dlpack_srcs}) - list(APPEND onnxruntime_providers_src ${onnxruntime_providers_dlpack_srcs}) endif() if (onnxruntime_REDUCED_OPS_BUILD) @@ -277,6 +283,13 @@ if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers PRIVATE ${ORTTRAINING_ROOT}) endif() +if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_providers PRIVATE ENABLE_ATEN) + # DLPack is a header-only dependency + set(DLPACK_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/external/dlpack/include) + target_include_directories(onnxruntime_providers PRIVATE ${DLPACK_INCLUDE_DIR}) +endif() + if (onnxruntime_ENABLE_TRAINING) add_dependencies(onnxruntime_providers tensorboard) onnxruntime_add_include_to_target(onnxruntime_providers tensorboard) @@ -287,10 +300,6 @@ if (onnxruntime_ENABLE_TRAINING) if (onnxruntime_USE_NCCL OR onnxruntime_USE_MPI) target_include_directories(onnxruntime_providers PUBLIC ${MPI_CXX_INCLUDE_DIRS}) endif() - - # DLPack is a header-only dependency - set(DLPACK_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/external/dlpack/include) - target_include_directories(onnxruntime_providers PRIVATE ${DLPACK_INCLUDE_DIR}) endif() install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/cpu DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) @@ -342,7 +351,7 @@ if (onnxruntime_USE_CUDA) "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.cc" ) - + # The shared_library files are in a separate list since they use precompiled headers, and the above files have them disabled. file(GLOB_RECURSE onnxruntime_providers_cuda_shared_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" @@ -358,6 +367,11 @@ if (onnxruntime_USE_CUDA) # disable contrib ops conditionally if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if (NOT onnxruntime_ENABLE_ATEN) + list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc" + ) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) @@ -481,10 +495,10 @@ if (onnxruntime_USE_CUDA) "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.cc" ) - # minimize the Windows includes. + # minimize the Windows includes. # this avoids an issue with CUDA 11.6 where 'small' is defined in the windows and cuda headers. target_compile_definitions(onnxruntime_providers_cuda PRIVATE "WIN32_LEAN_AND_MEAN") - + # disable a warning from the CUDA headers about unreferenced local functions #target_compile_options(onnxruntime_providers_cuda PRIVATE /wd4505) if (onnxruntime_USE_NUPHAR_TVM) @@ -509,6 +523,10 @@ if (onnxruntime_USE_CUDA) message(FATAL_ERROR "onnxruntime_providers_cuda unknown platform, need to specify shared library exports for it") endif() + if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_providers_cuda PRIVATE ENABLE_ATEN) + endif() + install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -758,12 +776,12 @@ if (onnxruntime_USE_OPENVINO) # Header paths find_package(InferenceEngine REQUIRED) find_package(ngraph REQUIRED) - + if (OPENVINO_2022_1) find_package(OpenVINO REQUIRED COMPONENTS Runtime ONNX) list (OV_20_LIBS openvino::frontend::onnx openvino::runtime) endif() - + if (WIN32) unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) endif() @@ -1297,6 +1315,11 @@ if (onnxruntime_USE_ROCM) # disable contrib ops conditionally if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if (NOT onnxruntime_ENABLE_ATEN) + list(REMOVE_ITEM onnxruntime_rocm_contrib_ops_cc_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/rocm/aten_ops/aten_op.cc" + ) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_contrib_ops_cc_srcs} ${onnxruntime_rocm_contrib_ops_cu_srcs}) @@ -1404,6 +1427,10 @@ if (onnxruntime_USE_ROCM) message(FATAL_ERROR "onnxruntime_providers_rocm unknown platform, need to specify shared library exports for it") endif() + if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_providers_rocm PRIVATE ENABLE_ATEN) + endif() + install(TARGETS onnxruntime_providers_rocm ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -1499,4 +1526,3 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() - diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 230f5ebb8b..fa69ec16e8 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -51,7 +51,7 @@ if(MSVC) target_compile_options(onnxruntime_pybind11_state PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") if(onnxruntime_ENABLE_TRAINING) target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj") - endif() + endif() endif() if(HAS_CAST_FUNCTION_TYPE) target_compile_options(onnxruntime_pybind11_state PRIVATE "-Wno-cast-function-type") @@ -98,9 +98,13 @@ else() set(ONNXRUNTIME_SO_LINK_FLAG "-DEF:${ONNXRUNTIME_ROOT}/python/pybind.def") endif() +if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_pybind11_state PRIVATE ENABLE_ATEN) + target_include_directories(onnxruntime_pybind11_state PRIVATE ${PROJECT_SOURCE_DIR}/external/dlpack/include) +endif() + if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_pybind11_state PRIVATE ${ORTTRAINING_ROOT}) - target_include_directories(onnxruntime_pybind11_state PRIVATE ${PROJECT_SOURCE_DIR}/external/dlpack/include) target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_training) endif() @@ -128,7 +132,7 @@ if (onnxruntime_ENABLE_EAGER_MODE) set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/eager/ort_util.cpp" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) set_source_files_properties("${ORTTRAINING_ROOT}/orttraining/python/orttraining_python_module.cc" PROPERTIES COMPILE_FLAGS -Wno-unused-parameter) endif() - if (MSVC) + if (MSVC) target_compile_options(onnxruntime_pybind11_state PRIVATE "/wd4100" "/wd4324" "/wd4458" "/wd4127" "/wd4193" "/wd4624" "/wd4702") target_compile_options(onnxruntime_pybind11_state PRIVATE "/bigobj" "/wd4275" "/wd4244" "/wd4267" "/wd4067") endif() @@ -306,7 +310,7 @@ if (onnxruntime_ENABLE_TRAINING) "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/*.py" ) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_aten_op_executor_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/*" + "${ONNXRUNTIME_ROOT}/python/torch_cpp_extensions/aten_op_executor/*" ) file(GLOB onnxruntime_python_ortmodule_torch_cpp_ext_torch_interop_utils_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/*" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6540c2ba85..baf81e48aa 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -466,8 +466,8 @@ endif() if(onnxruntime_USE_NUPHAR) # the test case under nuphar_tvm is only to verify some basic tvm show case, which is already out of date # it doesn't have relationship to nuphar directly. consider we have an official tvm execution provider now, - # keep those test cases doesn't bring any value now. - + # keep those test cases doesn't bring any value now. + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/framework/nuphar/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_nuphar) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_nuphar) @@ -699,11 +699,11 @@ endif() set(test_all_args) if (onnxruntime_USE_TENSORRT) # TRT EP CI takes much longer time when updating to TRT 8.2 - # So, we only run trt ep and exclude other eps to reduce CI test time. + # So, we only run trt ep and exclude other eps to reduce CI test time. # # The test names of model tests were using sequential number in the past. - # This PR https://github.com/microsoft/onnxruntime/pull/10220 (Please see ExpandModelName function in model_tests.cc for more details) - # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu__* or *xxx__*. + # This PR https://github.com/microsoft/onnxruntime/pull/10220 (Please see ExpandModelName function in model_tests.cc for more details) + # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu__* or *xxx__*. list(APPEND test_all_args "--gtest_filter=-*cpu__*:*cuda__*" ) endif () @@ -714,7 +714,7 @@ AddTest( onnx_test_runner_common ${onnxruntime_test_providers_libs} ${onnxruntime_test_common_libs} onnx_test_data_proto nlohmann_json::nlohmann_json DEPENDS ${all_dependencies} - TEST_ARGS ${test_all_args} + TEST_ARGS ${test_all_args} ) if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. @@ -755,6 +755,10 @@ if (onnxruntime_BUILD_WEBASSEMBLY) endif() endif() +if (onnxruntime_ENABLE_ATEN) + target_compile_definitions(onnxruntime_test_all PRIVATE ENABLE_ATEN) +endif() + set(test_data_target onnxruntime_test_all) onnxruntime_add_static_library(onnx_test_data_proto ${TEST_SRC_DIR}/proto/tml.proto) diff --git a/include/onnxruntime/core/framework/op_kernel_context.h b/include/onnxruntime/core/framework/op_kernel_context.h index cf9cd3ac68..073721988d 100644 --- a/include/onnxruntime/core/framework/op_kernel_context.h +++ b/include/onnxruntime/core/framework/op_kernel_context.h @@ -212,7 +212,7 @@ class OpKernelContext { const OrtValue* GetImplicitInputMLValue(int index) const; OrtValue* GetOutputMLValue(int index); -#ifdef ENABLE_TRAINING +#ifdef ENABLE_ATEN Status SetOutputMLValue(int index, const OrtValue& ort_value); #endif diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc new file mode 100644 index 0000000000..945c3aebce --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/aten_ops/aten_op.h" + +#include "core/dlpack/dlpack_converter.h" +#include "core/framework/op_kernel_context_internal.h" +#include "contrib_ops/cpu/aten_ops/aten_op_executor.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_KERNEL_EX(ATen, kPytorchAtenDomain, 1, kCpuExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorAndSequenceTensorTypes()), ATen); + +Status ATen::Compute(OpKernelContext* p_ctx) const { + auto* p_ctx_internal = static_cast(p_ctx); + size_t input_size = static_cast(p_ctx_internal->InputCount()); + size_t output_size = static_cast(p_ctx_internal->OutputCount()); + std::unique_ptr dlpack_inputs = std::make_unique(input_size); + std::unique_ptr dlpack_outputs = std::make_unique(output_size); + for (size_t i = 0; i < input_size; ++i) { + const OrtValue* p_ort_value = p_ctx_internal->GetInputMLValue(static_cast(i)); + if (!p_ort_value) { + dlpack_inputs[i] = nullptr; + } else { + OrtValue ort_value = *p_ort_value; + dlpack_inputs[i] = dlpack::OrtValueToDlpack(ort_value); + } + } + + aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, input_size, dlpack_inputs.get(), output_size, + dlpack_outputs.get()); + for (size_t i = 0; i < output_size; ++i) { + ORT_RETURN_IF_ERROR( + p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(dlpack_outputs[i]))); + } + + return Status::OK(); +} + +#ifdef ENABLE_TRAINING +bool IsATenOperatorExecutorInitialized() { return aten_ops::ATenOperatorExecutor::Instance().IsInitialized(); } + +Status ExecuteReduceSumATen(OpKernelContext* p_ctx, const gsl::span& axes, bool keepdims) { + ORT_ENFORCE(aten_ops::ATenOperatorExecutor::Instance().IsInitialized() && !axes.empty()); + size_t input_size = 4; + std::unique_ptr dlpack_inputs = std::make_unique(input_size); + auto* p_ctx_internal = static_cast(p_ctx); + OrtValue ort_value = *p_ctx_internal->GetInputMLValue(0); + dlpack_inputs[0] = dlpack::OrtValueToDlpack(ort_value); + OrtValue axes_tensor; + OrtValue keepdims_tensor; + TensorShapeVector axes_tensor_shape(1, static_cast(axes.size())); + TensorShapeVector keepdims_tensor_shape(1, 1); + auto ml_tensor = DataTypeImpl::GetType(); + OrtMemoryInfo info("Cpu", OrtDeviceAllocator); + auto axes_tensor_obj = std::make_unique(DataTypeImpl::GetType(), axes_tensor_shape, + const_cast(reinterpret_cast(&axes[0])), info); + axes_tensor.Init(axes_tensor_obj.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + auto keepdims_tensor_obj = std::make_unique(DataTypeImpl::GetType(), keepdims_tensor_shape, + reinterpret_cast(&keepdims), info); + keepdims_tensor.Init(keepdims_tensor_obj.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + dlpack_inputs[1] = dlpack::OrtValueToDlpack(axes_tensor); + dlpack_inputs[2] = dlpack::OrtValueToDlpack(keepdims_tensor); + dlpack_inputs[3] = nullptr; + DLManagedTensor* dlpack_output = nullptr; + aten_ops::ATenOperatorExecutor::Instance()("sum", "dim_IntList", input_size, dlpack_inputs.get(), 1, &dlpack_output); + ORT_ENFORCE(dlpack_output); + ORT_RETURN_IF_ERROR(p_ctx_internal->SetOutputMLValue(0, dlpack::DlpackToOrtValue(dlpack_output))); + return Status::OK(); +} +#endif + +} // namespace contrib +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.h similarity index 96% rename from orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h rename to onnxruntime/contrib_ops/cpu/aten_ops/aten_op.h index 1a80af45c5..54f4ed2d16 100644 --- a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op.h @@ -22,8 +22,10 @@ class ATen : public OpKernel { std::string overload_name_; }; +#ifdef ENABLE_TRAINING bool IsATenOperatorExecutorInitialized(); Status ExecuteReduceSumATen(OpKernelContext* p_ctx, const gsl::span& axes, bool keepdims); +#endif } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h similarity index 70% rename from orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h rename to onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index ce518bfea6..be9650d96b 100644 --- a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -11,8 +11,9 @@ namespace contrib { namespace aten_ops { typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index); -typedef std::vector (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, - const std::vector& dlpacks); +typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, + DLManagedTensor** dlpack_inputs, size_t output_size, + DLManagedTensor** dlpack_outputs); class ATenOperatorExecutor { public: @@ -34,10 +35,11 @@ class ATenOperatorExecutor { return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); } - std::vector operator()(const std::string& op_name, const std::string& overload_name, - const std::vector& dlpacks) { + void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, + DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs) { ORT_ENFORCE(p_execute_aten_op_func_, "ATenOperatorExecutor is not initialized."); - return p_execute_aten_op_func_(op_name.c_str(), overload_name.c_str(), dlpacks); + p_execute_aten_op_func_(op_name.c_str(), overload_name.c_str(), input_size, dlpack_inputs, output_size, + dlpack_outputs); } private: diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 7b14d496d6..2068b3c3e3 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -112,6 +112,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); +#ifdef ENABLE_ATEN +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -249,6 +253,10 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + +#ifdef ENABLE_ATEN + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/aten_ops/aten_op.cc b/onnxruntime/contrib_ops/cuda/aten_ops/aten_op.cc new file mode 100644 index 0000000000..a39c587b42 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/aten_ops/aten_op.cc @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "contrib_ops/cpu/aten_ops/aten_op.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + ATen, kPytorchAtenDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllTensorAndSequenceTensorTypes()), + onnxruntime::contrib::ATen); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 2509953a8d..eb1e745d41 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -97,6 +97,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, LayerNormalization); +#ifdef ENABLE_ATEN +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -195,6 +199,10 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + +#ifdef ENABLE_ATEN + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index a9cefec499..7efd5d8524 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -96,6 +96,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, LayerNormalization); +#ifdef ENABLE_ATEN +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -194,6 +198,10 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, // BuildKernelCreateInfo, + +#ifdef ENABLE_ATEN + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index aaf7e59434..ae50f19352 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -38,18 +38,23 @@ IExecutionFrame::IExecutionFrame(const OrtValueNameIdxMap& ort_value_idx_map, IExecutionFrame::~IExecutionFrame() = default; -#ifdef ENABLE_TRAINING +#ifdef ENABLE_ATEN Status IExecutionFrame::SetOutputMLValue(int index, const OrtValue& ort_value) { int ort_value_idx = GetNodeIdxToMLValueIdx(index); if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); } - ORT_ENFORCE(!all_values_[ort_value_idx].IsAllocated()); - all_values_[ort_value_idx] = ort_value; + if (all_values_[ort_value_idx].IsAllocated()) { + ORT_RETURN_IF_ERROR(CopyTensor(ort_value.Get(), *all_values_[ort_value_idx].GetMutable())); + } else { + all_values_[ort_value_idx] = ort_value; + } return Status::OK(); } +#endif +#ifdef ENABLE_TRAINING void IExecutionFrame::UpdateFeeds(const std::vector& feed_mlvalue_idxs, const std::vector& feeds) { ORT_ENFORCE(feed_mlvalue_idxs.size() == feeds.size()); diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 98268ac1c7..58f094f426 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -51,9 +51,12 @@ class IExecutionFrame { const OrtValue* GetNodeInputOrOutputMLValue(int index) const; OrtValue* GetMutableNodeInputOrOutputMLValue(int index); -#ifdef ENABLE_TRAINING +#ifdef ENABLE_ATEN // Override the index-th output with ort_value Status SetOutputMLValue(int index, const OrtValue& ort_value); +#endif + +#ifdef ENABLE_TRAINING void UpdateFeeds(const std::vector& feed_mlvalue_idxs, const std::vector& feeds); void UpdateFetches(const std::vector& fetch_mlvalue_idxs, const std::vector& fetches, const std::unordered_map& initializers); @@ -73,7 +76,7 @@ class IExecutionFrame { /** * write the output values to the 'fetches' vector - * Don't access the values after SessionState is destroyed + * Don't access the values after SessionState is destroyed */ Status GetOutputs(std::vector& fetches); diff --git a/onnxruntime/core/framework/op_kernel.cc b/onnxruntime/core/framework/op_kernel.cc index a1b0f2fc26..cccff45541 100644 --- a/onnxruntime/core/framework/op_kernel.cc +++ b/onnxruntime/core/framework/op_kernel.cc @@ -208,7 +208,7 @@ OrtValue* OpKernelContext::GetOutputMLValue(int index) { return execution_frame_->GetMutableNodeInputOrOutputMLValue(output_arg_index); } -#ifdef ENABLE_TRAINING +#ifdef ENABLE_ATEN Status OpKernelContext::SetOutputMLValue(int index, const OrtValue& ort_value) { if (index < 0 || index >= OutputCount()) { return Status(common::ONNXRUNTIME, common::FAIL, diff --git a/onnxruntime/core/framework/op_kernel_context_internal.h b/onnxruntime/core/framework/op_kernel_context_internal.h index 8e3a0747c0..f679b710d1 100644 --- a/onnxruntime/core/framework/op_kernel_context_internal.h +++ b/onnxruntime/core/framework/op_kernel_context_internal.h @@ -53,7 +53,7 @@ class OpKernelContextInternal : public OpKernelContext { return OpKernelContext::GetOutputMLValue(index); } -#ifdef ENABLE_TRAINING +#ifdef ENABLE_ATEN Status SetOutputMLValue(int index, const OrtValue& ort_value) { return OpKernelContext::SetOutputMLValue(index, ort_value); } diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 0a663688c8..58dc3d0510 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -21,7 +21,10 @@ #include "core/framework/TensorSeq.h" #ifdef ENABLE_TRAINING #include "core/framework/orttraining_partial_executor.h" -#include "orttraining/training_ops/cpu/aten_ops/aten_op_executor.h" +#endif + +#ifdef ENABLE_ATEN +#include "contrib_ops/cpu/aten_ops/aten_op_executor.h" #endif namespace ONNX_NAMESPACE { @@ -787,8 +790,9 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) return true; } -#ifdef ENABLE_TRAINING - if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { +#ifdef ENABLE_ATEN + if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && + node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 4e371f0691..0a107a02c6 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2510,6 +2510,24 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t } }); +#ifdef ENABLE_ATEN + ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) + .SetDomain(kPytorchAtenDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("ATen") + .Input(0, "inputs", "ATen Op inputs.", "T", OpSchema::Variadic, + /*is_homogeneous*/ false, + /*min_arity*/ 1) + .Output(0, "outputs", "ATen Op outputs.", "T", OpSchema::Variadic, + /*is_homogeneous*/ false, + /*min_arity*/ 1) + .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) + .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) + .TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(), + "Allow inputs and outputs to be any kind of tensor."); +#endif + #ifndef _OPSCHEMA_LIB_ // Register the NCHWc schemas if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 648dc2b2d9..e4dcd6d10c 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -32,10 +32,12 @@ #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" #include "contrib_ops/cpu/bert/longformer_attention_base.h" #include "contrib_ops/cpu/transformers/beam_search.h" +#ifdef ENABLE_ATEN +#include "contrib_ops/cpu/aten_ops/aten_op.h" +#endif #endif #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cpu/aten_ops/aten_op.h" #include "orttraining/training_ops/cpu/controlflow/group.h" #include "orttraining/training_ops/cpu/controlflow/record.h" #include "orttraining/training_ops/cpu/controlflow/wait.h" @@ -173,10 +175,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU { void BeamSearch__Init(contrib::transformers::BeamSearch* p, const OpKernelInfo& info) override { p->contrib::transformers::BeamSearch::Init(info); } virtual Status BeamSearch__Compute(const contrib::transformers::BeamSearch* p, OpKernelContext* ctx) { return p->contrib::transformers::BeamSearch::Compute(ctx); } virtual Status BeamSearch__SetupSubgraphExecutionInfo(contrib::transformers::BeamSearch* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::BeamSearch::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + +#ifdef ENABLE_ATEN + Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } +#endif #endif #ifdef ENABLE_TRAINING - Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } void contrib__record_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::record_event_in_tensor(event_id_tensor); } void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::wait_event_in_tensor(event_id_tensor); } Status contrib__Group__Compute(const contrib::Group* p, OpKernelContext* context) override { return p->Group::Compute(context); } diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index f3a7822f47..fd4e88189d 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -134,10 +134,13 @@ struct ProviderHostCPU { virtual void BeamSearch__Init(contrib::transformers::BeamSearch* p, const OpKernelInfo& info) = 0; virtual Status BeamSearch__Compute(const contrib::transformers::BeamSearch* p, OpKernelContext* ctx) = 0; virtual Status BeamSearch__SetupSubgraphExecutionInfo(contrib::transformers::BeamSearch* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + +#ifdef ENABLE_ATEN + virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; +#endif #endif #ifdef ENABLE_TRAINING - virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; virtual void contrib__record_event_in_tensor(const Tensor& event_id_tensor) = 0; virtual void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) = 0; virtual Status contrib__Group__Compute(const contrib::Group* p, OpKernelContext* context) = 0; diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index 0806a0fd82..cd2862a36a 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -9,7 +9,7 @@ #include "core/providers/cuda/math/binary_elementwise_ops.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cpu/aten_ops/aten_op.h" +#include "contrib_ops/cpu/aten_ops/aten_op.h" #endif using namespace onnxruntime::common; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index a60e510b27..bf5d808bd9 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -31,10 +31,12 @@ #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" #include "contrib_ops/cpu/bert/longformer_attention_base.h" #include "contrib_ops/cpu/transformers/beam_search.h" +#ifdef ENABLE_ATEN +#include "contrib_ops/cpu/aten_ops/aten_op.h" +#endif #endif #ifdef ENABLE_TRAINING -#include "orttraining/training_ops/cpu/aten_ops/aten_op.h" #include "orttraining/training_ops/cpu/controlflow/group.h" #include "orttraining/training_ops/cpu/controlflow/yield.h" @@ -541,6 +543,10 @@ void BeamSearch::Init(const OpKernelInfo& info) { g_host_cpu.BeamSearch__Init(th Status BeamSearch::Compute(OpKernelContext* ctx) const { return g_host_cpu.BeamSearch__Compute(this, ctx); } Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { return g_host_cpu.BeamSearch__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } } // namespace transformers + +#ifdef ENABLE_ATEN +Status ATen::Compute(OpKernelContext* p_ctx) const { return g_host_cpu.ATen__Compute(this, p_ctx); } +#endif } // namespace contrib #endif @@ -567,7 +573,6 @@ Status Scan<9>::SetupSubgraphExecutionInfo(const SessionState& session_state, co #ifdef ENABLE_TRAINING namespace contrib { -Status ATen::Compute(OpKernelContext* p_ctx) const { return g_host_cpu.ATen__Compute(this, p_ctx); } Status Group::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__Group__Compute(this, context); } Status PassThrough::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__PassThrough__Compute(this, context); } Status YieldOp::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__YieldOp__Compute(this, context); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a93539f2aa..339b3a902a 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -30,6 +30,10 @@ #include "core/session/provider_bridge_ort.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" +#ifdef ENABLE_ATEN +#include "contrib_ops/cpu/aten_ops/aten_op_executor.h" +#endif + // Explicitly provide a definition for the static const var 'GPU' in the OrtDevice struct, // GCC 4.x doesn't seem to define this and it breaks the pipelines based on CentOS as it uses // GCC 4.x. @@ -1010,6 +1014,19 @@ void addGlobalMethods(py::module& m, Environment& env) { arena_extend_strategy = strategy; }); #endif + +#ifdef ENABLE_ATEN + m.def("register_aten_op_executor", + [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_tensor_argument_address_int, aten_op_executor_address_int; + ORT_THROW_IF_ERROR( + ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); + ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); + void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); + void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); + }); +#endif } void addObjectMethods(py::module& m, Environment& env, ExecutionProviderRegistrationFn ep_registration_fn) { @@ -1236,7 +1253,7 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") const OrtValue* ml_value = ml_value_pyobject.attr(PYTHON_ORTVALUE_NATIVE_OBJECT_ATTR).cast(); ORT_THROW_IF_ERROR(options->AddInitializer(name, ml_value)); }) - .def("add_external_initializers", [](PySessionOptions* options, py::list& names, + .def("add_external_initializers", [](PySessionOptions* options, py::list& names, const py::list& ort_values) -> void { #if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS) const auto init_num = ort_values.size(); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a01d6659f8..617b010ad6 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -192,19 +192,19 @@ class SymbolicShapeInference: "SkipLayerNormalization": self._infer_SkipLayerNormalization, } self.aten_op_dispatcher_ = { - "aten::embedding": self._infer_Gather, - "aten::bitwise_or": self._infer_aten_bitwise_or, - "aten::diagonal": self._infer_aten_diagonal, - "aten::max_pool2d_with_indices": self._infer_aten_pool2d, - "aten::max": self._infer_aten_minmax, - "aten::min": self._infer_aten_minmax, - "aten::multinomial": self._infer_aten_multinomial, - "aten::unfold": self._infer_aten_unfold, - "aten::argmax": self._infer_aten_argmax, - "aten::avg_pool2d": self._infer_aten_pool2d, - "aten::_adaptive_avg_pool2d": self._infer_aten_pool2d, - "aten::binary_cross_entropy_with_logits": self._infer_aten_bce, - "aten::numpy_T": self._infer_Transpose, + "embedding": self._infer_Gather, + "bitwise_or": self._infer_aten_bitwise_or, + "diagonal": self._infer_aten_diagonal, + "max_pool2d_with_indices": self._infer_aten_pool2d, + "max": self._infer_aten_minmax, + "min": self._infer_aten_minmax, + "multinomial": self._infer_aten_multinomial, + "unfold": self._infer_aten_unfold, + "argmax": self._infer_aten_argmax, + "avg_pool2d": self._infer_aten_pool2d, + "_adaptive_avg_pool2d": self._infer_aten_pool2d, + "binary_cross_entropy_with_logits": self._infer_aten_bce, + "numpy_T": self._infer_Transpose, } self.run_ = True self.suggested_merge_ = {} diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py similarity index 92% rename from orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py rename to onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 7a76951674..9dee656450 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -1,8 +1,8 @@ -from onnxruntime.capi import _pybind_state as C - import threading from functools import wraps +from onnxruntime.capi import _pybind_state as _C + def run_once_aten_op_executor(f): """ @@ -28,6 +28,6 @@ def run_once_aten_op_executor(f): def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor - C.register_aten_op_executor( + _C.register_aten_op_executor( str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc similarity index 92% rename from orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc rename to onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index e1dbc56814..e5b6dc8083 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -175,13 +175,13 @@ bool IsTensorArgument(const char* op_name, const char* overload_name, size_t ind return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; } -std::vector ExecuteATenOperator(const char* op_name, const char* overload_name, - const std::vector& dlpacks) { +void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, + DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs) { const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); - TORCH_INTERNAL_ASSERT(dlpacks.size() == aten_op.argument_size); + TORCH_INTERNAL_ASSERT(input_size == aten_op.argument_size); std::vector arguments; - for (size_t i = 0; i < dlpacks.size(); i++) { - arguments.emplace_back(aten_op.ToIValueArgument(dlpacks[i], i)); + for (size_t i = 0; i < input_size; ++i) { + arguments.emplace_back(aten_op.ToIValueArgument(dlpack_inputs[i], i)); } torch::jit::Stack stack; @@ -191,8 +191,7 @@ std::vector ExecuteATenOperator(const char* op_name, const cha #ifndef TORCH_VERSION_PREEQ #define TORCH_VERSION_PREEQ(x, y) \ - ((TORCH_VERSION_MAJOR == (x) && TORCH_VERSION_MINOR >= (y)) || \ - (TORCH_VERSION_MAJOR > (x))) + ((TORCH_VERSION_MAJOR == (x) && TORCH_VERSION_MINOR >= (y)) || (TORCH_VERSION_MAJOR > (x))) #endif // pull request https://github.com/pytorch/pytorch/pull/63414 introduced @@ -207,13 +206,12 @@ std::vector ExecuteATenOperator(const char* op_name, const cha aten_op.op->getOperation()(&stack); #endif - std::vector result; - for (const auto& ret : torch::jit::pop(stack, aten_op.return_size)) { + TORCH_INTERNAL_ASSERT(output_size == aten_op.return_size); + size_t output_index = 0; + for (const auto& ret : torch::jit::pop(stack, output_size)) { const auto& tensor = ret.toTensor(); - result.emplace_back(at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous())); + dlpack_outputs[output_index++] = at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()); } - - return result; } size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/setup.py similarity index 93% rename from orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py rename to onnxruntime/python/torch_cpp_extensions/aten_op_executor/setup.py index 5485170be8..e09449291b 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/setup.py @@ -4,7 +4,8 @@ # -------------------------------------------------------------------------- import os -from setuptools import setup, Extension + +from setuptools import setup from torch.utils import cpp_extension filename = os.path.join(os.path.dirname(__file__), "aten_op_executor.cc") diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py new file mode 100644 index 0000000000..f1b8e03ace --- /dev/null +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -0,0 +1,37 @@ +import threading +from functools import wraps + +import torch + +from onnxruntime.capi import _pybind_state as _C + +from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address + + +def run_once_aten_op_executor(f): + """ + Decorator to run a function only once. + :param f: function to be run only once during execution time despite the number of calls + :return: The original function with the params passed to it if it hasn't already been run before + """ + + @wraps(f) + def aten_op_executor_wrapper(*args, **kwargs): + if not aten_op_executor_wrapper.has_run: + with aten_op_executor_wrapper.lock: + if not aten_op_executor_wrapper.has_run: + aten_op_executor_wrapper.has_run = True + return f(*args, **kwargs) + + aten_op_executor_wrapper.lock = threading.Lock() + aten_op_executor_wrapper.has_run = False + return aten_op_executor_wrapper + + +@run_once_aten_op_executor +def load_aten_op_executor_cpp_extension(): + _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) + + +def init_aten_op_executor(): + load_aten_op_executor_cpp_extension() diff --git a/onnxruntime/python/torch_cpp_extensions/setup.py b/onnxruntime/python/torch_cpp_extensions/setup.py new file mode 100644 index 0000000000..f7bf2ab1f7 --- /dev/null +++ b/onnxruntime/python/torch_cpp_extensions/setup.py @@ -0,0 +1,19 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os + +from setuptools import setup +from torch.utils import cpp_extension + +filename = os.path.join(os.path.dirname(__file__), "aten_op_executor/aten_op_executor.cc") + +setup( + name="ort_torch_ext", + version="1.0", + ext_modules=[cpp_extension.CppExtension(name="ort_torch_ext.aten_op_executor", sources=[filename])], + packages=["ort_torch_ext"], + cmdclass={"build_ext": cpp_extension.BuildExtension}, +) diff --git a/onnxruntime/test/python/contrib_ops/aten_op_tests.py b/onnxruntime/test/python/contrib_ops/aten_op_tests.py new file mode 100644 index 0000000000..0ae8fe6c28 --- /dev/null +++ b/onnxruntime/test/python/contrib_ops/aten_op_tests.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + +import io +import unittest + +import numpy as np +import onnx +import torch +from onnx import TensorProto, helper +from ort_torch_ext import init_aten_op_executor +from torch.onnx import export + +import onnxruntime as ort + + +class OrtOpTests(unittest.TestCase): + def test_aten_embedding(self): + class NeuralNetEmbedding(torch.nn.Module): + def __init__(self, num_embeddings, embedding_dim, hidden_size): + super(NeuralNetEmbedding, self).__init__() + self.embedding = torch.nn.Embedding(num_embeddings, embedding_dim) + self.linear = torch.nn.Linear(embedding_dim, hidden_size) + + def forward(self, input): + embedding_result = self.embedding(input) + return embedding_result, self.linear(embedding_result) + + N, num_embeddings, embedding_dim, hidden_size = 64, 32, 128, 128 + model = NeuralNetEmbedding(num_embeddings, embedding_dim, hidden_size) + + with torch.no_grad(): + x = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64) + dynamic_axes = {"x": {0: "x_dim0"}, "y": {0: "y_dim0", 1: "y_dim1"}} + + f = io.BytesIO() + + export( + model, + x, + f=f, + input_names=["x"], + output_names=["y"], + dynamic_axes=dynamic_axes, + opset_version=14, + ) + + exported_model = onnx.load_model_from_string(f.getvalue()) + + # PyTorch exporter emitting ATen op is still under development. Currently convert it manually for testing. + for node in exported_model.graph.node: + if node.op_type == "Gather": + node.domain = "org.pytorch.aten" + node.op_type = "ATen" + attr = node.attribute.add() + attr.name = "operator" + attr.type = 3 + attr.s = "embedding".encode() + exported_model.graph.node.append( + helper.make_node( + "Constant", + [], + ["padding_idx"], + value=helper.make_tensor("padding_idx", TensorProto.INT64, (), [-1]), + ) + ) + exported_model.graph.node.append( + helper.make_node( + "Constant", + [], + ["scale_grad_by_freq"], + value=helper.make_tensor("scale_grad_by_freq", TensorProto.BOOL, (), [False]), + ) + ) + exported_model.graph.node.append( + helper.make_node( + "Constant", + [], + ["sparse"], + value=helper.make_tensor("sparse", TensorProto.BOOL, (), [False]), + ) + ) + node.input.append("padding_idx") + node.input.append("scale_grad_by_freq") + node.input.append("sparse") + exported_model.graph.value_info.append( + helper.make_value_info( + name=node.output[0], + type_proto=helper.make_tensor_type_proto( + elem_type=TensorProto.FLOAT, shape=[node.output[0] + "_dim0", node.output[0] + "_dim1"] + ), + ) + ) + break + + # The ONNX graph to run contains ATen Op. + assert any(node.op_type == "ATen" for node in exported_model.graph.node) + + init_aten_op_executor() + + # Run w/o IO binding. + for _ in range(8): + x = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64) + pt_y1, pt_y2 = model(x) + session = ort.InferenceSession(exported_model.SerializeToString(), providers=["CPUExecutionProvider"]) + ort_y1, ort_y2 = session.run([], {"x": x.numpy()}) + np.testing.assert_almost_equal(ort_y1, pt_y1.detach().numpy()) + np.testing.assert_almost_equal(ort_y2, pt_y2.detach().numpy()) + + # Run w/ IO binding. + for _ in range(8): + x = torch.randint(high=num_embeddings, size=(N,), dtype=torch.int64) + ort_x = ort.OrtValue.ortvalue_from_numpy(x.detach().numpy(), "cpu") + pt_y1, pt_y2 = model(x) + np_y1 = np.zeros(tuple(pt_y1.size()), dtype=np.float32) + np_y2 = np.zeros(tuple(pt_y2.size()), dtype=np.float32) + ort_y1 = ort.OrtValue.ortvalue_from_numpy(np_y1, "cpu") + ort_y2 = ort.OrtValue.ortvalue_from_numpy(np_y2, "cpu") + session = ort.InferenceSession(exported_model.SerializeToString(), providers=["CPUExecutionProvider"]) + io_binding = session.io_binding() + io_binding.bind_ortvalue_input(exported_model.graph.input[0].name, ort_x) + io_binding.bind_ortvalue_output(exported_model.graph.output[0].name, ort_y1) + io_binding.bind_ortvalue_output(exported_model.graph.output[1].name, ort_y2) + session.run_with_iobinding(io_binding) + np.testing.assert_almost_equal(np_y1, pt_y1.detach().numpy()) + np.testing.assert_almost_equal(np_y2, pt_y2.detach().numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 21a1a2b681..34435efa67 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3384,24 +3384,6 @@ Return true if all elements are true and false otherwise. } }); -#ifdef ENABLE_TRAINING - ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) - .SetDomain(kPytorchAtenDomain) - .SinceVersion(1) - .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) - .SetDoc("ATen") - .Input(0, "inputs", "ATen Op inputs.", "T", OpSchema::Variadic, - /*is_homogeneous*/ false, - /*min_arity*/ 1) - .Output(0, "outputs", "ATen Op outputs.", "T", OpSchema::Variadic, - /*is_homogeneous*/ false, - /*min_arity*/ 1) - .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) - .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) - .TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(), - "Allow inputs and outputs to be any kind of tensor."); -#endif - ONNX_CONTRIB_OPERATOR_SCHEMA(PythonOp) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index afa4b338f2..a6d4ba01db 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -21,9 +21,6 @@ #include "orttraining/core/framework/ortmodule_graph_builder.h" #include "orttraining/core/graph/gradient_definition_registry.h" #include "python/onnxruntime_pybind_mlvalue.h" - -#include "orttraining/training_ops/cpu/aten_ops/aten_op_executor.h" - #include "orttraining/python/orttraining_pybind_common.h" #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -555,16 +552,6 @@ for every transfered tensor. m.def("get_mpi_context_world_size", []() -> int { return MPIContext::GetInstance().GetWorldSize(); }); #endif - m.def("register_aten_op_executor", - [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_tensor_argument_address_int, aten_op_executor_address_int; - ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); - ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); - void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); - }); m.def("register_forward_runner", [](py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP auto& pool = onnxruntime::language_interop_ops::torch::OrtTorchFunctionPool::GetInstance(); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index a4d07cb338..c0d1e2cb40 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -25,6 +25,7 @@ # 'is_tensor' is optional, if not present, the default is False. import json + from onnxruntime.capi import _pybind_state as C @@ -85,7 +86,7 @@ def register_gradient(domain, name, *attributes): # For ATen op, we need to provide op_name and overload name. -@register_gradient("org.pytorch.aten", "ATen", "aten::embedding", "") +@register_gradient("org.pytorch.aten", "ATen", "embedding", "") def embedding_gradient(): return [ ("Constant", [], ["Const_0"], {"value": {"value": 0, "dtype": "int", "is_tensor": True}}), @@ -95,12 +96,12 @@ def embedding_gradient(): ("ATen", "org.pytorch.aten"), ["GO(0)", "I(1)", "Gather_X_0", "I(2)", "I(3)", "I(4)"], ["GI(0)"], - {"operator": {"value": "aten::embedding_backward", "dtype": "string"}}, + {"operator": {"value": "embedding_backward", "dtype": "string"}}, ), ] -@register_gradient("org.pytorch.aten", "ATen", "aten::diagonal", "") +@register_gradient("org.pytorch.aten", "ATen", "diagonal", "") def diagonal_gradient(): return [ ("Shape", ["I(0)"], ["Shape_X"]), @@ -108,19 +109,19 @@ def diagonal_gradient(): ("ATen", "org.pytorch.aten"), ["GO(0)", "Shape_X", "I(1)", "I(2)", "I(3)"], ["GI(0)"], - {"operator": {"value": "aten::diagonal_backward", "dtype": "string"}}, + {"operator": {"value": "diagonal_backward", "dtype": "string"}}, ), ] -@register_gradient("org.pytorch.aten", "ATen", "aten::max_pool2d_with_indices", "") +@register_gradient("org.pytorch.aten", "ATen", "max_pool2d_with_indices", "") def max_pool2d_gradient(): return [ ( ("ATen", "org.pytorch.aten"), ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)", "I(5)", "O(1)"], ["GI(0)"], - {"operator": {"value": "aten::max_pool2d_with_indices_backward", "dtype": "string"}}, + {"operator": {"value": "max_pool2d_with_indices_backward", "dtype": "string"}}, ), ] @@ -143,8 +144,8 @@ def minmax_gradient(): ] -min_gradient = register_gradient("org.pytorch.aten", "ATen", "aten::min", "")(minmax_gradient) -max_gradient = register_gradient("org.pytorch.aten", "ATen", "aten::max", "")(minmax_gradient) +min_gradient = register_gradient("org.pytorch.aten", "ATen", "min", "")(minmax_gradient) +max_gradient = register_gradient("org.pytorch.aten", "ATen", "max", "")(minmax_gradient) def minmax_dim_gradient(): @@ -161,16 +162,16 @@ def minmax_dim_gradient(): ("ATen", "org.pytorch.aten"), ["GO(0)", "I(1)", "O(1)", "Shape_X", "I(2)"], ["GI(0)"], - {"operator": {"value": "aten::value_selecting_reduction_backward", "dtype": "string"}}, + {"operator": {"value": "value_selecting_reduction_backward", "dtype": "string"}}, ), ] -min_dim_gradient = register_gradient("org.pytorch.aten", "ATen", "aten::min", "dim")(minmax_dim_gradient) -max_dim_gradient = register_gradient("org.pytorch.aten", "ATen", "aten::max", "dim")(minmax_dim_gradient) +min_dim_gradient = register_gradient("org.pytorch.aten", "ATen", "min", "dim")(minmax_dim_gradient) +max_dim_gradient = register_gradient("org.pytorch.aten", "ATen", "max", "dim")(minmax_dim_gradient) -@register_gradient("org.pytorch.aten", "ATen", "aten::unfold", "") +@register_gradient("org.pytorch.aten", "ATen", "unfold", "") def unfold_gradient(): return [ ("Shape", ["I(0)"], ["Shape_X"]), @@ -178,58 +179,58 @@ def unfold_gradient(): ("ATen", "org.pytorch.aten"), ["GO(0)", "Shape_X", "I(1)", "I(2)", "I(3)"], ["GI(0)"], - {"operator": {"value": "aten::unfold_backward", "dtype": "string"}}, + {"operator": {"value": "unfold_backward", "dtype": "string"}}, ), ] -@register_gradient("org.pytorch.aten", "ATen", "aten::avg_pool2d", "") +@register_gradient("org.pytorch.aten", "ATen", "avg_pool2d", "") def avg_pool2d_gradient(): return [ ( ("ATen", "org.pytorch.aten"), ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)", "I(5)", "I(6)"], ["GI(0)"], - {"operator": {"value": "aten::avg_pool2d_backward", "dtype": "string"}}, + {"operator": {"value": "avg_pool2d_backward", "dtype": "string"}}, ), ] -@register_gradient("org.pytorch.aten", "ATen", "aten::_adaptive_avg_pool2d", "") +@register_gradient("org.pytorch.aten", "ATen", "_adaptive_avg_pool2d", "") def adaptive_avg_pool2d_gradient(): return [ ( ("ATen", "org.pytorch.aten"), ["GO(0)", "I(0)"], ["GI(0)"], - {"operator": {"value": "aten::_adaptive_avg_pool2d_backward", "dtype": "string"}}, + {"operator": {"value": "_adaptive_avg_pool2d_backward", "dtype": "string"}}, ), ] -CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "aten::argmax", "") -CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "aten::multinomial", "") +CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "argmax", "") +CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "multinomial", "") -@register_gradient("org.pytorch.aten", "ATen", "aten::binary_cross_entropy_with_logits", "") +@register_gradient("org.pytorch.aten", "ATen", "binary_cross_entropy_with_logits", "") def binary_cross_entropy_with_logits_gradient(): return [ ( ("ATen", "org.pytorch.aten"), ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)"], ["GI(0)"], - {"operator": {"value": "aten::binary_cross_entropy_with_logits_backward", "dtype": "string"}}, + {"operator": {"value": "binary_cross_entropy_with_logits_backward", "dtype": "string"}}, ), ] -@register_gradient("org.pytorch.aten", "ATen", "aten::numpy_T", "") +@register_gradient("org.pytorch.aten", "ATen", "numpy_T", "") def numpy_T_gradient(): return [ ( ("ATen", "org.pytorch.aten"), ["GO(0)"], ["GI(0)"], - {"operator": {"value": "aten::numpy_T", "dtype": "string"}}, + {"operator": {"value": "numpy_T", "dtype": "string"}}, ), ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0ffd2e5fcf..526a6d559f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -3,10 +3,10 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from torch.onnx import register_custom_op_symbolic -from torch.onnx.symbolic_helper import parse_args, _get_tensor_dim_size, _get_tensor_sizes -import torch.onnx.symbolic_helper as sym_help import torch +import torch.onnx.symbolic_helper as sym_help +from torch.onnx import register_custom_op_symbolic +from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args class CustomOpSymbolicRegistry: @@ -73,7 +73,7 @@ def nll_loss(g, self, target, weight, reduction, ignore_index): @register_symbolic("embedding") def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): output = g.op( - "org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, operator_s="aten::embedding" + "org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, operator_s="embedding" ) indices_shape = _get_tensor_sizes(indices) if indices_shape is not None and hasattr(weight.type(), "with_sizes"): @@ -84,19 +84,19 @@ def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): @register_symbolic("bitwise_or") def bitwise_or(g, self, other): - return g.op("org.pytorch.aten::ATen", self, other, operator_s="aten::bitwise_or", overload_name_s="Tensor") + return g.op("org.pytorch.aten::ATen", self, other, operator_s="bitwise_or", overload_name_s="Tensor") @register_symbolic("diagonal") def diagonal(g, self, offset, dim1, dim2): - return g.op("org.pytorch.aten::ATen", self, offset, dim1, dim2, operator_s="aten::diagonal") + return g.op("org.pytorch.aten::ATen", self, offset, dim1, dim2, operator_s="diagonal") @register_symbolic("multinomial") def multinomial(g, self, num_samples, replacement=False, generator=None): if generator is not None and not sym_help._is_none(generator): raise RuntimeError("Unsupported: ONNX does not support generator for multinomial") - return g.op("org.pytorch.aten::ATen", self, num_samples, replacement, generator, operator_s="aten::multinomial") + return g.op("org.pytorch.aten::ATen", self, num_samples, replacement, generator, operator_s="multinomial") @register_symbolic("max_pool2d") @@ -112,7 +112,7 @@ def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode): padding, dilation, ceil_mode, - operator_s="aten::max_pool2d_with_indices", + operator_s="max_pool2d_with_indices", outputs=2, )[0] @@ -121,38 +121,34 @@ def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode): def max(g, self, dim_or_y=None, keepdim=None): # torch.max(input), returns the max value in the tensor if dim_or_y is None and keepdim is None: - return g.op("org.pytorch.aten::ATen", self, operator_s="aten::max") + return g.op("org.pytorch.aten::ATen", self, operator_s="max") # torch.max(input, other) if keepdim is None: return g.op("Max", self, dim_or_y) # torch.max(input, dim, keepdim), returns (max_values, max_indices) - return g.op( - "org.pytorch.aten::ATen", self, dim_or_y, keepdim, operator_s="aten::max", overload_name_s="dim", outputs=2 - ) + return g.op("org.pytorch.aten::ATen", self, dim_or_y, keepdim, operator_s="max", overload_name_s="dim", outputs=2) @register_symbolic("min") def min(g, self, dim_or_y=None, keepdim=None): # torch.min(input), returns the min value in the tensor if dim_or_y is None and keepdim is None: - return g.op("org.pytorch.aten::ATen", self, operator_s="aten::min") + return g.op("org.pytorch.aten::ATen", self, operator_s="min") # torch.min(input, other) if keepdim is None: return g.op("Min", self, dim_or_y) # torch.min(input, dim, keepdim), returns (min_values, min_indices) - return g.op( - "org.pytorch.aten::ATen", self, dim_or_y, keepdim, operator_s="aten::min", overload_name_s="dim", outputs=2 - ) + return g.op("org.pytorch.aten::ATen", self, dim_or_y, keepdim, operator_s="min", overload_name_s="dim", outputs=2) @register_symbolic("unfold") def unfold(g, input, dimension, size, step): - return g.op("org.pytorch.aten::ATen", input, dimension, size, step, operator_s="aten::unfold") + return g.op("org.pytorch.aten::ATen", input, dimension, size, step, operator_s="unfold") @register_symbolic("argmax") def argmax(g, input, dim, keepdim): - return g.op("org.pytorch.aten::ATen", input, dim, keepdim, operator_s="aten::argmax") + return g.op("org.pytorch.aten::ATen", input, dim, keepdim, operator_s="argmax") @register_symbolic("avg_pool2d") @@ -169,13 +165,13 @@ def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_p ceil_mode, count_include_pad, divisor_override, - operator_s="aten::avg_pool2d", + operator_s="avg_pool2d", ) @register_symbolic("adaptive_avg_pool2d") def adaptive_avg_pool2d(g, self, output_size): - return g.op("org.pytorch.aten::ATen", self, output_size, operator_s="aten::_adaptive_avg_pool2d") + return g.op("org.pytorch.aten::ATen", self, output_size, operator_s="_adaptive_avg_pool2d") @register_symbolic("binary_cross_entropy_with_logits") @@ -191,7 +187,7 @@ def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduct weight, pos_weight, reduction, - operator_s="aten::binary_cross_entropy_with_logits", + operator_s="binary_cross_entropy_with_logits", ) from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce @@ -209,7 +205,7 @@ def numpy_T(g, self): else: # if we don't have dim information we cannot # output a permute so use ATen instead - return g.op("com.microsoft::ATenOp", self, name_s="aten::numpy_T") + return g.op("org.pytorch.aten::ATen", self, operator_s="numpy_T") @register_symbolic("squeeze") diff --git a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc deleted file mode 100644 index d6b84a5425..0000000000 --- a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "orttraining/training_ops/cpu/aten_ops/aten_op.h" - -#include "core/dlpack/dlpack_converter.h" -#include "core/framework/op_kernel_context_internal.h" -#include "orttraining/training_ops/cpu/aten_ops/aten_op_executor.h" - -namespace onnxruntime { -namespace contrib { - -ONNX_OPERATOR_KERNEL_EX(ATen, kPytorchAtenDomain, 1, kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::AllTensorAndSequenceTensorTypes()), - ATen); - -Status ATen::Compute(OpKernelContext* p_ctx) const { - auto* p_ctx_internal = static_cast(p_ctx); - std::vector dlpacks; - for (int i = 0; i < p_ctx_internal->InputCount(); i++) { - const OrtValue* p_ort_value = p_ctx_internal->GetInputMLValue(i); - if (!p_ort_value) { - dlpacks.emplace_back(nullptr); - } else { - OrtValue ort_value = *p_ort_value; - dlpacks.emplace_back(dlpack::OrtValueToDlpack(ort_value)); - } - } - - auto result = aten_ops::ATenOperatorExecutor::Instance()(op_name_, overload_name_, dlpacks); - for (size_t i = 0; i < result.size(); i++) { - ORT_RETURN_IF_ERROR(p_ctx_internal->SetOutputMLValue(static_cast(i), dlpack::DlpackToOrtValue(result[i]))); - } - - return Status::OK(); -} - -bool IsATenOperatorExecutorInitialized() { - return aten_ops::ATenOperatorExecutor::Instance().IsInitialized(); -} - -Status ExecuteReduceSumATen(OpKernelContext* p_ctx, const gsl::span& axes, bool keepdims) { - ORT_ENFORCE(aten_ops::ATenOperatorExecutor::Instance().IsInitialized() && !axes.empty()); - std::vector dlpacks; - auto* p_ctx_internal = static_cast(p_ctx); - OrtValue ort_value = *p_ctx_internal->GetInputMLValue(0); - dlpacks.emplace_back(dlpack::OrtValueToDlpack(ort_value)); - OrtValue axes_tensor; - OrtValue keepdims_tensor; - TensorShapeVector axes_tensor_shape(1, static_cast(axes.size())); - TensorShapeVector keepdims_tensor_shape(1, 1); - auto ml_tensor = DataTypeImpl::GetType(); - OrtMemoryInfo info("Cpu", OrtDeviceAllocator); - auto axes_tensor_obj = std::make_unique(DataTypeImpl::GetType(), axes_tensor_shape, - const_cast(reinterpret_cast(&axes[0])), info); - axes_tensor.Init(axes_tensor_obj.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - auto keepdims_tensor_obj = std::make_unique(DataTypeImpl::GetType(), keepdims_tensor_shape, reinterpret_cast(&keepdims), info); - keepdims_tensor.Init(keepdims_tensor_obj.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - dlpacks.emplace_back(dlpack::OrtValueToDlpack(axes_tensor)); - dlpacks.emplace_back(dlpack::OrtValueToDlpack(keepdims_tensor)); - dlpacks.emplace_back(nullptr); - auto result = aten_ops::ATenOperatorExecutor::Instance()("aten::sum", "dim_IntList", dlpacks); - ORT_RETURN_IF_ERROR(p_ctx_internal->SetOutputMLValue(0, dlpack::DlpackToOrtValue(result[0]))); - return Status::OK(); -} - -} // namespace contrib -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index 2475ffb9bb..ec503cd9b6 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -98,7 +98,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Recv) class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, RecordEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WaitEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, YieldOp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SummaryScalar); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SummaryHistogram); @@ -208,7 +207,6 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/aten_ops/aten_op.cc b/orttraining/orttraining/training_ops/cuda/aten_ops/aten_op.cc deleted file mode 100644 index 3a1594d87f..0000000000 --- a/orttraining/orttraining/training_ops/cuda/aten_ops/aten_op.cc +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/shared_library/provider_api.h" -#include "orttraining/training_ops/cpu/aten_ops/aten_op.h" -#include "core/providers/cuda/cuda_fwd.h" - -namespace onnxruntime { -namespace cuda { - -ONNX_OPERATOR_KERNEL_EX(ATen, kPytorchAtenDomain, 1, kCudaExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::AllTensorAndSequenceTensorTypes()), - onnxruntime::contrib::ATen); - -} // namespace cuda -} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 6faf6edd40..0868f1760e 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -220,7 +220,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Adas class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, RecordEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); #ifdef ENABLE_TRAINING_TORCH_INTEROP class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOp); @@ -443,7 +442,6 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, #ifdef ENABLE_TRAINING_TORCH_INTEROP BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 5d121b57d1..44a5988523 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -195,7 +195,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Adas class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, RecordEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, WaitEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, YieldOp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kPytorchAtenDomain, 1, ATen); #ifdef ENABLE_TRAINING_TORCH_INTEROP class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PythonOp); @@ -392,7 +391,6 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, #ifdef ENABLE_TRAINING_TORCH_INTEROP BuildKernelCreateInfo, diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml new file mode 100644 index 0000000000..a97c363526 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -0,0 +1,77 @@ +jobs: +- job: Linux_Build + timeoutInMinutes: 120 + workspace: + clean: all + pool: Linux-CPU-2019 + steps: + - checkout: self + clean: true + submodules: recursive + + - task: NodeTool@0 + inputs: + versionSpec: '12.16.3' + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimecpubuildaten + + - task: CmdLine@2 + displayName: 'build' + inputs: + script: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuildaten \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel \ + --build_wheel \ + --skip_tests \ + --cmake_extra_defines onnxruntime_ENABLE_ATEN=ON + workingDirectory: $(Build.SourcesDirectory) + + - task: CmdLine@2 + displayName: 'install ort_torch_ext and launch test' + inputs: + script: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuildaten \ + bash -c "rm -rf /build/Release/onnxruntime /build/Release/pybind11 && \ + /opt/python/cp38-cp38/bin/python3 -m pip install /build/Release/dist/*.whl && \ + /opt/python/cp38-cp38/bin/python3 -m pip install /onnxruntime_src/onnxruntime/python/torch_cpp_extensions && \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/onnxruntime/test/python/contrib_ops/aten_op_tests.py && \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel \ + --build_wheel \ + --test \ + --cmake_extra_defines onnxruntime_ENABLE_ATEN=ON" + workingDirectory: $(Build.SourcesDirectory) + + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu new file mode 100644 index 0000000000..ad3e783040 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu @@ -0,0 +1,10 @@ +FROM quay.io/pypa/manylinux2014_x86_64:latest + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps_aten.sh && rm -rf /tmp/scripts + +ARG BUILD_UID=1001 +ARG BUILD_USER=onnxruntimedev +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh new file mode 100755 index 0000000000..0a8bbef0c8 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh @@ -0,0 +1,112 @@ +#!/bin/bash +set -e -x + +# Development tools and libraries +yum -y install \ + graphviz + +# Download a file from internet +function GetFile { + local uri=$1 + local path=$2 + local force=${3:-false} + local download_retries=${4:-5} + local retry_wait_time_seconds=${5:-30} + + if [[ -f $path ]]; then + if [[ $force = false ]]; then + echo "File '$path' already exists. Skipping download" + return 0 + else + rm -rf $path + fi + fi + + if [[ -f $uri ]]; then + echo "'$uri' is a file path, copying file to '$path'" + cp $uri $path + return $? + fi + + echo "Downloading $uri" + # Use aria2c if available, otherwise use curl + if command -v aria2c > /dev/null; then + aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" + else + curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail + fi + + return $? +} + +if [ ! -d "/opt/conda/bin" ]; then + PYTHON_EXES=("/opt/python/cp37-cp37m/bin/python3.7" "/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9") +else + PYTHON_EXES=("/opt/conda/bin/python") +fi + +os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) + +SYS_LONG_BIT=$(getconf LONG_BIT) +mkdir -p /tmp/src +GLIBC_VERSION=$(getconf GNU_LIBC_VERSION | cut -f 2 -d \.) + +DISTRIBUTOR=$(lsb_release -i -s) + +if [[ "$DISTRIBUTOR" = "CentOS" && $SYS_LONG_BIT = "64" ]]; then + LIBDIR="lib64" +else + LIBDIR="lib" +fi + +cd /tmp/src + +echo "Installing azcopy" +mkdir -p /tmp/azcopy +GetFile https://aka.ms/downloadazcopy-v10-linux /tmp/azcopy/azcopy.tar.gz +tar --strip 1 -xf /tmp/azcopy/azcopy.tar.gz -C /tmp/azcopy +cp /tmp/azcopy/azcopy /usr/bin + +echo "Installing Ninja" +GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz +tar -zxf ninja-linux.tar.gz +cd ninja-1.10.0 +cmake -Bbuild-cmake -H. +cmake --build build-cmake +mv ./build-cmake/ninja /usr/bin +echo "Installing Node.js" +GetFile https://nodejs.org/dist/v16.14.2/node-v16.14.2-linux-x64.tar.gz /tmp/src/node-v16.14.2-linux-x64.tar.gz +tar --strip 1 -xf /tmp/src/node-v16.14.2-linux-x64.tar.gz -C /usr + +cd /tmp/src +GetFile https://downloads.gradle-dn.com/distributions/gradle-6.3-bin.zip /tmp/src/gradle-6.3-bin.zip +unzip /tmp/src/gradle-6.3-bin.zip +mv /tmp/src/gradle-6.3 /usr/local/gradle + +if ! [ -x "$(command -v protoc)" ]; then + source ${0/%install_deps_aten\.sh/..\/install_protobuf.sh} +fi + +export ONNX_ML=1 +export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" + +for PYTHON_EXE in "${PYTHON_EXES[@]}" +do + ${PYTHON_EXE} -m pip install -r ${0/%install_deps_aten\.sh/requirements\.txt} + if ![[ ${PYTHON_EXE} = "/opt/python/cp310-cp310/bin/python3.10" ]]; then + ${PYTHON_EXE} -m pip install -r ${0/%install_deps_aten\.sh/..\/training\/ortmodule\/stage1\/requirements_torch_cpu\/requirements.txt} + else + ${PYTHON_EXE} -m pip install torch==1.11.0 + fi +done + +cd /tmp/src +GetFile 'https://sourceware.org/pub/valgrind/valgrind-3.16.1.tar.bz2' /tmp/src/valgrind-3.16.1.tar.bz2 +tar -jxvf valgrind-3.16.1.tar.bz2 +cd valgrind-3.16.1 +./configure --prefix=/usr --libdir=/usr/lib64 --enable-only64bit --enable-tls +make -j$(getconf _NPROCESSORS_ONLN) +make install + +cd / +rm -rf /tmp/src