From 6090d8cd6eb89ab8da000ba41ccabb48ce798feb Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Wed, 14 Dec 2022 08:32:46 -0800 Subject: [PATCH] Fix usage of enable_training_ops and reduce ifdef complexity for training builds (#13888) ### Description Fix usage of enable_training_ops and reduce ifdef complexity for training builds. ### Motivation and Context This is the second refactoring PR towards creating a dedicated build for on device training. This PR aims to reduce some complexity. We can set ENABLE_TRAINING_OPS in cmake when either ENABLE_TRAINING or ENABLE_TRAINING_ON_DEVICE is selected, this way we dont have to use if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_ON_DEVICE ) everywhere in the code. - If it fixes an open issue, please link to the issue here. --> --- cmake/CMakeLists.txt | 2 +- cmake/onnxruntime_framework.cmake | 2 +- cmake/onnxruntime_graph.cmake | 6 +- cmake/onnxruntime_providers.cmake | 77 +++++++++-------- cmake/onnxruntime_session.cmake | 2 +- cmake/onnxruntime_webassembly.cmake | 4 +- .../providers/cpu/cpu_execution_provider.cc | 4 +- .../core/providers/cpu/cpu_provider_shared.cc | 22 +++-- .../core/providers/cpu/cpu_provider_shared.h | 24 ++++-- .../core/providers/cpu/tensor/scatter.cc | 4 +- .../providers/cuda/cuda_execution_provider.cc | 4 +- .../providers/cuda/shared_inc/fpgeneric.h | 2 +- .../cuda/tensor/gather_elements_impl.cu | 4 +- .../providers/cuda/tensor/gather_nd_impl.h | 2 +- .../core/providers/cuda/tensor/slice_impl.cu | 4 +- .../core/providers/cuda/tensor/slice_impl.h | 4 +- .../providers/rocm/rocm_execution_provider.cc | 4 +- .../provider_bridge_provider.cc | 17 ++-- onnxruntime/core/session/environment.cc | 8 +- .../training_ops/cpu/nn/dropout_7.cc | 2 +- .../cuda/cuda_training_kernels.cc | 82 +++++++++++-------- 21 files changed, 161 insertions(+), 119 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index e745c92522..d17c93b7f9 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -353,6 +353,7 @@ file (STRINGS "${REPO_ROOT}/VERSION_NUMBER" ORT_VERSION) if (onnxruntime_ENABLE_TRAINING) set(onnxruntime_ENABLE_ATEN ON) + set(onnxruntime_ENABLE_TRAINING_OPS ON) endif() find_package(Threads) @@ -1270,7 +1271,6 @@ endif() if (onnxruntime_ENABLE_TRAINING) add_compile_definitions(ENABLE_TRAINING) - add_compile_definitions(ENABLE_TRAINING_OPS) add_compile_definitions(ENABLE_STRIDED_TENSORS) if (UNIX) diff --git a/cmake/onnxruntime_framework.cmake b/cmake/onnxruntime_framework.cmake index 79a3b32016..575a8903e1 100644 --- a/cmake/onnxruntime_framework.cmake +++ b/cmake/onnxruntime_framework.cmake @@ -47,7 +47,7 @@ else() target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) endif() # Needed for the provider interface, as it includes training headers when training is enabled -if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_framework PRIVATE ${ORTTRAINING_ROOT}) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) onnxruntime_add_include_to_target(onnxruntime_framework Python::Module) diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 12f4cd1033..e43b1197b0 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -80,7 +80,7 @@ if (onnxruntime_ENABLE_TRAINING) endif() set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) -if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS) list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src}) endif() @@ -107,7 +107,7 @@ endif() target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT}) -if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_graph PRIVATE ${ORTTRAINING_ROOT}) if (onnxruntime_USE_NCCL) @@ -119,7 +119,7 @@ set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/graph DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) -if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS) source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src}) endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index b80e75e8c5..fe89280571 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -51,7 +51,7 @@ endfunction() # `target`. function(add_op_reduction_include_dirs target) set(op_reduction_include_dirs "${op_reduction_root}/onnxruntime") - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) list(APPEND op_reduction_include_dirs "${op_reduction_root}/orttraining") endif() # add include directories BEFORE so they are searched first, giving op reduction file paths precedence @@ -166,8 +166,7 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) endif() - -if (onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc" @@ -191,6 +190,8 @@ if (onnxruntime_ENABLE_TRAINING_OPS) "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.cc" "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.h" ) list(REMOVE_ITEM onnxruntime_providers_src ${onnxruntime_cpu_full_training_only_srcs}) @@ -276,7 +277,7 @@ target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${e onnxruntime_add_include_to_target(onnxruntime_providers re2::re2) add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) -if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers PRIVATE ${ORTTRAINING_ROOT}) endif() @@ -384,38 +385,50 @@ if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs}) endif() - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc" ) file(GLOB_RECURSE onnxruntime_cuda_training_ops_cu_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cu" "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cuh" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cu" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cuh" ) - # NCCL is not support in Windows build - if (WIN32) - list(REMOVE_ITEM onnxruntime_cuda_training_ops_cc_srcs - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc" - ) - elseif (NOT onnxruntime_USE_NCCL) - list(REMOVE_ITEM onnxruntime_cuda_training_ops_cc_srcs - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc" - "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc" - ) - endif() - source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs}) list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs}) + + if(NOT onnxruntime_ENABLE_TRAINING) + file(GLOB_RECURSE onnxruntime_cuda_full_training_only_srcs + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/yield.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/yield.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.h" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cu" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.h" + ) + + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_full_training_only_srcs}) + elseif(WIN32 OR NOT onnxruntime_USE_NCCL) + # NCCL is not support in Windows build + file(GLOB_RECURSE onnxruntime_cuda_nccl_op_srcs + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc" + "${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc" + ) + + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_nccl_op_srcs}) + endif() endif() if (onnxruntime_REDUCED_OPS_BUILD) @@ -453,7 +466,7 @@ if (onnxruntime_USE_CUDA) endif() onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers) - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_training) target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) @@ -481,7 +494,7 @@ if (onnxruntime_USE_CUDA) target_link_libraries(onnxruntime_providers_cuda PRIVATE nvToolsExt) endif() - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers_cuda PRIVATE ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS}) if(onnxruntime_USE_MPI) target_link_libraries(onnxruntime_providers_cuda PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) @@ -561,7 +574,7 @@ if (onnxruntime_USE_DNNL) set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX) # Needed for the provider interface, as it includes training headers when training is enabled - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ORTTRAINING_ROOT}) endif() @@ -693,7 +706,7 @@ if (onnxruntime_USE_TENSORRT) endif() # Needed for the provider interface, as it includes training headers when training is enabled - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT}) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module) @@ -808,7 +821,7 @@ if (onnxruntime_USE_OPENVINO) endif() # Needed for the provider interface, as it includes training headers when training is enabled - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_providers_openvino PRIVATE ${ORTTRAINING_ROOT}) endif() @@ -1332,7 +1345,7 @@ if (onnxruntime_USE_ROCM) list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} ${onnxruntime_rocm_generated_contrib_ops_cu_srcs}) endif() - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) file(GLOB_RECURSE onnxruntime_rocm_training_ops_cc_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.h" "${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cc" @@ -1370,7 +1383,7 @@ if (onnxruntime_USE_ROCM) endif() onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 safeint_interface) - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING_OPS) onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_training) target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_training) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) @@ -1402,7 +1415,7 @@ if (onnxruntime_USE_ROCM) set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime") - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS}) if(onnxruntime_USE_MPI) target_link_libraries(onnxruntime_providers_rocm PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS}) diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 2d0b3dcb27..5105eeb77f 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -51,7 +51,7 @@ if (onnxruntime_USE_ROCM) # ROCM provider sources are generated, need to add include directory for generated headers target_include_directories(onnxruntime_session PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining) endif() -if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) +if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_session PRIVATE ${ORTTRAINING_ROOT}) endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 18d6bf4e57..25761aa841 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -114,7 +114,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB) re2::re2 ) - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING) bundle_static_library(onnxruntime_webassembly tensorboard) endif() @@ -192,7 +192,7 @@ else() target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) endif() - if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) + if (onnxruntime_ENABLE_TRAINING) target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard) endif() diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index d944130bb6..f7f12efe55 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -10,7 +10,7 @@ #include "contrib_ops/cpu/cpu_contrib_kernels.h" #endif -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) #include "orttraining/training_ops/cpu/cpu_training_kernels.h" #endif @@ -2190,7 +2190,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) { #ifndef DISABLE_CONTRIB_OPS ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry)); #endif -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuTrainingKernels(kernel_registry)); #endif return Status::OK(); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 2122269f29..94536a7832 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -38,16 +38,19 @@ #endif #endif -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/cpu/controlflow/group.h" -#include "orttraining/training_ops/cpu/controlflow/record.h" -#include "orttraining/training_ops/cpu/controlflow/wait.h" -#include "orttraining/training_ops/cpu/controlflow/yield.h" #include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h" #include "orttraining/training_ops/cpu/tensor/split.h" #include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h" #endif +#ifdef ENABLE_TRAINING +#include "orttraining/training_ops/cpu/controlflow/record.h" +#include "orttraining/training_ops/cpu/controlflow/wait.h" +#include "orttraining/training_ops/cpu/controlflow/yield.h" +#endif + #include "cpu_provider_shared.h" namespace onnxruntime { @@ -248,16 +251,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU { #endif #endif -#ifdef ENABLE_TRAINING - void contrib__record_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::record_event_in_tensor(event_id_tensor); } - void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::wait_event_in_tensor(event_id_tensor); } +#ifdef ENABLE_TRAINING_OPS Status contrib__Group__Compute(const contrib::Group* p, OpKernelContext* context) override { return p->Group::Compute(context); } Status contrib__PassThrough__Compute(const contrib::PassThrough* p, OpKernelContext* context) override { return p->PassThrough::Compute(context); } void contrib__VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) override { contrib::VerifyLogitWeightAndLabelShape(logit_shape, label_shape, weight_shape); } void contrib__GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) override { contrib::GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); } void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, TensorShapeVector& new_shape, std::vector& permutations) override { contrib::GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); } Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); } - Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); } // From cpu/optimizer/adamwbase.h (direct) Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, contrib__AdamWOptimizerBase__Prepare& prepare) override { @@ -270,6 +270,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { const TensorSeq* values, TensorSeq* updated_values) override { return p->AdamWOptimizerBase::GenerateOutputs(ctx, number_of_values, values, updated_values); } +#endif + +#ifdef ENABLE_TRAINING + void contrib__record_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::record_event_in_tensor(event_id_tensor); } + void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::wait_event_in_tensor(event_id_tensor); } + Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); } // From aten_op.h (direct) bool contrib__IsATenOperatorExecutorInitialized() override { diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 9b77a0c774..ebc65ece53 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -178,20 +178,24 @@ struct ProviderHostCPU { #endif #endif -#ifdef ENABLE_TRAINING - virtual void contrib__record_event_in_tensor(const Tensor& event_id_tensor) = 0; - virtual void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) = 0; +#ifdef ENABLE_TRAINING_OPS virtual Status contrib__Group__Compute(const contrib::Group* p, OpKernelContext* context) = 0; virtual Status contrib__PassThrough__Compute(const contrib::PassThrough* p, OpKernelContext* context) = 0; virtual void contrib__VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) = 0; virtual void contrib__GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) = 0; virtual void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, TensorShapeVector& new_shape, std::vector& permutations) = 0; virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) = 0; - virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0; // From cpu/optimizer/adamwbase.h virtual Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, contrib__AdamWOptimizerBase__Prepare& prepare) = 0; virtual Status contrib__AdamWOptimizerBase__GenerateOutputs(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, size_t number_of_values, const TensorSeq* values, TensorSeq* updated_values) = 0; +#endif + +#ifdef ENABLE_TRAINING + virtual void contrib__record_event_in_tensor(const Tensor& event_id_tensor) = 0; + virtual void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) = 0; + virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0; + // From aten_op.h virtual bool contrib__IsATenOperatorExecutorInitialized() = 0; virtual Status contrib__ExecuteReduceSumATen(OpKernelContext* p_ctx, const gsl::span& axes, bool keepdims) = 0; @@ -252,15 +256,19 @@ struct EinsumTypedComputeProcessor { Status Run() { return g_host_cpu.EinsumTypedComputeProcessor__Run(this); } }; -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS namespace contrib { -inline void record_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__record_event_in_tensor(event_id_tensor); } -inline void wait_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__wait_event_in_tensor(event_id_tensor); } - inline void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) { g_host_cpu.contrib__VerifyLogitWeightAndLabelShape(logit_shape, label_shape, weight_shape); } inline void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) { g_host_cpu.contrib__GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); } inline void GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, TensorShapeVector& new_shape, std::vector& permutations) { g_host_cpu.contrib__GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); } inline Status PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) { return g_host_cpu.contrib__PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); } +} // namespace contrib +#endif + +#ifdef ENABLE_TRAINING +namespace contrib { +inline void record_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__record_event_in_tensor(event_id_tensor); } +inline void wait_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__wait_event_in_tensor(event_id_tensor); } // From aten_op.h inline bool IsATenOperatorExecutorInitialized() { return g_host_cpu.contrib__IsATenOperatorExecutorInitialized(); } diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc index 1701b2beed..7e5a304445 100644 --- a/onnxruntime/core/providers/cpu/tensor/scatter.cc +++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc @@ -12,7 +12,7 @@ #include "core/framework/op_kernel_type_control_utils.h" #include "core/providers/common.h" #include "core/providers/op_kernel_type_control.h" -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) #include "orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h" #endif @@ -394,7 +394,7 @@ Status Scatter::Compute(OpKernelContext* context) const { return status; } -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) namespace contrib { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 258c226207..7de3efc55c 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -16,7 +16,7 @@ #include "contrib_ops/cuda/cuda_contrib_kernels.h" #endif -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/cuda/cuda_training_kernels.h" #endif @@ -2294,7 +2294,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry)); #endif -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry)); #endif diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index fbe7a94166..7c95b246dd 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -308,7 +308,7 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, long long int strideC, int batch_count, const cudaDeviceProp& prop) { -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); #else diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu index a1ceab4341..51389483ad 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu @@ -3,7 +3,7 @@ #include "core/providers/cuda/tensor/gather_elements_impl.h" #include "core/providers/cuda/tensor/scatter_elements_impl.h" -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h" #endif @@ -259,7 +259,7 @@ GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(double) #undef GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL #undef GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS template struct FuncAtomicAdd { diff --git a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h index 828f6ab6af..e424ec9da8 100644 --- a/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gather_nd_impl.h @@ -29,7 +29,7 @@ void GatherNDImpl( const size_t slice_size, const int64_t* input_slice_offsets_data); -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS template void GatherNDGradImpl( cudaStream_t stream, diff --git a/onnxruntime/core/providers/cuda/tensor/slice_impl.cu b/onnxruntime/core/providers/cuda/tensor/slice_impl.cu index 60a3e4e550..df392b45e9 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/slice_impl.cu @@ -118,7 +118,7 @@ Status SliceImpl(cudaStream_t stream, const size_t element_size, const int32_t d input_data, output_data, N); } -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS Status SliceImplGrad(cudaStream_t stream, const size_t element_size, const int32_t dimension_count, const TArray& starts, const TArray& steps, const TArray& input_strides, const TArray& output_strides, const void* input_data, void* output_data, @@ -126,7 +126,7 @@ Status SliceImplGrad(cudaStream_t stream, const size_t element_size, const int32 return SliceImplEx(stream, element_size, dimension_count, starts, steps, input_strides, output_strides, input_data, output_data, N); } -#endif // ENABLE_TRAINING +#endif // ENABLE_TRAINING_OPS } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/slice_impl.h b/onnxruntime/core/providers/cuda/tensor/slice_impl.h index 2064b3e9d9..ced1070d0f 100644 --- a/onnxruntime/core/providers/cuda/tensor/slice_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/slice_impl.h @@ -19,7 +19,7 @@ Status SliceImpl(cudaStream_t stream, void* output_data, const size_t N); -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS Status SliceImplGrad(cudaStream_t stream, const size_t element_size, const int32_t dimension_count, @@ -30,7 +30,7 @@ Status SliceImplGrad(cudaStream_t stream, const void* input_data, void* output_data, const size_t N); -#endif // ENABLE_TRAINING +#endif // ENABLE_TRAINING_OPS } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 5334fe70d6..e4af7068b7 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -15,7 +15,7 @@ #include "contrib_ops/rocm/rocm_contrib_kernels.h" #endif -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/rocm/rocm_training_kernels.h" #endif @@ -2260,7 +2260,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { ORT_RETURN_IF_ERROR(::onnxruntime::contrib::rocm::RegisterRocmContribKernels(kernel_registry)); #endif -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS ORT_RETURN_IF_ERROR(::onnxruntime::rocm::RegisterRocmTrainingKernels(kernel_registry)); #endif diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 418a4a1ce8..65a45cb1a1 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -37,10 +37,13 @@ #endif #endif -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS #include "orttraining/training_ops/cpu/controlflow/group.h" -#include "orttraining/training_ops/cpu/controlflow/yield.h" #include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h" +#endif + +#ifdef ENABLE_TRAINING +#include "orttraining/training_ops/cpu/controlflow/yield.h" #ifdef ENABLE_TRAINING_TORCH_INTEROP #include "orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h" @@ -621,21 +624,23 @@ Status Scan<8>::SetupSubgraphExecutionInfo(const SessionState& session_state, co template <> Status Scan<9>::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { return g_host_cpu.Scan__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } -#ifdef ENABLE_TRAINING +#ifdef ENABLE_TRAINING_OPS namespace contrib { Status Group::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__Group__Compute(this, context); } Status PassThrough::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__PassThrough__Compute(this, context); } -Status YieldOp::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__YieldOp__Compute(this, context); } - Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx, AdamWOptimizerBase::Prepare& prepare) const { return g_host_cpu.contrib__AdamWOptimizerBase__PrepareForCompute(this, ctx, reinterpret_cast(prepare)); } - Status AdamWOptimizerBase::GenerateOutputs(OpKernelContext* ctx, size_t number_of_values, const TensorSeq* values, TensorSeq* updated_values) const { return g_host_cpu.contrib__AdamWOptimizerBase__GenerateOutputs(this, ctx, number_of_values, values, updated_values); } +} +#endif +#ifdef ENABLE_TRAINING +namespace contrib { +Status YieldOp::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__YieldOp__Compute(this, context); } } // namespace contrib #ifdef ENABLE_TRAINING_TORCH_INTEROP diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index e9bc637d8b..cf0e70853a 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -13,7 +13,7 @@ #include "core/graph/contrib_ops/internal_nhwc_onnx_opset.h" #include "core/graph/contrib_ops/ms_opset.h" #include "core/graph/contrib_ops/onnx_deprecated_opset.h" -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) #include "onnx/defs/operator_sets_training.h" #endif #endif @@ -31,7 +31,7 @@ #include "core/platform/tracing.h" #endif -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) #include "orttraining/core/graph/training_op_defs.h" #endif #ifdef ENABLE_TRAINING @@ -256,11 +256,11 @@ Status Environment::Initialize(std::unique_ptr logging_ RegisterOnnxMLOperatorSetSchema(); #endif -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) RegisterOnnxTrainingOperatorSetSchema(); #endif -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) // preserve this order until : this depends on operatorsetschema registration. training::RegisterTrainingOpSchemas(); #endif diff --git a/orttraining/orttraining/training_ops/cpu/nn/dropout_7.cc b/orttraining/orttraining/training_ops/cpu/nn/dropout_7.cc index 4aea989c9f..7707011e59 100644 --- a/orttraining/orttraining/training_ops/cpu/nn/dropout_7.cc +++ b/orttraining/orttraining/training_ops/cpu/nn/dropout_7.cc @@ -6,7 +6,7 @@ namespace onnxruntime { -#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) +#if defined(ENABLE_TRAINING_OPS) ONNX_CPU_OPERATOR_KERNEL( Dropout, 7, diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index fe54adb680..642a71fe4c 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -153,24 +153,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Gath class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Decoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal); @@ -212,6 +194,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inpl class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant); +// the kernels within the following ifdef are not included in a build with +// --enable_training_ops but without --enable_training +#ifdef ENABLE_TRAINING #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send); @@ -226,6 +211,25 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Reco class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Encoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Decoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Encoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Decoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Encoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Decoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder); + #ifdef ENABLE_TRAINING_TORCH_INTEROP class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOpGrad); @@ -238,6 +242,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Nccl class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MegatronF); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MegatronG); #endif +#endif Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -380,24 +385,6 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -440,6 +427,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, +// the kernels within the following ifdef are not included in a build with +// --enable_training_ops but without --enable_training +#ifdef ENABLE_TRAINING // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI) BuildKernelCreateInfo, @@ -454,6 +444,25 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + #ifdef ENABLE_TRAINING_TORCH_INTEROP BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -465,6 +474,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, +#endif #endif };