mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Enable training ops in inference (#7783)
* Enable training ops in inference * fix a build error * relu test name is the same as trainig test
This commit is contained in:
parent
b852b73e84
commit
1fbc04d691
4 changed files with 87 additions and 9 deletions
|
|
@ -60,6 +60,12 @@ file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS
|
|||
"${ONNXRUNTIME_ROOT}/core/defs/*.cc"
|
||||
)
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
|
||||
set(orttraining_graph_src
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
|
||||
)
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS
|
||||
"${ORTTRAINING_SOURCE_DIR}/core/graph/*.h"
|
||||
|
|
@ -68,7 +74,7 @@ if (onnxruntime_ENABLE_TRAINING)
|
|||
endif()
|
||||
|
||||
set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
|
||||
list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src})
|
||||
endif()
|
||||
|
||||
|
|
@ -83,7 +89,7 @@ endif()
|
|||
|
||||
target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT})
|
||||
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
|
||||
target_include_directories(onnxruntime_graph PRIVATE ${ORTTRAINING_ROOT})
|
||||
|
||||
if (onnxruntime_USE_NCCL)
|
||||
|
|
@ -95,7 +101,7 @@ set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime")
|
|||
set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX)
|
||||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/graph DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)
|
||||
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
|
||||
source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src})
|
||||
endif()
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,6 @@ set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
|
|||
if (onnxruntime_USE_CUDA)
|
||||
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
endif()
|
||||
if (onnxruntime_ENABLE_TRAINING)
|
||||
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
|
||||
target_include_directories(onnxruntime_session PRIVATE ${ORTTRAINING_ROOT})
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
#include "onnx/defs/operator_sets.h"
|
||||
#include "onnx/defs/operator_sets_ml.h"
|
||||
#if defined(ENABLE_TRAINING)
|
||||
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
|
||||
#include "onnx/defs/operator_sets_training.h"
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -30,8 +30,10 @@
|
|||
#include "core/platform/tracing.h"
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
|
||||
#include "orttraining/core/graph/training_op_defs.h"
|
||||
#endif
|
||||
#ifdef ENABLE_TRAINING
|
||||
#include "orttraining/core/graph/gradient_builder_registry.h"
|
||||
#include "orttraining/core/graph/loss_function_registry.h"
|
||||
#include "orttraining/core/graph/optimizer_builder.h"
|
||||
|
|
@ -180,17 +182,20 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
|
|||
RegisterOnnxMLOperatorSetSchema();
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
|
||||
RegisterOnnxTrainingOperatorSetSchema();
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_TRAINING
|
||||
// preserve this order: this depends on operatorsetschema registration.
|
||||
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
|
||||
// preserve this order until <training schemas>: this depends on operatorsetschema registration.
|
||||
training::RegisterTrainingOpSchemas();
|
||||
#endif
|
||||
#ifdef ENABLE_TRAINING
|
||||
training::GradientBuilderRegistry::GetInstance().RegisterGradientBuilders();
|
||||
training::LossFunctionRegistry::GetInstance().RegisterNonOperatorLossFunctions();
|
||||
training::OptimizerBuilderRegistry::GetInstance().RegisterBuilders();
|
||||
training::OptimizerGraphBuilderRegistry::GetInstance().RegisterGraphBuilders();
|
||||
// <training schemas>
|
||||
#endif
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,55 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
#if defined(ENABLE_TRAINING_OPS)
|
||||
namespace {
|
||||
void TestElementwiseGradientOp(
|
||||
const char* op,
|
||||
const std::vector<std::pair<std::string, std::vector<float>>>& inputs,
|
||||
std::function<float(const std::vector<float>&)> expected_func,
|
||||
const std::unordered_map<std::string, float> attrs = {},
|
||||
int opset_version = 7, const char* domain = kOnnxDomain) {
|
||||
const auto first_input = inputs.begin();
|
||||
ASSERT_NE(first_input, inputs.end());
|
||||
for (auto input = first_input; input != inputs.end(); ++input) {
|
||||
if (input == first_input) continue;
|
||||
ASSERT_EQ(first_input->second.size(), input->second.size());
|
||||
}
|
||||
|
||||
OpTester test(op, opset_version, domain);
|
||||
|
||||
for (auto attr : attrs) {
|
||||
test.AddAttribute(attr.first, attr.second);
|
||||
}
|
||||
|
||||
const auto input_size = first_input->second.size();
|
||||
std::vector<int64_t> dims{static_cast<int64_t>(input_size)};
|
||||
|
||||
std::vector<float> expected_vals;
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
std::vector<float> params(inputs.size());
|
||||
std::transform(
|
||||
inputs.begin(), inputs.end(), params.begin(),
|
||||
[i](const std::pair<std::string, std::vector<float>>& input) {
|
||||
return input.second[i];
|
||||
});
|
||||
expected_vals.push_back(expected_func(params));
|
||||
}
|
||||
|
||||
for (const auto& input : inputs) {
|
||||
test.AddInput<float>(input.first.c_str(), dims, input.second);
|
||||
}
|
||||
test.AddOutput<float>("dX", dims, expected_vals);
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
|
||||
}
|
||||
|
||||
float ReluGrad(float dy, float x) {
|
||||
return x > 0 ? dy : 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST_F(ActivationOpTest, Sigmoid) {
|
||||
TestActivationOp<float>("Sigmoid",
|
||||
input_values,
|
||||
|
|
@ -210,5 +259,23 @@ TEST_F(ActivationOpNoInfTest, Softsign) {
|
|||
{}, false); // Disable TensorRT because result mismatches
|
||||
}
|
||||
|
||||
#if defined(ENABLE_TRAINING_OPS)
|
||||
TEST(ReluGradInferenceTest, Basic) {
|
||||
const std::vector<float> x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
|
||||
const std::vector<float> dY(7, 1.0f);
|
||||
|
||||
TestElementwiseGradientOp(
|
||||
"ReluGrad",
|
||||
{{"dY", dY}, {"X", x_vals}},
|
||||
[](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], x = params[1];
|
||||
|
||||
return ReluGrad(dy, x);
|
||||
},
|
||||
{}, 1, kMSDomain);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue