diff --git a/cmake/onnxruntime_graph.cmake b/cmake/onnxruntime_graph.cmake index 25ee83d15f..0e4a7e004d 100644 --- a/cmake/onnxruntime_graph.cmake +++ b/cmake/onnxruntime_graph.cmake @@ -60,6 +60,12 @@ file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/defs/*.cc" ) +if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING) + set(orttraining_graph_src + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc" + "${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h" + ) +endif() if (onnxruntime_ENABLE_TRAINING) file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/core/graph/*.h" @@ -68,7 +74,7 @@ if (onnxruntime_ENABLE_TRAINING) endif() set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) -if (onnxruntime_ENABLE_TRAINING) +if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src}) endif() @@ -83,7 +89,7 @@ endif() target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT}) -if (onnxruntime_ENABLE_TRAINING) +if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_graph PRIVATE ${ORTTRAINING_ROOT}) if (onnxruntime_USE_NCCL) @@ -95,7 +101,7 @@ set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/graph DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src}) -if (onnxruntime_ENABLE_TRAINING) +if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src}) endif() diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 564dc53dfe..4850d48c0a 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -22,6 +22,6 @@ set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime") if (onnxruntime_USE_CUDA) target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) endif() -if (onnxruntime_ENABLE_TRAINING) +if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(onnxruntime_session PRIVATE ${ORTTRAINING_ROOT}) endif() diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 3bd28d630a..07636bbc15 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -8,7 +8,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include "onnx/defs/operator_sets.h" #include "onnx/defs/operator_sets_ml.h" -#if defined(ENABLE_TRAINING) +#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) #include "onnx/defs/operator_sets_training.h" #endif #endif @@ -30,8 +30,10 @@ #include "core/platform/tracing.h" #endif -#ifdef ENABLE_TRAINING +#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) #include "orttraining/core/graph/training_op_defs.h" +#endif +#ifdef ENABLE_TRAINING #include "orttraining/core/graph/gradient_builder_registry.h" #include "orttraining/core/graph/loss_function_registry.h" #include "orttraining/core/graph/optimizer_builder.h" @@ -180,17 +182,20 @@ Status Environment::Initialize(std::unique_ptr logging_ RegisterOnnxMLOperatorSetSchema(); #endif -#ifdef ENABLE_TRAINING +#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) RegisterOnnxTrainingOperatorSetSchema(); #endif -#ifdef ENABLE_TRAINING - // preserve this order: this depends on operatorsetschema registration. +#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS) + // preserve this order until : this depends on operatorsetschema registration. training::RegisterTrainingOpSchemas(); +#endif +#ifdef ENABLE_TRAINING training::GradientBuilderRegistry::GetInstance().RegisterGradientBuilders(); training::LossFunctionRegistry::GetInstance().RegisterNonOperatorLossFunctions(); training::OptimizerBuilderRegistry::GetInstance().RegisterBuilders(); training::OptimizerGraphBuilderRegistry::GetInstance().RegisterGraphBuilders(); + // #endif }); diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index 104e62e62e..d49726619b 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -7,6 +7,55 @@ namespace onnxruntime { namespace test { +#if defined(ENABLE_TRAINING_OPS) +namespace { +void TestElementwiseGradientOp( + const char* op, + const std::vector>>& inputs, + std::function&)> expected_func, + const std::unordered_map attrs = {}, + int opset_version = 7, const char* domain = kOnnxDomain) { + const auto first_input = inputs.begin(); + ASSERT_NE(first_input, inputs.end()); + for (auto input = first_input; input != inputs.end(); ++input) { + if (input == first_input) continue; + ASSERT_EQ(first_input->second.size(), input->second.size()); + } + + OpTester test(op, opset_version, domain); + + for (auto attr : attrs) { + test.AddAttribute(attr.first, attr.second); + } + + const auto input_size = first_input->second.size(); + std::vector dims{static_cast(input_size)}; + + std::vector expected_vals; + for (size_t i = 0; i < input_size; i++) { + std::vector params(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), params.begin(), + [i](const std::pair>& input) { + return input.second[i]; + }); + expected_vals.push_back(expected_func(params)); + } + + for (const auto& input : inputs) { + test.AddInput(input.first.c_str(), dims, input.second); + } + test.AddOutput("dX", dims, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); +} + +float ReluGrad(float dy, float x) { + return x > 0 ? dy : 0; +} +} +#endif + TEST_F(ActivationOpTest, Sigmoid) { TestActivationOp("Sigmoid", input_values, @@ -210,5 +259,23 @@ TEST_F(ActivationOpNoInfTest, Softsign) { {}, false); // Disable TensorRT because result mismatches } +#if defined(ENABLE_TRAINING_OPS) +TEST(ReluGradInferenceTest, Basic) { + const std::vector x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}; + const std::vector dY(7, 1.0f); + + TestElementwiseGradientOp( + "ReluGrad", + {{"dY", dY}, {"X", x_vals}}, + [](const std::vector& params) { + ORT_ENFORCE(params.size() == 2); + const auto dy = params[0], x = params[1]; + + return ReluGrad(dy, x); + }, + {}, 1, kMSDomain); +} +#endif + } // namespace test } // namespace onnxruntime