Fix usage of enable_training_ops and reduce ifdef complexity for training builds (#13888)

### Description
Fix usage of enable_training_ops and reduce ifdef complexity for
training builds.




### Motivation and Context
This is the second refactoring PR towards creating a dedicated build for
on device training. This PR aims to reduce some complexity. We can set
ENABLE_TRAINING_OPS in cmake when either ENABLE_TRAINING or
ENABLE_TRAINING_ON_DEVICE is selected, this way we dont have to use if
defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_ON_DEVICE )
everywhere in the code.

- If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
Ashwini Khade 2022-12-14 08:32:46 -08:00 committed by GitHub
parent 7894d44d2d
commit 6090d8cd6e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 161 additions and 119 deletions

View file

@ -353,6 +353,7 @@ file (STRINGS "${REPO_ROOT}/VERSION_NUMBER" ORT_VERSION)
if (onnxruntime_ENABLE_TRAINING)
set(onnxruntime_ENABLE_ATEN ON)
set(onnxruntime_ENABLE_TRAINING_OPS ON)
endif()
find_package(Threads)
@ -1270,7 +1271,6 @@ endif()
if (onnxruntime_ENABLE_TRAINING)
add_compile_definitions(ENABLE_TRAINING)
add_compile_definitions(ENABLE_TRAINING_OPS)
add_compile_definitions(ENABLE_STRIDED_TENSORS)
if (UNIX)

View file

@ -47,7 +47,7 @@ else()
target_include_directories(onnxruntime_framework PRIVATE ${ONNXRUNTIME_ROOT} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
endif()
# Needed for the provider interface, as it includes training headers when training is enabled
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_framework PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
onnxruntime_add_include_to_target(onnxruntime_framework Python::Module)

View file

@ -80,7 +80,7 @@ if (onnxruntime_ENABLE_TRAINING)
endif()
set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src})
endif()
@ -107,7 +107,7 @@ endif()
target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT})
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_graph PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_USE_NCCL)
@ -119,7 +119,7 @@ set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX)
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/graph DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src})
endif()

View file

@ -51,7 +51,7 @@ endfunction()
# `target`.
function(add_op_reduction_include_dirs target)
set(op_reduction_include_dirs "${op_reduction_root}/onnxruntime")
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
list(APPEND op_reduction_include_dirs "${op_reduction_root}/orttraining")
endif()
# add include directories BEFORE so they are searched first, giving op reduction file paths precedence
@ -166,8 +166,7 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs})
endif()
if (onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
file(GLOB_RECURSE onnxruntime_cpu_training_ops_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/*.cc"
@ -191,6 +190,8 @@ if (onnxruntime_ENABLE_TRAINING_OPS)
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/gist/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/tensorboard/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cpu/torch/*.h"
)
list(REMOVE_ITEM onnxruntime_providers_src ${onnxruntime_cpu_full_training_only_srcs})
@ -276,7 +277,7 @@ target_include_directories(onnxruntime_providers PRIVATE ${ONNXRUNTIME_ROOT} ${e
onnxruntime_add_include_to_target(onnxruntime_providers re2::re2)
add_dependencies(onnxruntime_providers onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers PRIVATE ${ORTTRAINING_ROOT})
endif()
@ -384,38 +385,50 @@ if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_contrib_ops_cc_srcs} ${onnxruntime_cuda_contrib_ops_cu_srcs})
endif()
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
file(GLOB_RECURSE onnxruntime_cuda_training_ops_cc_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc"
)
file(GLOB_RECURSE onnxruntime_cuda_training_ops_cu_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cu"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/*.cuh"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cu"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cuh"
)
# NCCL is not support in Windows build
if (WIN32)
list(REMOVE_ITEM onnxruntime_cuda_training_ops_cc_srcs
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc"
)
elseif (NOT onnxruntime_USE_NCCL)
list(REMOVE_ITEM onnxruntime_cuda_training_ops_cc_srcs
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc"
)
endif()
source_group(TREE ${ORTTRAINING_ROOT} FILES ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs})
list(APPEND onnxruntime_providers_cuda_src ${onnxruntime_cuda_training_ops_cc_srcs} ${onnxruntime_cuda_training_ops_cu_srcs})
if(NOT onnxruntime_ENABLE_TRAINING)
file(GLOB_RECURSE onnxruntime_cuda_full_training_only_srcs
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/communication/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/record.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/wait.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/yield.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/controlflow/yield.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/gist/*.cu"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/torch/*.h"
)
list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_full_training_only_srcs})
elseif(WIN32 OR NOT onnxruntime_USE_NCCL)
# NCCL is not support in Windows build
file(GLOB_RECURSE onnxruntime_cuda_nccl_op_srcs
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_common.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/nccl_kernels.cc"
"${ORTTRAINING_SOURCE_DIR}/training_ops/cuda/collective/megatron.cc"
)
list(REMOVE_ITEM onnxruntime_providers_cuda_src ${onnxruntime_cuda_nccl_op_srcs})
endif()
endif()
if (onnxruntime_REDUCED_OPS_BUILD)
@ -453,7 +466,7 @@ if (onnxruntime_USE_CUDA)
endif()
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda onnxruntime_training)
target_link_libraries(onnxruntime_providers_cuda PRIVATE onnxruntime_training)
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
@ -481,7 +494,7 @@ if (onnxruntime_USE_CUDA)
target_link_libraries(onnxruntime_providers_cuda PRIVATE nvToolsExt)
endif()
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${ORTTRAINING_ROOT} ${MPI_CXX_INCLUDE_DIRS})
if(onnxruntime_USE_MPI)
target_link_libraries(onnxruntime_providers_cuda PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS})
@ -561,7 +574,7 @@ if (onnxruntime_USE_DNNL)
set_target_properties(onnxruntime_providers_dnnl PROPERTIES LINKER_LANGUAGE CXX)
# Needed for the provider interface, as it includes training headers when training is enabled
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers_dnnl PRIVATE ${ORTTRAINING_ROOT})
endif()
@ -693,7 +706,7 @@ if (onnxruntime_USE_TENSORRT)
endif()
# Needed for the provider interface, as it includes training headers when training is enabled
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt Python::Module)
@ -808,7 +821,7 @@ if (onnxruntime_USE_OPENVINO)
endif()
# Needed for the provider interface, as it includes training headers when training is enabled
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_providers_openvino PRIVATE ${ORTTRAINING_ROOT})
endif()
@ -1332,7 +1345,7 @@ if (onnxruntime_USE_ROCM)
list(APPEND onnxruntime_providers_rocm_src ${onnxruntime_rocm_generated_contrib_ops_cc_srcs} ${onnxruntime_rocm_generated_contrib_ops_cu_srcs})
endif()
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
file(GLOB_RECURSE onnxruntime_rocm_training_ops_cc_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.h"
"${ORTTRAINING_SOURCE_DIR}/training_ops/rocm/*.cc"
@ -1370,7 +1383,7 @@ if (onnxruntime_USE_ROCM)
endif()
onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers Boost::mp11 safeint_interface)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(onnxruntime_providers_rocm onnxruntime_training)
target_link_libraries(onnxruntime_providers_rocm PRIVATE onnxruntime_training)
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
@ -1402,7 +1415,7 @@ if (onnxruntime_USE_ROCM)
set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS})
if(onnxruntime_USE_MPI)
target_link_libraries(onnxruntime_providers_rocm PRIVATE ${MPI_LIBRARIES} ${MPI_CXX_LINK_FLAGS})

View file

@ -51,7 +51,7 @@ if (onnxruntime_USE_ROCM)
# ROCM provider sources are generated, need to add include directory for generated headers
target_include_directories(onnxruntime_session PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
endif()
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_session PRIVATE ${ORTTRAINING_ROOT})
endif()

View file

@ -114,7 +114,7 @@ if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
re2::re2
)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING)
bundle_static_library(onnxruntime_webassembly tensorboard)
endif()
@ -192,7 +192,7 @@ else()
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
endif()
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
if (onnxruntime_ENABLE_TRAINING)
target_link_libraries(onnxruntime_webassembly PRIVATE tensorboard)
endif()

View file

@ -10,7 +10,7 @@
#include "contrib_ops/cpu/cpu_contrib_kernels.h"
#endif
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
#include "orttraining/training_ops/cpu/cpu_training_kernels.h"
#endif
@ -2190,7 +2190,7 @@ Status RegisterCPUKernels(KernelRegistry& kernel_registry) {
#ifndef DISABLE_CONTRIB_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuContribKernels(kernel_registry));
#endif
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::RegisterCpuTrainingKernels(kernel_registry));
#endif
return Status::OK();

View file

@ -38,16 +38,19 @@
#endif
#endif
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
#include "orttraining/training_ops/cpu/controlflow/group.h"
#include "orttraining/training_ops/cpu/controlflow/record.h"
#include "orttraining/training_ops/cpu/controlflow/wait.h"
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#include "orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.h"
#include "orttraining/training_ops/cpu/tensor/split.h"
#include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h"
#endif
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/cpu/controlflow/record.h"
#include "orttraining/training_ops/cpu/controlflow/wait.h"
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#endif
#include "cpu_provider_shared.h"
namespace onnxruntime {
@ -248,16 +251,13 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
#endif
#endif
#ifdef ENABLE_TRAINING
void contrib__record_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::record_event_in_tensor(event_id_tensor); }
void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::wait_event_in_tensor(event_id_tensor); }
#ifdef ENABLE_TRAINING_OPS
Status contrib__Group__Compute(const contrib::Group* p, OpKernelContext* context) override { return p->Group::Compute(context); }
Status contrib__PassThrough__Compute(const contrib::PassThrough* p, OpKernelContext* context) override { return p->PassThrough::Compute(context); }
void contrib__VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) override { contrib::VerifyLogitWeightAndLabelShape(logit_shape, label_shape, weight_shape); }
void contrib__GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) override { contrib::GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); }
void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, TensorShapeVector& new_shape, std::vector<size_t>& permutations) override { contrib::GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); }
Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); }
Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); }
// From cpu/optimizer/adamwbase.h (direct)
Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx,
contrib__AdamWOptimizerBase__Prepare& prepare) override {
@ -270,6 +270,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
const TensorSeq* values, TensorSeq* updated_values) override {
return p->AdamWOptimizerBase::GenerateOutputs(ctx, number_of_values, values, updated_values);
}
#endif
#ifdef ENABLE_TRAINING
void contrib__record_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::record_event_in_tensor(event_id_tensor); }
void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) override { return contrib::wait_event_in_tensor(event_id_tensor); }
Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); }
// From aten_op.h (direct)
bool contrib__IsATenOperatorExecutorInitialized() override {

View file

@ -178,20 +178,24 @@ struct ProviderHostCPU {
#endif
#endif
#ifdef ENABLE_TRAINING
virtual void contrib__record_event_in_tensor(const Tensor& event_id_tensor) = 0;
virtual void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) = 0;
#ifdef ENABLE_TRAINING_OPS
virtual Status contrib__Group__Compute(const contrib::Group* p, OpKernelContext* context) = 0;
virtual Status contrib__PassThrough__Compute(const contrib::PassThrough* p, OpKernelContext* context) = 0;
virtual void contrib__VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) = 0;
virtual void contrib__GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) = 0;
virtual void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, TensorShapeVector& new_shape, std::vector<size_t>& permutations) = 0;
virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) = 0;
virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0;
// From cpu/optimizer/adamwbase.h
virtual Status contrib__AdamWOptimizerBase__PrepareForCompute(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, contrib__AdamWOptimizerBase__Prepare& prepare) = 0;
virtual Status contrib__AdamWOptimizerBase__GenerateOutputs(const contrib::AdamWOptimizerBase* p, OpKernelContext* ctx, size_t number_of_values,
const TensorSeq* values, TensorSeq* updated_values) = 0;
#endif
#ifdef ENABLE_TRAINING
virtual void contrib__record_event_in_tensor(const Tensor& event_id_tensor) = 0;
virtual void contrib__wait_event_in_tensor(const Tensor& event_id_tensor) = 0;
virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0;
// From aten_op.h
virtual bool contrib__IsATenOperatorExecutorInitialized() = 0;
virtual Status contrib__ExecuteReduceSumATen(OpKernelContext* p_ctx, const gsl::span<const int64_t>& axes, bool keepdims) = 0;
@ -252,15 +256,19 @@ struct EinsumTypedComputeProcessor {
Status Run() { return g_host_cpu.EinsumTypedComputeProcessor__Run(this); }
};
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
namespace contrib {
inline void record_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__record_event_in_tensor(event_id_tensor); }
inline void wait_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__wait_event_in_tensor(event_id_tensor); }
inline void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) { g_host_cpu.contrib__VerifyLogitWeightAndLabelShape(logit_shape, label_shape, weight_shape); }
inline void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) { g_host_cpu.contrib__GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); }
inline void GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, TensorShapeVector& new_shape, std::vector<size_t>& permutations) { g_host_cpu.contrib__GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); }
inline Status PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) { return g_host_cpu.contrib__PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); }
} // namespace contrib
#endif
#ifdef ENABLE_TRAINING
namespace contrib {
inline void record_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__record_event_in_tensor(event_id_tensor); }
inline void wait_event_in_tensor(const Tensor& event_id_tensor) { return g_host_cpu.contrib__wait_event_in_tensor(event_id_tensor); }
// From aten_op.h
inline bool IsATenOperatorExecutorInitialized() { return g_host_cpu.contrib__IsATenOperatorExecutorInitialized(); }

View file

@ -12,7 +12,7 @@
#include "core/framework/op_kernel_type_control_utils.h"
#include "core/providers/common.h"
#include "core/providers/op_kernel_type_control.h"
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
#include "orttraining/training_ops/cpu/tensor/gather_elements_grad_impl.h"
#endif
@ -394,7 +394,7 @@ Status Scatter<EnabledDataTypes>::Compute(OpKernelContext* context) const {
return status;
}
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
namespace contrib {

View file

@ -16,7 +16,7 @@
#include "contrib_ops/cuda/cuda_contrib_kernels.h"
#endif
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
#include "orttraining/training_ops/cuda/cuda_training_kernels.h"
#endif
@ -2294,7 +2294,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry));
#endif
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry));
#endif

View file

@ -308,7 +308,7 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle,
long long int strideC,
int batch_count,
const cudaDeviceProp& prop) {
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH);
#else

View file

@ -3,7 +3,7 @@
#include "core/providers/cuda/tensor/gather_elements_impl.h"
#include "core/providers/cuda/tensor/scatter_elements_impl.h"
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
#include "orttraining/training_ops/cuda/tensor/gather_elements_grad_impl.h"
#endif
@ -259,7 +259,7 @@ GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL(double)
#undef GATHER_SCATTER_ELEMENTS_SPECIALIZED_IMPL
#undef GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
template <class T>
struct FuncAtomicAdd {

View file

@ -29,7 +29,7 @@ void GatherNDImpl(
const size_t slice_size,
const int64_t* input_slice_offsets_data);
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
template <typename T>
void GatherNDGradImpl(
cudaStream_t stream,

View file

@ -118,7 +118,7 @@ Status SliceImpl(cudaStream_t stream, const size_t element_size, const int32_t d
input_data, output_data, N);
}
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
Status SliceImplGrad(cudaStream_t stream, const size_t element_size, const int32_t dimension_count,
const TArray<int64_t>& starts, const TArray<int64_t>& steps, const TArray<int64_t>& input_strides,
const TArray<fast_divmod>& output_strides, const void* input_data, void* output_data,
@ -126,7 +126,7 @@ Status SliceImplGrad(cudaStream_t stream, const size_t element_size, const int32
return SliceImplEx<true>(stream, element_size, dimension_count, starts, steps, input_strides, output_strides,
input_data, output_data, N);
}
#endif // ENABLE_TRAINING
#endif // ENABLE_TRAINING_OPS
} // namespace cuda
} // namespace onnxruntime

View file

@ -19,7 +19,7 @@ Status SliceImpl(cudaStream_t stream,
void* output_data,
const size_t N);
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
Status SliceImplGrad(cudaStream_t stream,
const size_t element_size,
const int32_t dimension_count,
@ -30,7 +30,7 @@ Status SliceImplGrad(cudaStream_t stream,
const void* input_data,
void* output_data,
const size_t N);
#endif // ENABLE_TRAINING
#endif // ENABLE_TRAINING_OPS
} // namespace cuda
} // namespace onnxruntime

View file

@ -15,7 +15,7 @@
#include "contrib_ops/rocm/rocm_contrib_kernels.h"
#endif
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
#include "orttraining/training_ops/rocm/rocm_training_kernels.h"
#endif
@ -2260,7 +2260,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::rocm::RegisterRocmContribKernels(kernel_registry));
#endif
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::rocm::RegisterRocmTrainingKernels(kernel_registry));
#endif

View file

@ -37,10 +37,13 @@
#endif
#endif
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
#include "orttraining/training_ops/cpu/controlflow/group.h"
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#include "orttraining/training_ops/cpu/optimizer/adamw/adamwbase.h"
#endif
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/cpu/controlflow/yield.h"
#ifdef ENABLE_TRAINING_TORCH_INTEROP
#include "orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.h"
@ -621,21 +624,23 @@ Status Scan<8>::SetupSubgraphExecutionInfo(const SessionState& session_state, co
template <>
Status Scan<9>::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { return g_host_cpu.Scan__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); }
#ifdef ENABLE_TRAINING
#ifdef ENABLE_TRAINING_OPS
namespace contrib {
Status Group::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__Group__Compute(this, context); }
Status PassThrough::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__PassThrough__Compute(this, context); }
Status YieldOp::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__YieldOp__Compute(this, context); }
Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx, AdamWOptimizerBase::Prepare& prepare) const {
return g_host_cpu.contrib__AdamWOptimizerBase__PrepareForCompute(this, ctx, reinterpret_cast<contrib__AdamWOptimizerBase__Prepare&>(prepare));
}
Status AdamWOptimizerBase::GenerateOutputs(OpKernelContext* ctx, size_t number_of_values,
const TensorSeq* values, TensorSeq* updated_values) const {
return g_host_cpu.contrib__AdamWOptimizerBase__GenerateOutputs(this, ctx, number_of_values, values, updated_values);
}
}
#endif
#ifdef ENABLE_TRAINING
namespace contrib {
Status YieldOp::Compute(OpKernelContext* context) const { return g_host_cpu.contrib__YieldOp__Compute(this, context); }
} // namespace contrib
#ifdef ENABLE_TRAINING_TORCH_INTEROP

View file

@ -13,7 +13,7 @@
#include "core/graph/contrib_ops/internal_nhwc_onnx_opset.h"
#include "core/graph/contrib_ops/ms_opset.h"
#include "core/graph/contrib_ops/onnx_deprecated_opset.h"
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
#include "onnx/defs/operator_sets_training.h"
#endif
#endif
@ -31,7 +31,7 @@
#include "core/platform/tracing.h"
#endif
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
#include "orttraining/core/graph/training_op_defs.h"
#endif
#ifdef ENABLE_TRAINING
@ -256,11 +256,11 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
RegisterOnnxMLOperatorSetSchema();
#endif
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
RegisterOnnxTrainingOperatorSetSchema();
#endif
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
// preserve this order until <training schemas>: this depends on operatorsetschema registration.
training::RegisterTrainingOpSchemas();
#endif

View file

@ -6,7 +6,7 @@
namespace onnxruntime {
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#if defined(ENABLE_TRAINING_OPS)
ONNX_CPU_OPERATOR_KERNEL(
Dropout,
7,

View file

@ -153,24 +153,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Gath
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal);
@ -212,6 +194,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inpl
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant);
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
#ifdef ENABLE_TRAINING
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
// P2P communication operators.
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send);
@ -226,6 +211,25 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Reco
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Decoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder);
#ifdef ENABLE_TRAINING_TORCH_INTEROP
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOp);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOpGrad);
@ -238,6 +242,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Nccl
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MegatronF);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MegatronG);
#endif
#endif
Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -380,24 +385,6 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Scale)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float_float, BatchNormInternal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double_double, BatchNormInternal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16_MLFloat16, BatchNormInternal)>,
@ -440,6 +427,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(
kCudaExecutionProvider, kMSDomain, 1, float, FakeQuant)>,
// the kernels within the following ifdef are not included in a build with
// --enable_training_ops but without --enable_training
#ifdef ENABLE_TRAINING
// P2P communication operators.
#if defined(ORT_USE_NCCL) || defined(USE_MPI)
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Send)>,
@ -454,6 +444,25 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, WaitEvent)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, YieldOp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeEncoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistBinarizeDecoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistBinarizeDecoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, GistBinarizeDecoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, bool, GistPack1Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack1Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack8Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, GistPack8Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPack16Decoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Encoder)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, GistPackMsfp15Decoder)>,
#ifdef ENABLE_TRAINING_TORCH_INTEROP
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PythonOpGrad)>,
@ -465,6 +474,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, NcclReduceScatter)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MegatronF)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MegatronG)>,
#endif
#endif
};