support build option to include external graph transformers (#9478)

* temp code

* support external graph transformer  from build script

* remove debug code

* add test case

* support register rewrite rule

* fix source_group issue if external source is not share any common prefix

* fix python code style checker

* resolve merge conflict

Co-authored-by: Cheng Tang <chenta@microsoft.com>
This commit is contained in:
Tang, Cheng 2021-11-15 08:16:20 -08:00 committed by GitHub
parent 6e09fc5152
commit 99257eb8e3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 330 additions and 1 deletions

View file

@ -171,6 +171,9 @@ set(ONNX_CUSTOM_PROTOC_EXECUTABLE "" CACHE STRING "Specify custom protoc executa
# pre-build python path
option(onnxruntime_PREBUILT_PYTORCH_PATH "Path to pytorch installation dir")
# external transformer src path
option(onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH "Path to external transformer src dir")
if (onnxruntime_USE_CUDA)
set(onnxruntime_DISABLE_RTTI OFF)
endif()
@ -334,6 +337,10 @@ if (onnxruntime_BUILD_WEBASSEMBLY)
endif()
endif()
if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH)
add_definitions(-DORT_TRAINING_EXTERNAL_GRAPH_TRANSFORMERS=1)
endif()
# ORT build with as much excluded as possible. Supports ORT flatbuffers models only.
if (onnxruntime_MINIMAL_BUILD)
add_compile_definitions(ORT_MINIMAL_BUILD)

View file

@ -57,6 +57,16 @@ file(GLOB onnxruntime_optimizer_srcs CONFIGURE_DEPENDS ${onnxruntime_optimizer_s
source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_optimizer_srcs})
if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH)
set(onnxruntime_external_transformer_src_patterns)
list(APPEND onnxruntime_external_transformer_src_patterns
"${onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH}/*.cc"
"${onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH}/*.cpp"
)
file(GLOB onnxruntime_external_transformer_src ${onnxruntime_external_transformer_src_patterns})
list(APPEND onnxruntime_optimizer_srcs ${onnxruntime_external_transformer_src})
endif()
onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_srcs})
install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/optimizer DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core)

View file

@ -36,6 +36,8 @@
#include "orttraining/core/graph/loss_function_registry.h"
#include "orttraining/core/graph/optimizer_builder.h"
#include "orttraining/core/graph/optimizer_graph_builder_registry.h"
#include "orttraining/core/optimizer/graph_transformer_registry.h"
#endif
namespace onnxruntime {
@ -247,6 +249,7 @@ Status Environment::Initialize(std::unique_ptr<logging::LoggingManager> logging_
training::OptimizerBuilderRegistry::GetInstance().RegisterBuilders();
training::OptimizerGraphBuilderRegistry::GetInstance().RegisterGraphBuilders();
// <training schemas>
training::GraphTransformerRegistry::GetInstance().RegisterExternalGraphTransformers();
#endif
});

View file

@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "orttraining/core/optimizer/graph_transformer_registry.h"
namespace onnxruntime {
namespace training {
#ifdef ORT_TRAINING_EXTERNAL_GRAPH_TRANSFORMERS
void RegisterTrainingExternalTransformers();
#endif
void GraphTransformerRegistry::RegisterExternalGraphTransformers() {
#ifdef ORT_TRAINING_EXTERNAL_GRAPH_TRANSFORMERS
RegisterTrainingExternalTransformers();
#endif
}
void GenerateExternalTransformers(
TransformerLevel level,
bool before_gradient_builder,
const std::unordered_set<std::string>& ep_list,
std::vector<std::unique_ptr<GraphTransformer>>& output) {
auto& registered_transformers = GraphTransformerRegistry::GetInstance().GetAllRegisteredTransformers();
for (auto& [k, v] : registered_transformers) {
if (v.before_gradient_builder != before_gradient_builder || v.level != level)
continue;
output.push_back(GraphTransformerRegistry::GetInstance().CreateTransformer(k, ep_list));
}
}
} // namespace training
} // namespace onnxruntime

View file

@ -0,0 +1,93 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <unordered_map>
#include <functional>
#include "orttraining/core/graph/generic_registry.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/rule_based_graph_transformer.h"
namespace onnxruntime {
namespace training {
typedef GenericRegistry<GraphTransformer,
const std::unordered_set<std::string>&> // supported EP list
GraphTransformerRegistryType;
typedef std::function<std::unique_ptr<GraphTransformer>(const std::unordered_set<std::string>&)> GraphTransformerCreator;
struct GraphTransformerMeta {
TransformerLevel level;
bool before_gradient_builder;
};
class GraphTransformerRegistry {
public:
static GraphTransformerRegistry& GetInstance() {
static GraphTransformerRegistry instance;
return instance;
}
void RegisterExternalGraphTransformers();
void Register(const std::string& name, const GraphTransformerCreator& creator, const GraphTransformerMeta& meta) {
ORT_ENFORCE(!transformer_registry_.Contains(name), "Fail to register, the entry exists:", name);
transformer_registry_.Register<GraphTransformer>(name, creator);
name_to_meta_map_.insert({name, meta});
}
const std::unordered_map<std::string, GraphTransformerMeta>& GetAllRegisteredTransformers() {
return name_to_meta_map_;
}
std::unique_ptr<GraphTransformer> CreateTransformer(const std::string& name, const std::unordered_set<std::string>& ep_list) const {
return transformer_registry_.MakeUnique(name, ep_list);
}
private:
GraphTransformerRegistry() = default;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformerRegistry);
GraphTransformerRegistryType transformer_registry_;
std::unordered_map<std::string, GraphTransformerMeta> name_to_meta_map_;
};
class GraphTransformerRegisterOnce final {
public:
GraphTransformerRegisterOnce(const std::string& name, const GraphTransformerCreator& creator, TransformerLevel level, bool before_gradient_builder) {
GraphTransformerRegistry::GetInstance().Register(name, creator, {level, before_gradient_builder});
}
};
#define ONNX_REGISTER_EXTERNAL_GRAPH_TRANSFORMER(name, level, flag) \
ONNX_REGISTER_EXTERNAL_GRAPH_TRANSFORMER_UNIQ(__COUNTER__, name, level, flag)
#define ONNX_REGISTER_EXTERNAL_GRAPH_TRANSFORMER_UNIQ(Counter, name, level, flag) \
static ONNX_UNUSED onnxruntime::training::GraphTransformerRegisterOnce \
graph_transformer_register_once##name##Counter( \
#name, [](const std::unordered_set<std::string>& eps) { \
return std::make_unique<name>(eps); \
}, TransformerLevel::level, flag);
#define ONNX_REGISTER_EXTERNAL_REWRITE_RULE(name, level, flag) \
ONNX_REGISTER_EXTERNAL_REWRITE_RULE_UNIQ(__COUNTER__, name, level, flag)
#define ONNX_REGISTER_EXTERNAL_REWRITE_RULE_UNIQ(Counter, name, level, flag) \
static ONNX_UNUSED onnxruntime::training::GraphTransformerRegisterOnce \
graph_transformer_register_once##name##Counter( \
#name, [](const std::unordered_set<std::string>& eps) { \
auto rule_base_transformer = std::make_unique<RuleBasedGraphTransformer>(#name, eps); \
ORT_THROW_IF_ERROR(rule_base_transformer->Register(std::make_unique<name>())); \
return rule_base_transformer; \
}, TransformerLevel::level, flag);
void GenerateExternalTransformers(
TransformerLevel level,
bool before_gradient_builder,
const std::unordered_set<std::string>& ep_list,
std::vector<std::unique_ptr<GraphTransformer>>& output);
} // namespace training
} // namespace onnxruntime

View file

@ -52,6 +52,7 @@
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/localized_recompute.h"
#include "orttraining/core/optimizer/loss_rewriter.h"
#include "orttraining/core/optimizer/graph_transformer_registry.h"
#include "orttraining/core/optimizer/transformer_layer_recompute.h"
namespace onnxruntime {
@ -157,6 +158,8 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::move(rule_transformer));
}
GenerateExternalTransformers(level, true, compatible_eps, transformers);
if (rules_and_transformers_to_disable.empty()) {
return transformers;
} else {

View file

@ -0,0 +1,126 @@
import sys
import threading
import time
class OutputGrabber(object):
"""
Class used to grab standard output or another stream.
"""
escape_char = "\b"
def __init__(self, stream=None, threaded=False):
self.origstream = stream
self.threaded = threaded
if self.origstream is None:
self.origstream = sys.stdout
self.origstreamfd = self.origstream.fileno()
self.capturedtext = ""
# Create a pipe so the stream can be captured:
self.pipe_out, self.pipe_in = os.pipe()
def __enter__(self):
self.start()
return self
def __exit__(self, type, value, traceback):
self.stop()
def start(self):
"""
Start capturing the stream data.
"""
self.capturedtext = ""
# Save a copy of the stream:
self.streamfd = os.dup(self.origstreamfd)
# Replace the original stream with our write pipe:
os.dup2(self.pipe_in, self.origstreamfd)
if self.threaded:
# Start thread that will read the stream:
self.workerThread = threading.Thread(target=self.readOutput)
self.workerThread.start()
# Make sure that the thread is running and os.read() has executed:
time.sleep(0.01)
def stop(self):
"""
Stop capturing the stream data and save the text in `capturedtext`.
"""
# Print the escape character to make the readOutput method stop:
self.origstream.write(self.escape_char)
# Flush the stream to make sure all our data goes in before
# the escape character:
self.origstream.flush()
if self.threaded:
# wait until the thread finishes so we are sure that
# we have until the last character:
self.workerThread.join()
else:
self.readOutput()
# Close the pipe:
os.close(self.pipe_in)
os.close(self.pipe_out)
# Restore the original stream:
os.dup2(self.streamfd, self.origstreamfd)
# Close the duplicate stream:
os.close(self.streamfd)
def readOutput(self):
"""
Read the stream data (one byte at a time)
and save the text in `capturedtext`.
"""
while True:
char = os.read(self.pipe_out,1).decode(self.origstream.encoding)
if not char or self.escape_char in char:
break
self.capturedtext += char
import torch
from onnxruntime.capi import _pybind_state as torch_ort_eager
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from onnxruntime.training import optim, orttrainer, orttrainer_options
import unittest
def my_loss(x, target):
return F.nll_loss(F.log_softmax(x, dim=1), target)
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x, target):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return my_loss(out, target)
class OrtEPTests(unittest.TestCase):
def test_external_graph_transformer_triggering(self):
input_size = 784
hidden_size = 500
num_classes = 10
batch_size = 128
model = NeuralNet(input_size, hidden_size, num_classes)
model_desc = {'inputs': [('x', [batch_size, input_size]),
('target', [batch_size,])],
'outputs': [('loss', [], True)]}
optim_config = optim.SGDConfig()
opts = orttrainer.ORTTrainerOptions({'device':{'id':'cpu'}})
model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts)
# because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer
data = torch.rand(batch_size, input_size)
target = torch.randint(0, 10, (batch_size,))
with OutputGrabber() as out:
loss = model.train_step(data, target)
assert '******************Trigger Customized Graph Transformer: MyGraphTransformer!' in out.capturedtext
if __name__ == '__main__':
unittest.main()

View file

@ -0,0 +1,35 @@
#include "core/optimizer/rewrite_rule.h"
#include "orttraining/core/optimizer/graph_transformer_registry.h"
#include "onnx/defs/schema.h"
#include <memory>
#include <iostream>
namespace onnxruntime {
namespace training {
class MyRewriteRule : public RewriteRule {
public:
MyRewriteRule() noexcept
: RewriteRule("MyRewriteRule") {
}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {};
}
private:
bool SatisfyCondition(const Graph& /*graph*/, const Node& /*node*/, const logging::Logger& /*logger*/) const override {
return true;
}
Status Apply(Graph& /*graph*/, Node& /*node*/, RewriteRuleEffect& /*rule_effect*/, const logging::Logger& /*logger*/) const override{
std::cout << "******************Trigger Customized Graph Transformer: MyGraphTransformer!" << std::endl;
return Status::OK();
}
};
void RegisterTrainingExternalTransformers() {
ONNX_REGISTER_EXTERNAL_REWRITE_RULE(MyRewriteRule, Level1, true);
}
}
}

View file

@ -560,7 +560,15 @@ def parse_arguments():
parser.add_argument(
"--enable_external_custom_op_schemas", action='store_true',
help="Enable registering user defined custom operation schemas at shared library load time.\
This feature is only supported/available on Ubuntu.")
This feature is only supported/available on Ubuntu.")
parser.add_argument(
"--external_graph_transformer_path", type=str,
help="path to the external graph transformer dir.")
parser.add_argument(
"--test_external_transformer_example", action='store_true',
help="run the example external transformer test, mainly used in CI pipeline.")
return parser.parse_args()
@ -813,6 +821,8 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home
else "OFF"),
"-Donnxruntime_NVCC_THREADS=" + str(args.parallel),
]
if args.external_graph_transformer_path:
cmake_args.append("-Donnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH=" + args.external_graph_transformer_path)
# It should be default ON in CI build pipelines, and OFF in packaging pipelines.
# And OFF for the people who are not actively developing onnx runtime.
add_cmake_define_without_override(cmake_extra_defines, "onnxruntime_DEV_MODE", use_dev_mode(args))
@ -1637,6 +1647,15 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
# run eager mode test
args_list = [sys.executable, os.path.join(cwd, 'eager_test')]
run_subprocess(args_list, cwd=cwd, dll_path=dll_path, python_path=cwd)
if args.test_external_transformer_example:
run_subprocess([sys.executable,
os.path.join(source_dir,
'orttraining',
'orttraining',
'test',
'external_transformer',
'test',
'external_transformers_test.py')], cwd=cwd, dll_path=dll_path)
try:
import onnx # noqa