Enable training ops in inference (#7783)

* Enable training ops in inference

* fix a build error

* relu test name is the same as trainig test
This commit is contained in:
Sunghoon 2021-05-21 13:06:14 -07:00 committed by GitHub
parent b852b73e84
commit 1fbc04d691
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 9 deletions

View file

@ -60,6 +60,12 @@ file(GLOB_RECURSE onnxruntime_ir_defs_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/defs/*.cc"
)
if (onnxruntime_ENABLE_TRAINING_OPS AND NOT onnxruntime_ENABLE_TRAINING)
set(orttraining_graph_src
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.cc"
"${ORTTRAINING_SOURCE_DIR}/core/graph/training_op_defs.h"
)
endif()
if (onnxruntime_ENABLE_TRAINING)
file(GLOB_RECURSE orttraining_graph_src CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/core/graph/*.h"
@ -68,7 +74,7 @@ if (onnxruntime_ENABLE_TRAINING)
endif()
set(onnxruntime_graph_lib_src ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
if (onnxruntime_ENABLE_TRAINING)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
list(APPEND onnxruntime_graph_lib_src ${orttraining_graph_src})
endif()
@ -83,7 +89,7 @@ endif()
target_include_directories(onnxruntime_graph PRIVATE ${ONNXRUNTIME_ROOT})
if (onnxruntime_ENABLE_TRAINING)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_graph PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_USE_NCCL)
@ -95,7 +101,7 @@ set_target_properties(onnxruntime_graph PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_graph PROPERTIES LINKER_LANGUAGE CXX)
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/graph DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_graph_src} ${onnxruntime_ir_defs_src})
if (onnxruntime_ENABLE_TRAINING)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
source_group(TREE ${ORTTRAINING_ROOT} FILES ${orttraining_graph_src})
endif()

View file

@ -22,6 +22,6 @@ set_target_properties(onnxruntime_session PROPERTIES FOLDER "ONNXRuntime")
if (onnxruntime_USE_CUDA)
target_include_directories(onnxruntime_session PRIVATE ${onnxruntime_CUDNN_HOME}/include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
endif()
if (onnxruntime_ENABLE_TRAINING)
if (onnxruntime_ENABLE_TRAINING OR onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(onnxruntime_session PRIVATE ${ORTTRAINING_ROOT})
endif()

View file

@ -8,7 +8,7 @@
#if !defined(ORT_MINIMAL_BUILD)
#include "onnx/defs/operator_sets.h"
#include "onnx/defs/operator_sets_ml.h"
#if defined(ENABLE_TRAINING)
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#include "onnx/defs/operator_sets_training.h"
#endif
#endif
@ -30,8 +30,10 @@
#include "core/platform/tracing.h"
#endif
#ifdef ENABLE_TRAINING
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
#include "orttraining/core/graph/training_op_defs.h"
#endif
#ifdef ENABLE_TRAINING
#include "orttraining/core/graph/gradient_builder_registry.h"
#include "orttraining/core/graph/loss_function_registry.h"
#include "orttraining/core/graph/optimizer_builder.h"
@ -180,17 +182,20 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
RegisterOnnxMLOperatorSetSchema();
#endif
#ifdef ENABLE_TRAINING
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
RegisterOnnxTrainingOperatorSetSchema();
#endif
#ifdef ENABLE_TRAINING
// preserve this order: this depends on operatorsetschema registration.
#if defined(ENABLE_TRAINING) || defined(ENABLE_TRAINING_OPS)
// preserve this order until <training schemas>: this depends on operatorsetschema registration.
training::RegisterTrainingOpSchemas();
#endif
#ifdef ENABLE_TRAINING
training::GradientBuilderRegistry::GetInstance().RegisterGradientBuilders();
training::LossFunctionRegistry::GetInstance().RegisterNonOperatorLossFunctions();
training::OptimizerBuilderRegistry::GetInstance().RegisterBuilders();
training::OptimizerGraphBuilderRegistry::GetInstance().RegisterGraphBuilders();
// <training schemas>
#endif
});

View file

@ -7,6 +7,55 @@
namespace onnxruntime {
namespace test {
#if defined(ENABLE_TRAINING_OPS)
namespace {
void TestElementwiseGradientOp(
const char* op,
const std::vector<std::pair<std::string, std::vector<float>>>& inputs,
std::function<float(const std::vector<float>&)> expected_func,
const std::unordered_map<std::string, float> attrs = {},
int opset_version = 7, const char* domain = kOnnxDomain) {
const auto first_input = inputs.begin();
ASSERT_NE(first_input, inputs.end());
for (auto input = first_input; input != inputs.end(); ++input) {
if (input == first_input) continue;
ASSERT_EQ(first_input->second.size(), input->second.size());
}
OpTester test(op, opset_version, domain);
for (auto attr : attrs) {
test.AddAttribute(attr.first, attr.second);
}
const auto input_size = first_input->second.size();
std::vector<int64_t> dims{static_cast<int64_t>(input_size)};
std::vector<float> expected_vals;
for (size_t i = 0; i < input_size; i++) {
std::vector<float> params(inputs.size());
std::transform(
inputs.begin(), inputs.end(), params.begin(),
[i](const std::pair<std::string, std::vector<float>>& input) {
return input.second[i];
});
expected_vals.push_back(expected_func(params));
}
for (const auto& input : inputs) {
test.AddInput<float>(input.first.c_str(), dims, input.second);
}
test.AddOutput<float>("dX", dims, expected_vals);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {});
}
float ReluGrad(float dy, float x) {
return x > 0 ? dy : 0;
}
}
#endif
TEST_F(ActivationOpTest, Sigmoid) {
TestActivationOp<float>("Sigmoid",
input_values,
@ -210,5 +259,23 @@ TEST_F(ActivationOpNoInfTest, Softsign) {
{}, false); // Disable TensorRT because result mismatches
}
#if defined(ENABLE_TRAINING_OPS)
TEST(ReluGradInferenceTest, Basic) {
const std::vector<float> x_vals = {-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f};
const std::vector<float> dY(7, 1.0f);
TestElementwiseGradientOp(
"ReluGrad",
{{"dY", dY}, {"X", x_vals}},
[](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], x = params[1];
return ReluGrad(dy, x);
},
{}, 1, kMSDomain);
}
#endif
} // namespace test
} // namespace onnxruntime