mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Introduce new optimizer MatMul + BatchNormalization (#17915)
### Description
Introduce new ORT L1 optimizer under RewriteRule category to fuse MatMul
+ BatchNormalization node. This optimizer look for a specific pattern
observed in one of the impacting customer models and fuse the Matmul and
Batchnormalization node into a Gemm node. For details on the pattern
matching and fusion please refer to the comment section of
`matmul_bn_fusion.cc`.
To visualize, this optimizer will replace following subgraph to a Gemm
node.
<pre>
MatMul GEMM
| |
Reshape ^ ---> Reshape ^
| |
Transpose ^ Transpose ^
|
BatchNormalization
Note: ^ means there can be >=0 occurrence(s) of that node.
Few example fusable pattern:
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM ->
Reshape -> Transpose
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM ->
Reshape -> Reshape
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization --->
GEMM -> Reshape -> Transpose -> Reshape
* - MatMul -> BatchNormalization ---> GEMM
</pre>
Note: This optimizer may evolve in the future to be more generic in
terms of the pattern matching.
### Motivation and Context
- Why is this change required? What problem does it solve?
One of the user of ORT+DML ep needs this to better target the model to
DML. But this transformation applies more broadly, so added L1
optimizer.
<!-- - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
fcb48ae260
commit
8a5f299a13
11 changed files with 543 additions and 9 deletions
|
|
@ -50,6 +50,7 @@
|
|||
#include "core/optimizer/matmul_integer_to_float.h"
|
||||
#include "core/optimizer/matmul_scale_fusion.h"
|
||||
#include "core/optimizer/matmul_transpose_fusion.h"
|
||||
#include "core/optimizer/matmul_bn_fusion.h"
|
||||
#include "core/optimizer/nchwc_transformer.h"
|
||||
#include "core/optimizer/noop_elimination.h"
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
|
|
@ -127,6 +128,7 @@ InlinedVector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(
|
|||
rules.push_back(std::make_unique<ConvAddFusion>());
|
||||
rules.push_back(std::make_unique<ConvMulFusion>());
|
||||
rules.push_back(std::make_unique<ConvBNFusion>());
|
||||
rules.push_back(std::make_unique<MatmulBNFusion>());
|
||||
rules.push_back(std::make_unique<ClipQuantFusion>());
|
||||
rules.push_back(std::make_unique<ReluQuantFusion>());
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -289,7 +289,11 @@ Initializer& Initializer::sqrt() {
|
|||
namespace {
|
||||
template <typename T>
|
||||
struct ScaleByAxis {
|
||||
void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const {
|
||||
void operator()(Tensor& data,
|
||||
const Tensor& scalers,
|
||||
const size_t block_size,
|
||||
const size_t num_blocks,
|
||||
const bool column_major) const {
|
||||
ToNumeric<T> to_numeric;
|
||||
const auto scaler_size = scalers.Shape().Size();
|
||||
T* dst = data.MutableData<T>();
|
||||
|
|
@ -301,24 +305,32 @@ struct ScaleByAxis {
|
|||
}
|
||||
} else {
|
||||
for (size_t block_offset = 0, i = 0; i < num_blocks; i++) {
|
||||
const auto numeric_scaler = to_numeric(scalers_data[i]);
|
||||
for (size_t j = 0; j < block_size; ++j, ++block_offset) {
|
||||
dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
|
||||
if (column_major) {
|
||||
for (size_t j = 0; j < block_size; ++j, ++block_offset) {
|
||||
const auto numeric_scaler = to_numeric(scalers_data[j]);
|
||||
dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
|
||||
}
|
||||
} else {
|
||||
const auto numeric_scaler = to_numeric(scalers_data[i]);
|
||||
for (size_t j = 0; j < block_size; ++j, ++block_offset) {
|
||||
dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void Initializer::scale_by_axis(const Initializer& scalers, int axis) {
|
||||
void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) {
|
||||
ORT_ENFORCE(axis >= 0, "Axis must be non-negative");
|
||||
const size_t block_size = narrow<size_t>(data_.Shape().SizeFromDimension(gsl::narrow_cast<size_t>(axis)));
|
||||
const size_t num_blocks = size() / block_size;
|
||||
ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size");
|
||||
ORT_ENFORCE(scalers.size() == 1 ||
|
||||
(column_major ? scalers.size() == block_size : scalers.size() == num_blocks),
|
||||
"Invalid other(scalers) size");
|
||||
utils::MLTypeCallDispatcher<MLFloat16, BFloat16, float, double, int32_t, int64_t> t_disp(data_.GetElementType());
|
||||
t_disp.Invoke<ScaleByAxis>(data_, scalers.data_, block_size, num_blocks);
|
||||
t_disp.Invoke<ScaleByAxis>(data_, scalers.data_, block_size, num_blocks, column_major);
|
||||
}
|
||||
#endif // ORT_EXTENDED_MINIMAL_BUILD
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class Initializer final {
|
|||
|
||||
Initializer& sqrt();
|
||||
|
||||
void scale_by_axis(const Initializer& other, int axis);
|
||||
void scale_by_axis(const Initializer& other, int axis, bool column_major = false);
|
||||
#endif // ORT_EXTENDED_MINIMAL_BUILD
|
||||
private:
|
||||
std::string name_;
|
||||
|
|
|
|||
230
onnxruntime/core/optimizer/matmul_bn_fusion.cc
Normal file
230
onnxruntime/core/optimizer/matmul_bn_fusion.cc
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/matmul_bn_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
const std::vector<std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>>> ignorable_nodes{
|
||||
{"Reshape", {1, 5, 13, 14, 19}},
|
||||
{"Transpose", {1, 13}}};
|
||||
const std::pair<std::string, InlinedVector<ONNX_NAMESPACE::OperatorSetVersion>> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}};
|
||||
} // namespace
|
||||
|
||||
bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
|
||||
const Node* curr_node = graph.GetNode(curr_node_index);
|
||||
|
||||
// curr_node has different execution provider then it's parent or
|
||||
// has output edge != 1 (this condition will handle the case when ignorable node
|
||||
// is graph output i.e. a graph like this "MatMul->Transpose")
|
||||
if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() ||
|
||||
curr_node->GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// curr_node can be any of the ignorable_nodes.
|
||||
for (size_t index = 0; index < ignorable_nodes.size(); index++) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
std::optional<NodeIndex> MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) {
|
||||
while (NodeIsIgnorable(graph, root_node, curr_node_index)) {
|
||||
curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index();
|
||||
}
|
||||
|
||||
// curr_node is neither ignorable nor dest
|
||||
const Node* curr_node = graph.GetNode(curr_node_index);
|
||||
if (curr_node->OpType() != dest.first) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() &&
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) {
|
||||
return curr_node_index;
|
||||
}
|
||||
|
||||
// either curr_node has different execution provider or
|
||||
// has invalid opset.
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
/*
|
||||
* Given a MatMul node, it will verify the following pattern.
|
||||
* MatMul GEMM
|
||||
* | |
|
||||
* Reshape ^ ---> Reshape ^
|
||||
* | |
|
||||
* Transpose ^ Transpose ^
|
||||
* |
|
||||
* BatchNormalization
|
||||
* Note: ^ means there can be 0 or any occurrences of that node.
|
||||
* Few example fusable pattern:
|
||||
* - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose
|
||||
* - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape
|
||||
* - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose
|
||||
* - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape
|
||||
* - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape
|
||||
* - MatMul -> BatchNormalization ---> GEMM
|
||||
* Other Conditions:
|
||||
* - B tensor of MatMul should be constant.
|
||||
* - scale, B, mean, var tensors of BatchNormalization should be constant.
|
||||
* - Every node in the path, except the BatchNormalization, should have only 1 output edge.
|
||||
*/
|
||||
bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const {
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) ||
|
||||
node.GetOutputEdgesCount() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (graph.NodeProducesGraphOutput(node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// because <node> is not producing graph output, it means it will have a child node
|
||||
NodeIndex child_node_index = node.OutputNodesBegin()->Index();
|
||||
std::optional<NodeIndex> batch_norm_index = MatchPath(graph, node, child_node_index);
|
||||
if (!batch_norm_index.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Node* batch_norm_node = graph.GetNode(*batch_norm_index);
|
||||
|
||||
// Check that the appropriate inputs to the Matmul and BN nodes are constants.
|
||||
if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) ||
|
||||
!graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// First output from BN is required. Others are optional. If any optional outputs exist we can't fuse.
|
||||
const auto& output_defs = batch_norm_node->OutputDefs();
|
||||
if (output_defs.size() > 1) {
|
||||
for (size_t i = 1, end = output_defs.size(); i < end; ++i) {
|
||||
if (output_defs[i] != nullptr && output_defs[i]->Exists()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/*
|
||||
* BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc]
|
||||
* Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML
|
||||
* Expanding out the terms:
|
||||
* Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias
|
||||
* Here,
|
||||
* [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha`
|
||||
* [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta`
|
||||
* Output = alpha * Input + beta, Input = B tensor of MatMul.
|
||||
*
|
||||
*/
|
||||
Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
|
||||
NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index();
|
||||
NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value();
|
||||
|
||||
Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph
|
||||
|
||||
// only perform fusion if epsilon is present and is of float_32 type
|
||||
auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon");
|
||||
if (epsilon_attribute == batch_norm_node.GetAttributes().end() ||
|
||||
epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) {
|
||||
return Status::OK();
|
||||
}
|
||||
const float epsilon = epsilon_attribute->second.f();
|
||||
|
||||
const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name());
|
||||
ORT_ENFORCE(scale_tensor);
|
||||
const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name());
|
||||
ORT_ENFORCE(bias_tensor);
|
||||
const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name());
|
||||
ORT_ENFORCE(mean_tensor);
|
||||
const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name());
|
||||
ORT_ENFORCE(var_tensor);
|
||||
const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name());
|
||||
ORT_ENFORCE(matmul_b_tensor);
|
||||
|
||||
if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) ||
|
||||
!optimizer_utils::IsFloatingPointDataType(*scale_tensor) ||
|
||||
!optimizer_utils::IsFloatingPointDataType(*bias_tensor) ||
|
||||
!optimizer_utils::IsFloatingPointDataType(*mean_tensor) ||
|
||||
!optimizer_utils::IsFloatingPointDataType(*var_tensor) ||
|
||||
scale_tensor->dims_size() != 1 ||
|
||||
bias_tensor->dims_size() != 1 ||
|
||||
mean_tensor->dims_size() != 1 ||
|
||||
var_tensor->dims_size() != 1 ||
|
||||
scale_tensor->dims(0) != matmul_b_tensor->dims(1) ||
|
||||
bias_tensor->dims(0) != matmul_b_tensor->dims(1) ||
|
||||
mean_tensor->dims(0) != matmul_b_tensor->dims(1) ||
|
||||
var_tensor->dims(0) != matmul_b_tensor->dims(1)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/*
|
||||
* temp = scale / sqrt(var + epsilon)
|
||||
* output = (temp * Input) - ((temp * mean) + bias)
|
||||
*/
|
||||
Initializer scale(*scale_tensor, graph.ModelPath());
|
||||
Initializer bias(*bias_tensor, graph.ModelPath());
|
||||
Initializer mean(*mean_tensor, graph.ModelPath());
|
||||
Initializer var(*var_tensor, graph.ModelPath());
|
||||
Initializer matmul_b(*matmul_b_tensor, graph.ModelPath());
|
||||
|
||||
var.add(epsilon);
|
||||
var.sqrt();
|
||||
scale.div(var); // this is the temp
|
||||
matmul_b.scale_by_axis(scale, 1, true);
|
||||
|
||||
mean.mul(scale);
|
||||
bias.sub(mean);
|
||||
|
||||
// create B tensorProto for new Gemm node from <matmulB> initializer.
|
||||
ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor);
|
||||
matmul_b.ToProto(new_gemm_b_tensor);
|
||||
const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name());
|
||||
new_gemm_b_tensor.set_name(new_gemm_b_name);
|
||||
NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor);
|
||||
|
||||
// create bias tensorProto for new Gemm node from <bias> initializer.
|
||||
ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor);
|
||||
bias.ToProto(new_gemm_bias_tensor);
|
||||
const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias");
|
||||
new_gemm_bias_tensor.set_name(new_gemm_bias_name);
|
||||
NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor);
|
||||
|
||||
Node& gemm_node = graph.AddNode(
|
||||
graph.GenerateNodeArgName("MatMulBnFusion_Gemm"),
|
||||
"Gemm",
|
||||
"Generated from Matmul BatchNormalization fusion",
|
||||
{matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg},
|
||||
matmul_node.MutableOutputDefs(),
|
||||
nullptr,
|
||||
kOnnxDomain);
|
||||
|
||||
// Remove MatMul node.
|
||||
Node* node = graph.GetNode(matmul_node.Index());
|
||||
graph_utils::RemoveNodeOutputEdges(graph, *node);
|
||||
graph.RemoveNode(matmul_node.Index());
|
||||
|
||||
// Delete optional empty output defs.
|
||||
// Delete BatchNormalization node and update the input of the child of BatchNormalization
|
||||
batch_norm_node.MutableOutputDefs().resize(1);
|
||||
NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index();
|
||||
graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node);
|
||||
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
27
onnxruntime/core/optimizer/matmul_bn_fusion.h
Normal file
27
onnxruntime/core/optimizer/matmul_bn_fusion.h
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
/*
|
||||
* This fusion submerges a BatchNormalization operator to it's super
|
||||
* precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition()
|
||||
* is true.
|
||||
*/
|
||||
class MatmulBNFusion : public RewriteRule {
|
||||
public:
|
||||
MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"MatMul"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -31,6 +31,7 @@
|
|||
#include "core/optimizer/conv_add_act_fusion.h"
|
||||
#include "core/optimizer/conv_add_fusion.h"
|
||||
#include "core/optimizer/conv_bn_fusion.h"
|
||||
#include "core/optimizer/matmul_bn_fusion.h"
|
||||
#include "core/optimizer/conv_mul_fusion.h"
|
||||
#include "core/optimizer/div_mul_fusion.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
|
|
@ -964,6 +965,268 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::string expected_output_name;
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "MatMul") {
|
||||
expected_output_name = node.OutputDefs()[0]->Name();
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 0);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 0);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Gemm") {
|
||||
ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
|
||||
<< "fusion should produce the same output name as the MatMul node";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::string expected_output_name;
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "MatMul") {
|
||||
expected_output_name = node.OutputDefs()[0]->Name();
|
||||
} else if (node.OpType() == "BatchNormalization") {
|
||||
node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr));
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 0);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 0);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Gemm") {
|
||||
ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
|
||||
<< "fusion should produce the same output name as the MatMul node";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// should not fuse
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "BatchNormalization") {
|
||||
// additional non-empty output to batchNormalization
|
||||
ONNX_NAMESPACE::TypeProto optional_output_tensor_type;
|
||||
optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType);
|
||||
auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type);
|
||||
node.MutableOutputDefs().push_back(&arg);
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 1);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::string expected_output_name;
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "BatchNormalization") {
|
||||
expected_output_name = node.OutputDefs()[0]->Name();
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 0);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 0);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Gemm") {
|
||||
ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
|
||||
<< "fusion should produce the same output name as the last node";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::string expected_output_name;
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "MatMul") {
|
||||
expected_output_name = node.OutputDefs()[0]->Name();
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 0);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 0);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Gemm") {
|
||||
ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
|
||||
<< "fusion should produce the same output name as the MatMul node";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::string expected_output_name;
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "MatMul") {
|
||||
expected_output_name = node.OutputDefs()[0]->Name();
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 0);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 0);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 1);
|
||||
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Gemm") {
|
||||
ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name)
|
||||
<< "fusion should produce the same output name as the MatMul node";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
GraphViewer graphViewer(graph);
|
||||
for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) {
|
||||
auto& node = *graph.GetNode(node_index);
|
||||
if (node.OpType() == "BatchNormalization") {
|
||||
graph_utils::RemoveNode(graph, node);
|
||||
}
|
||||
}
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 1);
|
||||
}
|
||||
|
||||
// should not fuse
|
||||
TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx";
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>("RuleTransformerL1");
|
||||
ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique<MatmulBNFusion>()));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1));
|
||||
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["BatchNormalization"], 1);
|
||||
ASSERT_EQ(op_to_count["MatMul"], 1);
|
||||
ASSERT_EQ(op_to_count["Gemm"], 0);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx";
|
||||
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue