From 99257eb8e3d528f05fa4248df2c843147d795920 Mon Sep 17 00:00:00 2001 From: "Tang, Cheng" Date: Mon, 15 Nov 2021 08:16:20 -0800 Subject: [PATCH] support build option to include external graph transformers (#9478) * temp code * support external graph transformer from build script * remove debug code * add test case * support register rewrite rule * fix source_group issue if external source is not share any common prefix * fix python code style checker * resolve merge conflict Co-authored-by: Cheng Tang --- cmake/CMakeLists.txt | 7 + cmake/onnxruntime_optimizer.cmake | 10 ++ onnxruntime/core/session/environment.cc | 3 + .../optimizer/graph_transformer_registry.cc | 33 +++++ .../optimizer/graph_transformer_registry.h | 93 +++++++++++++ .../core/optimizer/graph_transformer_utils.cc | 3 + .../test/external_transformers_test.py | 126 ++++++++++++++++++ .../test_external_transformers.cc | 35 +++++ tools/ci_build/build.py | 21 ++- 9 files changed, 330 insertions(+), 1 deletion(-) create mode 100644 orttraining/orttraining/core/optimizer/graph_transformer_registry.cc create mode 100644 orttraining/orttraining/core/optimizer/graph_transformer_registry.h create mode 100644 orttraining/orttraining/test/external_transformer/test/external_transformers_test.py create mode 100644 orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index eda1a68a67..6027be8fed 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -171,6 +171,9 @@ set(ONNX_CUSTOM_PROTOC_EXECUTABLE "" CACHE STRING "Specify custom protoc executa # pre-build python path option(onnxruntime_PREBUILT_PYTORCH_PATH "Path to pytorch installation dir") +# external transformer src path +option(onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH "Path to external transformer src dir") + if (onnxruntime_USE_CUDA) set(onnxruntime_DISABLE_RTTI OFF) endif() @@ -334,6 +337,10 @@ if (onnxruntime_BUILD_WEBASSEMBLY) endif() endif() +if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH) + add_definitions(-DORT_TRAINING_EXTERNAL_GRAPH_TRANSFORMERS=1) +endif() + # ORT build with as much excluded as possible. Supports ORT flatbuffers models only. if (onnxruntime_MINIMAL_BUILD) add_compile_definitions(ORT_MINIMAL_BUILD) diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index b7e77f6dbc..9afa2c2e48 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -57,6 +57,16 @@ file(GLOB onnxruntime_optimizer_srcs CONFIGURE_DEPENDS ${onnxruntime_optimizer_s source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_optimizer_srcs}) +if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH) + set(onnxruntime_external_transformer_src_patterns) + list(APPEND onnxruntime_external_transformer_src_patterns + "${onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH}/*.cc" + "${onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH}/*.cpp" + ) + file(GLOB onnxruntime_external_transformer_src ${onnxruntime_external_transformer_src_patterns}) + list(APPEND onnxruntime_optimizer_srcs ${onnxruntime_external_transformer_src}) +endif() + onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_srcs}) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/optimizer DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core) diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index a29a80ae08..f9d82b3a5b 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -36,6 +36,8 @@ #include "orttraining/core/graph/loss_function_registry.h" #include "orttraining/core/graph/optimizer_builder.h" #include "orttraining/core/graph/optimizer_graph_builder_registry.h" +#include "orttraining/core/optimizer/graph_transformer_registry.h" + #endif namespace onnxruntime { @@ -247,6 +249,7 @@ Status Environment::Initialize(std::unique_ptr logging_ training::OptimizerBuilderRegistry::GetInstance().RegisterBuilders(); training::OptimizerGraphBuilderRegistry::GetInstance().RegisterGraphBuilders(); // + training::GraphTransformerRegistry::GetInstance().RegisterExternalGraphTransformers(); #endif }); diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_registry.cc b/orttraining/orttraining/core/optimizer/graph_transformer_registry.cc new file mode 100644 index 0000000000..d1a80f53fc --- /dev/null +++ b/orttraining/orttraining/core/optimizer/graph_transformer_registry.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/core/optimizer/graph_transformer_registry.h" + +namespace onnxruntime { +namespace training { + +#ifdef ORT_TRAINING_EXTERNAL_GRAPH_TRANSFORMERS +void RegisterTrainingExternalTransformers(); +#endif + +void GraphTransformerRegistry::RegisterExternalGraphTransformers() { +#ifdef ORT_TRAINING_EXTERNAL_GRAPH_TRANSFORMERS + RegisterTrainingExternalTransformers(); +#endif +} + +void GenerateExternalTransformers( + TransformerLevel level, + bool before_gradient_builder, + const std::unordered_set& ep_list, + std::vector>& output) { + auto& registered_transformers = GraphTransformerRegistry::GetInstance().GetAllRegisteredTransformers(); + for (auto& [k, v] : registered_transformers) { + if (v.before_gradient_builder != before_gradient_builder || v.level != level) + continue; + output.push_back(GraphTransformerRegistry::GetInstance().CreateTransformer(k, ep_list)); + } +} + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_registry.h b/orttraining/orttraining/core/optimizer/graph_transformer_registry.h new file mode 100644 index 0000000000..75692c5848 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/graph_transformer_registry.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "orttraining/core/graph/generic_registry.h" +#include "core/optimizer/graph_transformer.h" +#include "core/optimizer/graph_transformer_level.h" +#include "core/optimizer/rule_based_graph_transformer.h" + +namespace onnxruntime { +namespace training { + +typedef GenericRegistry&> // supported EP list + GraphTransformerRegistryType; + +typedef std::function(const std::unordered_set&)> GraphTransformerCreator; + +struct GraphTransformerMeta { + TransformerLevel level; + bool before_gradient_builder; +}; + +class GraphTransformerRegistry { + public: + + static GraphTransformerRegistry& GetInstance() { + static GraphTransformerRegistry instance; + return instance; + } + + void RegisterExternalGraphTransformers(); + + void Register(const std::string& name, const GraphTransformerCreator& creator, const GraphTransformerMeta& meta) { + ORT_ENFORCE(!transformer_registry_.Contains(name), "Fail to register, the entry exists:", name); + transformer_registry_.Register(name, creator); + name_to_meta_map_.insert({name, meta}); + } + + const std::unordered_map& GetAllRegisteredTransformers() { + return name_to_meta_map_; + } + + std::unique_ptr CreateTransformer(const std::string& name, const std::unordered_set& ep_list) const { + return transformer_registry_.MakeUnique(name, ep_list); + } + + private: + GraphTransformerRegistry() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformerRegistry); + + GraphTransformerRegistryType transformer_registry_; + std::unordered_map name_to_meta_map_; +}; + +class GraphTransformerRegisterOnce final { + public: + GraphTransformerRegisterOnce(const std::string& name, const GraphTransformerCreator& creator, TransformerLevel level, bool before_gradient_builder) { + GraphTransformerRegistry::GetInstance().Register(name, creator, {level, before_gradient_builder}); + } +}; + +#define ONNX_REGISTER_EXTERNAL_GRAPH_TRANSFORMER(name, level, flag) \ + ONNX_REGISTER_EXTERNAL_GRAPH_TRANSFORMER_UNIQ(__COUNTER__, name, level, flag) +#define ONNX_REGISTER_EXTERNAL_GRAPH_TRANSFORMER_UNIQ(Counter, name, level, flag) \ + static ONNX_UNUSED onnxruntime::training::GraphTransformerRegisterOnce \ + graph_transformer_register_once##name##Counter( \ + #name, [](const std::unordered_set& eps) { \ + return std::make_unique(eps); \ + }, TransformerLevel::level, flag); + +#define ONNX_REGISTER_EXTERNAL_REWRITE_RULE(name, level, flag) \ + ONNX_REGISTER_EXTERNAL_REWRITE_RULE_UNIQ(__COUNTER__, name, level, flag) +#define ONNX_REGISTER_EXTERNAL_REWRITE_RULE_UNIQ(Counter, name, level, flag) \ + static ONNX_UNUSED onnxruntime::training::GraphTransformerRegisterOnce \ + graph_transformer_register_once##name##Counter( \ + #name, [](const std::unordered_set& eps) { \ + auto rule_base_transformer = std::make_unique(#name, eps); \ + ORT_THROW_IF_ERROR(rule_base_transformer->Register(std::make_unique())); \ + return rule_base_transformer; \ + }, TransformerLevel::level, flag); + +void GenerateExternalTransformers( + TransformerLevel level, + bool before_gradient_builder, + const std::unordered_set& ep_list, + std::vector>& output); + +} // namespace training +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 51d8df43f9..a2e93f8357 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -52,6 +52,7 @@ #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/localized_recompute.h" #include "orttraining/core/optimizer/loss_rewriter.h" +#include "orttraining/core/optimizer/graph_transformer_registry.h" #include "orttraining/core/optimizer/transformer_layer_recompute.h" namespace onnxruntime { @@ -157,6 +158,8 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::move(rule_transformer)); } + GenerateExternalTransformers(level, true, compatible_eps, transformers); + if (rules_and_transformers_to_disable.empty()) { return transformers; } else { diff --git a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py new file mode 100644 index 0000000000..2d87a9e824 --- /dev/null +++ b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py @@ -0,0 +1,126 @@ +import sys +import threading +import time + +class OutputGrabber(object): + """ + Class used to grab standard output or another stream. + """ + escape_char = "\b" + + def __init__(self, stream=None, threaded=False): + self.origstream = stream + self.threaded = threaded + if self.origstream is None: + self.origstream = sys.stdout + self.origstreamfd = self.origstream.fileno() + self.capturedtext = "" + # Create a pipe so the stream can be captured: + self.pipe_out, self.pipe_in = os.pipe() + + def __enter__(self): + self.start() + return self + + def __exit__(self, type, value, traceback): + self.stop() + + def start(self): + """ + Start capturing the stream data. + """ + self.capturedtext = "" + # Save a copy of the stream: + self.streamfd = os.dup(self.origstreamfd) + # Replace the original stream with our write pipe: + os.dup2(self.pipe_in, self.origstreamfd) + if self.threaded: + # Start thread that will read the stream: + self.workerThread = threading.Thread(target=self.readOutput) + self.workerThread.start() + # Make sure that the thread is running and os.read() has executed: + time.sleep(0.01) + + def stop(self): + """ + Stop capturing the stream data and save the text in `capturedtext`. + """ + # Print the escape character to make the readOutput method stop: + self.origstream.write(self.escape_char) + # Flush the stream to make sure all our data goes in before + # the escape character: + self.origstream.flush() + if self.threaded: + # wait until the thread finishes so we are sure that + # we have until the last character: + self.workerThread.join() + else: + self.readOutput() + # Close the pipe: + os.close(self.pipe_in) + os.close(self.pipe_out) + # Restore the original stream: + os.dup2(self.streamfd, self.origstreamfd) + # Close the duplicate stream: + os.close(self.streamfd) + + def readOutput(self): + """ + Read the stream data (one byte at a time) + and save the text in `capturedtext`. + """ + while True: + char = os.read(self.pipe_out,1).decode(self.origstream.encoding) + if not char or self.escape_char in char: + break + self.capturedtext += char + +import torch +from onnxruntime.capi import _pybind_state as torch_ort_eager +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import os +from onnxruntime.training import optim, orttrainer, orttrainer_options +import unittest + +def my_loss(x, target): + return F.nll_loss(F.log_softmax(x, dim=1), target) + +class NeuralNet(nn.Module): + def __init__(self, input_size, hidden_size, num_classes): + super(NeuralNet, self).__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, num_classes) + + def forward(self, x, target): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return my_loss(out, target) + +class OrtEPTests(unittest.TestCase): + def test_external_graph_transformer_triggering(self): + input_size = 784 + hidden_size = 500 + num_classes = 10 + batch_size = 128 + model = NeuralNet(input_size, hidden_size, num_classes) + + model_desc = {'inputs': [('x', [batch_size, input_size]), + ('target', [batch_size,])], + 'outputs': [('loss', [], True)]} + optim_config = optim.SGDConfig() + opts = orttrainer.ORTTrainerOptions({'device':{'id':'cpu'}}) + model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) + # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer + data = torch.rand(batch_size, input_size) + target = torch.randint(0, 10, (batch_size,)) + + with OutputGrabber() as out: + loss = model.train_step(data, target) + assert '******************Trigger Customized Graph Transformer: MyGraphTransformer!' in out.capturedtext + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc b/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc new file mode 100644 index 0000000000..2699242adb --- /dev/null +++ b/orttraining/orttraining/test/external_transformer/test_exeternal_transformers/test_external_transformers.cc @@ -0,0 +1,35 @@ +#include "core/optimizer/rewrite_rule.h" +#include "orttraining/core/optimizer/graph_transformer_registry.h" +#include "onnx/defs/schema.h" +#include +#include + +namespace onnxruntime { +namespace training { + +class MyRewriteRule : public RewriteRule { +public: + MyRewriteRule() noexcept + : RewriteRule("MyRewriteRule") { + } + std::vector TargetOpTypes() const noexcept override { + return {}; + } + +private: + bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override { + return true; + } + + Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override{ + std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl; + return Status::OK(); + } +}; + +void RegisterTrainingExternalTransformers() { + ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true); +} + +} +} diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 9d462efa8b..261985735e 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -560,7 +560,15 @@ def parse_arguments(): parser.add_argument( "--enable_external_custom_op_schemas", action='store_true', help="Enable registering user defined custom operation schemas at shared library load time.\ - This feature is only supported/available on Ubuntu.") + This feature is only supported/available on Ubuntu.") + + parser.add_argument( + "--external_graph_transformer_path", type=str, + help="path to the external graph transformer dir.") + + parser.add_argument( + "--test_external_transformer_example", action='store_true', + help="run the example external transformer test, mainly used in CI pipeline.") return parser.parse_args() @@ -813,6 +821,8 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home else "OFF"), "-Donnxruntime_NVCC_THREADS=" + str(args.parallel), ] + if args.external_graph_transformer_path: + cmake_args.append("-Donnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH=" + args.external_graph_transformer_path) # It should be default ON in CI build pipelines, and OFF in packaging pipelines. # And OFF for the people who are not actively developing onnx runtime. add_cmake_define_without_override(cmake_extra_defines, "onnxruntime_DEV_MODE", use_dev_mode(args)) @@ -1637,6 +1647,15 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): # run eager mode test args_list = [sys.executable, os.path.join(cwd, 'eager_test')] run_subprocess(args_list, cwd=cwd, dll_path=dll_path, python_path=cwd) + if args.test_external_transformer_example: + run_subprocess([sys.executable, + os.path.join(source_dir, + 'orttraining', + 'orttraining', + 'test', + 'external_transformer', + 'test', + 'external_transformers_test.py')], cwd=cwd, dll_path=dll_path) try: import onnx # noqa