use gemm to replace matmul + add (#234)

* matmul add fusion

* add shape check on Gemm input C

* walk around the issue with RemoveNode

* update the version support

* If MatMul has shape [K] * [K, N], update it to [1, K] * [K, N], so that it can work for Gemm

* Fuse Gemm+Activation into FusedGemm

* test

* revert the change which fuse the matmul with shape [K]*[K, N] to Gemm as shape [1, K]*[K, N], this may cause runtime failure, as the we can't change input data shape.

* revert the change which change the shape for Matmul from [K]*[K, N] to [1, K]*[K, N]. It enables fuse Matmul + Add to Gemm, but the issue is the data is not aware of this, so the data shape is still [K]*[K, N] and cause runtime issue.

* 1. Fix build issue for CUDA
2. Update Gemm so that we can fuse Matmul [K] * [K, N] + Add [1, N] into Gemm with shape [1,K] * [K, N] + [1, N]

* Fix build issue

* Fuse the activation node even it connects the output

* resolve the merge conflicts

* Add test model for Gemm+Activation fusion
This commit is contained in:
Hector Li 2019-01-22 15:21:55 -08:00 committed by GitHub
parent 8b55596dfe
commit 647cc2dced
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 475 additions and 26 deletions

View file

@ -10,6 +10,7 @@ namespace contrib {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram);
@ -38,6 +39,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram)>());

View file

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "fused_gemm.h"
namespace onnxruntime {
namespace contrib {
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
FusedGemm,
1,
float,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FusedGemm<float, float, float, float>);
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cpu/math/gemm.h"
namespace onnxruntime {
namespace contrib {
template <typename T_X,
typename T_W,
typename T_B,
typename T_Y>
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> {
public:
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) {
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
}
Status Compute(OpKernelContext* context) const override {
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context);
}
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -208,6 +208,96 @@ activation.)DOC")
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
});
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(R"DOC(
The FusedGemm operator schema is the same as Gemm besides it includes attributes
activation and leaky_relu_alpha.)DOC")
.Input(
0,
"A",
"Input tensor A. "
"The shape of A should be (M, K) if transA is 0, "
"or (K, M) if transA is non-zero.",
"T")
.Input(
1,
"B",
"Input tensor B. "
"The shape of B should be (K, N) if transB is 0, "
"or (N, K) if transB is non-zero.",
"T")
.Input(
2,
"C",
"Input tensor C. "
"The shape of C should be unidirectional broadcastable to (M, N).",
"T")
.Output(0, "Y", "Output tensor of shape (M, N).", "T")
.TypeConstraint(
"T",
{"tensor(float16)",
"tensor(float)",
"tensor(double)",
"tensor(uint32)",
"tensor(uint64)",
"tensor(int32)",
"tensor(int64)"},
"Constrain input and output types to float/int tensors.")
.Attr(
"transA",
"Whether A should be transposed",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"transB",
"Whether B should be transposed",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr(
"alpha",
"Scalar multiplier for the product of input tensors A * B.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"beta",
"Scalar multiplier for input tensor C.",
AttributeProto::FLOAT,
1.0f)
.Attr(
"activation",
"",
AttributeProto::STRING,
OPTIONAL)
.Attr(
"leaky_relu_alpha",
"",
AttributeProto::FLOAT,
OPTIONAL)
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (hasNInputShapes(ctx, 2)) {
auto transAAttr = ctx.getAttribute("transA");
bool transA =
transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
auto transBAttr = ctx.getAttribute("transB");
bool transB =
transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
auto& first_input_shape = getInputShape(ctx, 0);
auto& second_input_shape = getInputShape(ctx, 1);
if (first_input_shape.dim_size() != 2)
fail_shape_inference("First input does not have rank 2");
if (second_input_shape.dim_size() != 2)
fail_shape_inference("Second input does not have rank 2");
updateOutputShape(
ctx,
0,
{first_input_shape.dim(transA ? 1 : 0),
second_input_shape.dim(transB ? 0 : 1)});
}
});
ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
.SetDomain(kMSDomain)
.SinceVersion(1)

View file

@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/initializer.h"
#include "core/graph/gemm_activation_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>
using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {
namespace {
bool IsFusableActivation(const Node& node) {
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
}
void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) {
Node::EdgeSet output_edges;
for (auto it = act.OutputEdgesBegin(); it != act.OutputEdgesEnd(); ++it) {
output_edges.insert(*it);
}
//remove output edge of activation
//connect fused_gemm node and nodes after activation nodes
for (auto& output_edge : output_edges) {
NodeIndex dst_node_index = output_edge.GetNode().Index();
int src_arg_index = output_edge.GetSrcArgIndex();
int dst_arg_index = output_edge.GetDstArgIndex();
g.RemoveEdge(act.Index(), dst_node_index, src_arg_index, dst_arg_index);
g.AddEdge(fused_gemm.Index(), dst_node_index, 0, dst_arg_index);
}
}
} // namespace
Status GemmActivationFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto index : order) {
auto node = graph.GetNode(index);
if (!(utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 7) || utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 9)) || node->GetOutputEdgesCount() != 1) {
continue;
}
const Node& next_node = *(node->OutputNodesBegin());
if (!IsFusableActivation(next_node)) {
continue;
}
Node* gemm_node = node;
const Node& act_node = next_node;
Node& fused_gemm = graph.AddNode(graph.GenerateNodeName("fused " + gemm_node->Name()), "FusedGemm",
"fused Gemm " + gemm_node->Name() + "with activation " + act_node.OpType(),
gemm_node->MutableInputDefs(),
graph.IsNodeOutputsInGraphOutputs(next_node) ? const_cast<Node&>(act_node).MutableOutputDefs() : gemm_node->MutableOutputDefs(),
&gemm_node->GetAttributes(),
"com.microsoft");
//Add a new attribute to specify the activation type
fused_gemm.AddAttribute("activation", act_node.OpType());
//Add optional attributes for activations
if (act_node.OpType() == "LeakyRelu") {
const NodeAttributes attrs = act_node.GetAttributes();
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
fused_gemm.AddAttribute("leaky_relu_" + it->first, it->second);
}
}
if (!graph.IsNodeOutputsInGraphOutputs(next_node)) {
HandleActivationNodeEdges(graph, act_node, fused_gemm);
// Replace the input of the node following activation node
const NodeArg* act_output_def = act_node.OutputDefs()[0];
NodeArg* fused_gemm_output_def = fused_gemm.MutableOutputDefs()[0];
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
auto output_node = graph.GetNode((*it).Index());
if (!output_node) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
}
auto& input_defs = output_node->MutableInputDefs();
for (auto& def : input_defs) {
if (def == act_output_def) {
def = fused_gemm_output_def;
}
}
}
}
removed_nodes.push_front(gemm_node->Index());
removed_nodes.push_front(act_node.Index());
}
for (auto node : removed_nodes) {
graph.RemoveNode(node);
}
if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/graph_transformer.h"
namespace onnxruntime {
class GemmActivationFusion : public onnxruntime::GraphTransformer {
public:
GemmActivationFusion() noexcept : onnxruntime::GraphTransformer("GemmActivationFusion", "Fusing Activation into Gemm") {}
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
};
} // namespace onnxruntime

View file

@ -0,0 +1,106 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/initializer.h"
#include "core/graph/matmul_add_fusion.h"
#include "core/graph/graph_utils.h"
#include <deque>
using namespace onnx;
using namespace ::onnxruntime::common;
namespace onnxruntime {
Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto node_index : node_topology_list) {
auto node = graph.GetNode(node_index);
if (nullptr == node ||
!(utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 1) || utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9)) ||
node->GetOutputEdgesCount() != 1) {
continue;
}
auto next_node_itr = node->OutputNodesBegin();
if (next_node_itr == node->OutputNodesEnd()) {
continue;
}
const Node& next_node = (*next_node_itr);
if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) {
continue;
}
Node* matmul_node = node;
Node& add_node = const_cast<Node&>(next_node);
std::vector<NodeArg> input_args, output_args;
auto matmul_input_defs = matmul_node->MutableInputDefs();
auto add_input_defs = add_node.MutableInputDefs();
// Gemm only support float, so the inputs of MatMul
auto matmul_type = matmul_input_defs[0]->Type();
auto add_type = add_input_defs[0]->Type();
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
continue;
}
// Gemm only support Matrix, need to check the shape of MatMul and Add
auto matmul_a_shape = matmul_input_defs[0]->Shape();
auto matmul_b_shape = matmul_input_defs[1]->Shape();
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ) {
continue;
} else if (1 == matmul_a_shape->dim_size() && 2 == matmul_b_shape->dim_size()) {
// MatMul has shape [K] * [K, N], reset it to [1, K] * [K, N], so that it can work for Gemm
auto mutable_matmul_a_shape = const_cast<onnx::TensorShapeProto*>(matmul_a_shape);
auto dim_0 = mutable_matmul_a_shape->mutable_dim(0);
auto dim_1 = (const_cast<onnx::TensorShapeProto*>(matmul_a_shape))->add_dim();
(*dim_1) = (*dim_0);
dim_0->set_dim_value(1);
} if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
// Gemm only support Matrix
continue;
}
auto matmul_output_name = matmul_node->OutputDefs()[0]->Name();
auto gemm_input_defs = matmul_input_defs;
if (matmul_output_name == add_input_defs[0]->Name()) {
// matmul output as Add_A, should use Add_B as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[1]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[1]);
} else {
// matmul output as Add_B, should use Add_A as input C for gemm
// Gemm only support unidirectional broadcast on C
if (add_input_defs[0]->Shape()->dim_size() > 2) {
continue;
}
gemm_input_defs.push_back(add_input_defs[0]);
}
graph.AddNode(graph.GenerateNodeName("gemm"),
"Gemm",
"fused Matmul and Add " + add_node.OpType(),
gemm_input_defs,
add_node.MutableOutputDefs());
removed_nodes.push_front(matmul_node->Index());
removed_nodes.push_front(add_node.Index());
}
// Have to remove node in reversed order for now to walk around the issue in RemoveNode
for (auto it = removed_nodes.begin(); it != removed_nodes.end(); ++it) {
graph.RemoveNode(*it);
}
if (!removed_nodes.empty()) {
modified = true;
ORT_RETURN_IF_ERROR(graph.Resolve());
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/graph/graph_transformer.h"
namespace onnxruntime {
class MatMulAddFusion : public onnxruntime::GraphTransformer {
public:
MatMulAddFusion() noexcept : onnxruntime::GraphTransformer("MatMulAddFusion", "Fusing MatMul and Add into Gemm") {}
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
};
} // namespace onnxruntime

View file

@ -15,7 +15,7 @@ template <typename T_X,
typename T_W,
typename T_B,
typename T_Y>
class Gemm final : public OpKernel {
class Gemm : public OpKernel {
public:
Gemm(const OpKernelInfo& info) : OpKernel(info) {
int64_t temp;
@ -42,6 +42,7 @@ class Gemm final : public OpKernel {
int64_t N = helper.N();
int64_t K = helper.K();
auto Y = context->Output(0, TensorShape({M, N}));
T_Y* y_data = Y->template MutableData<T_Y>();
//bias
// Todo: we might should move this part into math::gemm to let eigen
@ -101,9 +102,11 @@ class Gemm final : public OpKernel {
X->template Data<T_X>(),
W->template Data<T_W>(),
beta_,
Y->template MutableData<T_Y>(),
y_data,
&CPUMathUtil::Instance());
FuseActivation<T_Y>(activation_, y_data, M * N, leaky_relu_alpha_);
return Status::OK();
}
@ -112,6 +115,11 @@ class Gemm final : public OpKernel {
CBLAS_TRANSPOSE trans_B_;
float alpha_;
float beta_;
protected:
// For fused gemm + activation
std::string activation_;
float leaky_relu_alpha_;
};
} // namespace onnxruntime

View file

@ -10,15 +10,15 @@ class GemmHelper {
public:
GemmHelper(const TensorShape& left, bool trans_left, const TensorShape& right, bool trans_right, const TensorShape& bias) {
//dimension check
ORT_ENFORCE(left.NumDimensions() == 2);
ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1);
ORT_ENFORCE(right.NumDimensions() == 2);
if (trans_left) {
M_ = left[1];
K_ = left[0];
M_ = left.NumDimensions() == 2 ? left[1] : left[0];
K_ = left.NumDimensions() == 2 ? left[0] :1 ;
} else {
M_ = left[0];
K_ = left[1];
M_ = left.NumDimensions() == 2 ? left[0] : 1;
K_ = left.NumDimensions() == 2 ? left[1] : left[0];
}
int k_dim;

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "core/providers/cpu/nn/conv_impl.h"
#include "core/util/math_cpuonly.h"
namespace onnxruntime {
@ -144,7 +145,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
Ymatrix.rowwise() += Bvec.transpose();
}
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
Xdata += X_offset * group_;
Ydata += Y_offset * group_;

View file

@ -23,23 +23,6 @@
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
template <typename T>
void fuse_activation(const std::string& activation, T* y_data, size_t size, float alpha) {
EigenVectorArrayMap<T> y_vec(y_data, size);
if (activation.empty()) {
return;
} else if (activation == "Relu") {
y_vec = y_vec.cwiseMax(0);
} else if (activation == "Sigmoid") {
y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp()));
} else if (activation == "Tanh") {
y_vec = y_vec.tanh();
} else if (activation == "LeakyRelu") {
y_vec = (y_vec >= 0).select(y_vec, (T)alpha * y_vec);
} else {
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation);
}
}
template <typename T>
Status Conv<T>::Compute(OpKernelContext* context) const {
@ -155,7 +138,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
Ymatrix.rowwise() += Bvec.transpose();
}
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
Xdata += X_offset * group_;
Ydata += Y_offset * group_;

View file

@ -76,4 +76,25 @@ class CPUMathUtil {
CPUMathUtil() = default;
};
template <typename T>
void FuseActivation(const std::string& activation, T* y_data, size_t size, float leaky_relu_alpha) {
if (activation.empty()) {
return;
}
EigenVectorArrayMap<T> y_vec(y_data, size);
if (activation == "Relu") {
y_vec = y_vec.cwiseMax(0);
} else if (activation == "Sigmoid") {
y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp()));
} else if (activation == "Tanh") {
y_vec = y_vec.tanh();
} else if (activation == "LeakyRelu") {
y_vec = (y_vec >= 0).select(y_vec, (T)leaky_relu_alpha * y_vec);
} else {
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation);
}
}
} // namespace onnxruntime

View file

@ -11,6 +11,8 @@
#include "core/graph/conv_mul_fusion.h"
#include "core/graph/conv_add_fusion.h"
#include "core/graph/conv_activation_fusion.h"
#include "core/graph/matmul_add_fusion.h"
#include "core/graph/gemm_activation_fusion.h"
#include "core/platform/env.h"
#include "test/capturing_sink.h"
@ -197,5 +199,60 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
ASSERT_TRUE(st.IsOK()) << st;
}
TEST(GraphTransformationTests, MatMulAddFusion_two_input) {
string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx";
SessionOptions so;
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
InferenceSession session_object{so, &DefaultLoggingManager()};
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
std::unique_ptr<MatMulAddFusion> matmul_add_fusion_transformer = std::make_unique<MatMulAddFusion>();
session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer));
ASSERT_TRUE(session_object.Initialize().IsOK());
}
TEST(GraphTransformationTests, MatMulAddFusion_three_input) {
string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/model.onnx";
SessionOptions so;
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
InferenceSession session_object{so, &DefaultLoggingManager()};
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
std::unique_ptr<MatMulAddFusion> matmul_add_fusion_transformer = std::make_unique<MatMulAddFusion>();
session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer));
ASSERT_TRUE(session_object.Initialize().IsOK());
}
TEST(GraphTransformationTests, Gemm_Relu_three_input) {
string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/gemm_relu.onnx";
SessionOptions so;
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
InferenceSession session_object{so, &DefaultLoggingManager()};
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
std::unique_ptr<GemmActivationFusion> gemm_activation_fusion_transformer = std::make_unique<GemmActivationFusion>();
session_object.RegisterGraphTransformer(std::move(gemm_activation_fusion_transformer));
ASSERT_TRUE(session_object.Initialize().IsOK());
}
} // namespace test
} // namespace onnxruntime