mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-03 23:49:44 +00:00
use gemm to replace matmul + add (#234)
* matmul add fusion * add shape check on Gemm input C * walk around the issue with RemoveNode * update the version support * If MatMul has shape [K] * [K, N], update it to [1, K] * [K, N], so that it can work for Gemm * Fuse Gemm+Activation into FusedGemm * test * revert the change which fuse the matmul with shape [K]*[K, N] to Gemm as shape [1, K]*[K, N], this may cause runtime failure, as the we can't change input data shape. * revert the change which change the shape for Matmul from [K]*[K, N] to [1, K]*[K, N]. It enables fuse Matmul + Add to Gemm, but the issue is the data is not aware of this, so the data shape is still [K]*[K, N] and cause runtime issue. * 1. Fix build issue for CUDA 2. Update Gemm so that we can fuse Matmul [K] * [K, N] + Add [1, N] into Gemm with shape [1,K] * [K, N] + [1, N] * Fix build issue * Fuse the activation node even it connects the output * resolve the merge conflicts * Add test model for Gemm+Activation fusion
This commit is contained in:
parent
8b55596dfe
commit
647cc2dced
24 changed files with 475 additions and 26 deletions
|
|
@ -10,6 +10,7 @@ namespace contrib {
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram);
|
||||
|
|
@ -38,6 +39,7 @@ void RegisterContribKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>());
|
||||
kernel_registry.Register(BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Ngram)>());
|
||||
|
|
|
|||
15
onnxruntime/contrib_ops/cpu/fused_gemm.cc
Normal file
15
onnxruntime/contrib_ops/cpu/fused_gemm.cc
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "fused_gemm.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
|
||||
FusedGemm,
|
||||
1,
|
||||
float,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
FusedGemm<float, float, float, float>);
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
26
onnxruntime/contrib_ops/cpu/fused_gemm.h
Normal file
26
onnxruntime/contrib_ops/cpu/fused_gemm.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cpu/math/gemm.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
template <typename T_X,
|
||||
typename T_W,
|
||||
typename T_B,
|
||||
typename T_Y>
|
||||
class FusedGemm : public Gemm<T_X, T_W, T_B, T_Y> {
|
||||
public:
|
||||
FusedGemm(const OpKernelInfo& info) : Gemm<T_X, T_W, T_B, T_Y>(info) {
|
||||
Gemm<T_X, T_W, T_B, T_Y>::activation_ = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
Gemm<T_X, T_W, T_B, T_Y>::leaky_relu_alpha_ = info.GetAttrOrDefault("leaky_relu_alpha", 0.01f);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
return Gemm<T_X, T_W, T_B, T_Y>::Compute(context);
|
||||
}
|
||||
};
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -208,6 +208,96 @@ activation.)DOC")
|
|||
ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true);
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(FusedGemm)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(R"DOC(
|
||||
The FusedGemm operator schema is the same as Gemm besides it includes attributes
|
||||
activation and leaky_relu_alpha.)DOC")
|
||||
.Input(
|
||||
0,
|
||||
"A",
|
||||
"Input tensor A. "
|
||||
"The shape of A should be (M, K) if transA is 0, "
|
||||
"or (K, M) if transA is non-zero.",
|
||||
"T")
|
||||
.Input(
|
||||
1,
|
||||
"B",
|
||||
"Input tensor B. "
|
||||
"The shape of B should be (K, N) if transB is 0, "
|
||||
"or (N, K) if transB is non-zero.",
|
||||
"T")
|
||||
.Input(
|
||||
2,
|
||||
"C",
|
||||
"Input tensor C. "
|
||||
"The shape of C should be unidirectional broadcastable to (M, N).",
|
||||
"T")
|
||||
.Output(0, "Y", "Output tensor of shape (M, N).", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)",
|
||||
"tensor(float)",
|
||||
"tensor(double)",
|
||||
"tensor(uint32)",
|
||||
"tensor(uint64)",
|
||||
"tensor(int32)",
|
||||
"tensor(int64)"},
|
||||
"Constrain input and output types to float/int tensors.")
|
||||
.Attr(
|
||||
"transA",
|
||||
"Whether A should be transposed",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"transB",
|
||||
"Whether B should be transposed",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"alpha",
|
||||
"Scalar multiplier for the product of input tensors A * B.",
|
||||
AttributeProto::FLOAT,
|
||||
1.0f)
|
||||
.Attr(
|
||||
"beta",
|
||||
"Scalar multiplier for input tensor C.",
|
||||
AttributeProto::FLOAT,
|
||||
1.0f)
|
||||
.Attr(
|
||||
"activation",
|
||||
"",
|
||||
AttributeProto::STRING,
|
||||
OPTIONAL)
|
||||
.Attr(
|
||||
"leaky_relu_alpha",
|
||||
"",
|
||||
AttributeProto::FLOAT,
|
||||
OPTIONAL)
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
if (hasNInputShapes(ctx, 2)) {
|
||||
auto transAAttr = ctx.getAttribute("transA");
|
||||
bool transA =
|
||||
transAAttr ? static_cast<int>(transAAttr->i()) != 0 : false;
|
||||
auto transBAttr = ctx.getAttribute("transB");
|
||||
bool transB =
|
||||
transBAttr ? static_cast<int>(transBAttr->i()) != 0 : false;
|
||||
auto& first_input_shape = getInputShape(ctx, 0);
|
||||
auto& second_input_shape = getInputShape(ctx, 1);
|
||||
if (first_input_shape.dim_size() != 2)
|
||||
fail_shape_inference("First input does not have rank 2");
|
||||
if (second_input_shape.dim_size() != 2)
|
||||
fail_shape_inference("Second input does not have rank 2");
|
||||
updateOutputShape(
|
||||
ctx,
|
||||
0,
|
||||
{first_input_shape.dim(transA ? 1 : 0),
|
||||
second_input_shape.dim(transB ? 0 : 1)});
|
||||
}
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
108
onnxruntime/core/graph/gemm_activation_fusion.cc
Normal file
108
onnxruntime/core/graph/gemm_activation_fusion.cc
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/initializer.h"
|
||||
#include "core/graph/gemm_activation_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include <deque>
|
||||
|
||||
using namespace onnx;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
namespace {
|
||||
bool IsFusableActivation(const Node& node) {
|
||||
return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6);
|
||||
}
|
||||
|
||||
void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_gemm) {
|
||||
Node::EdgeSet output_edges;
|
||||
for (auto it = act.OutputEdgesBegin(); it != act.OutputEdgesEnd(); ++it) {
|
||||
output_edges.insert(*it);
|
||||
}
|
||||
|
||||
//remove output edge of activation
|
||||
//connect fused_gemm node and nodes after activation nodes
|
||||
for (auto& output_edge : output_edges) {
|
||||
NodeIndex dst_node_index = output_edge.GetNode().Index();
|
||||
int src_arg_index = output_edge.GetSrcArgIndex();
|
||||
int dst_arg_index = output_edge.GetDstArgIndex();
|
||||
g.RemoveEdge(act.Index(), dst_node_index, src_arg_index, dst_arg_index);
|
||||
g.AddEdge(fused_gemm.Index(), dst_node_index, 0, dst_arg_index);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GemmActivationFusion::Apply(Graph& graph, bool& modified) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& order = graph_viewer.GetNodesInTopologicalOrder();
|
||||
|
||||
std::deque<onnxruntime::NodeIndex> removed_nodes;
|
||||
for (auto index : order) {
|
||||
auto node = graph.GetNode(index);
|
||||
if (!(utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 7) || utils::IsSupportedOptypeVersionAndDomain(*node, "Gemm", 9)) || node->GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
const Node& next_node = *(node->OutputNodesBegin());
|
||||
if (!IsFusableActivation(next_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* gemm_node = node;
|
||||
const Node& act_node = next_node;
|
||||
|
||||
Node& fused_gemm = graph.AddNode(graph.GenerateNodeName("fused " + gemm_node->Name()), "FusedGemm",
|
||||
"fused Gemm " + gemm_node->Name() + "with activation " + act_node.OpType(),
|
||||
gemm_node->MutableInputDefs(),
|
||||
graph.IsNodeOutputsInGraphOutputs(next_node) ? const_cast<Node&>(act_node).MutableOutputDefs() : gemm_node->MutableOutputDefs(),
|
||||
&gemm_node->GetAttributes(),
|
||||
"com.microsoft");
|
||||
|
||||
//Add a new attribute to specify the activation type
|
||||
fused_gemm.AddAttribute("activation", act_node.OpType());
|
||||
|
||||
//Add optional attributes for activations
|
||||
if (act_node.OpType() == "LeakyRelu") {
|
||||
const NodeAttributes attrs = act_node.GetAttributes();
|
||||
for (auto it = attrs.begin(); it != attrs.end(); ++it) {
|
||||
fused_gemm.AddAttribute("leaky_relu_" + it->first, it->second);
|
||||
}
|
||||
}
|
||||
|
||||
if (!graph.IsNodeOutputsInGraphOutputs(next_node)) {
|
||||
HandleActivationNodeEdges(graph, act_node, fused_gemm);
|
||||
|
||||
// Replace the input of the node following activation node
|
||||
const NodeArg* act_output_def = act_node.OutputDefs()[0];
|
||||
NodeArg* fused_gemm_output_def = fused_gemm.MutableOutputDefs()[0];
|
||||
for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) {
|
||||
auto output_node = graph.GetNode((*it).Index());
|
||||
if (!output_node) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT);
|
||||
}
|
||||
|
||||
auto& input_defs = output_node->MutableInputDefs();
|
||||
for (auto& def : input_defs) {
|
||||
if (def == act_output_def) {
|
||||
def = fused_gemm_output_def;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
removed_nodes.push_front(gemm_node->Index());
|
||||
removed_nodes.push_front(act_node.Index());
|
||||
}
|
||||
|
||||
for (auto node : removed_nodes) {
|
||||
graph.RemoveNode(node);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
16
onnxruntime/core/graph/gemm_activation_fusion.h
Normal file
16
onnxruntime/core/graph/gemm_activation_fusion.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class GemmActivationFusion : public onnxruntime::GraphTransformer {
|
||||
public:
|
||||
GemmActivationFusion() noexcept : onnxruntime::GraphTransformer("GemmActivationFusion", "Fusing Activation into Gemm") {}
|
||||
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
106
onnxruntime/core/graph/matmul_add_fusion.cc
Normal file
106
onnxruntime/core/graph/matmul_add_fusion.cc
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/initializer.h"
|
||||
#include "core/graph/matmul_add_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include <deque>
|
||||
|
||||
using namespace onnx;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
Status MatMulAddFusion::Apply(Graph& graph, bool& modified) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::deque<onnxruntime::NodeIndex> removed_nodes;
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto node = graph.GetNode(node_index);
|
||||
if (nullptr == node ||
|
||||
!(utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 1) || utils::IsSupportedOptypeVersionAndDomain(*node, "MatMul", 9)) ||
|
||||
node->GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto next_node_itr = node->OutputNodesBegin();
|
||||
if (next_node_itr == node->OutputNodesEnd()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node& next_node = (*next_node_itr);
|
||||
if (!utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", 7)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
Node* matmul_node = node;
|
||||
Node& add_node = const_cast<Node&>(next_node);
|
||||
std::vector<NodeArg> input_args, output_args;
|
||||
auto matmul_input_defs = matmul_node->MutableInputDefs();
|
||||
auto add_input_defs = add_node.MutableInputDefs();
|
||||
|
||||
// Gemm only support float, so the inputs of MatMul
|
||||
auto matmul_type = matmul_input_defs[0]->Type();
|
||||
auto add_type = add_input_defs[0]->Type();
|
||||
if ((*matmul_type) != "tensor(float)" || (*add_type) != "tensor(float)") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Gemm only support Matrix, need to check the shape of MatMul and Add
|
||||
auto matmul_a_shape = matmul_input_defs[0]->Shape();
|
||||
auto matmul_b_shape = matmul_input_defs[1]->Shape();
|
||||
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape ) {
|
||||
continue;
|
||||
} else if (1 == matmul_a_shape->dim_size() && 2 == matmul_b_shape->dim_size()) {
|
||||
// MatMul has shape [K] * [K, N], reset it to [1, K] * [K, N], so that it can work for Gemm
|
||||
auto mutable_matmul_a_shape = const_cast<onnx::TensorShapeProto*>(matmul_a_shape);
|
||||
auto dim_0 = mutable_matmul_a_shape->mutable_dim(0);
|
||||
auto dim_1 = (const_cast<onnx::TensorShapeProto*>(matmul_a_shape))->add_dim();
|
||||
(*dim_1) = (*dim_0);
|
||||
dim_0->set_dim_value(1);
|
||||
} if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
|
||||
// Gemm only support Matrix
|
||||
continue;
|
||||
}
|
||||
|
||||
auto matmul_output_name = matmul_node->OutputDefs()[0]->Name();
|
||||
auto gemm_input_defs = matmul_input_defs;
|
||||
if (matmul_output_name == add_input_defs[0]->Name()) {
|
||||
// matmul output as Add_A, should use Add_B as input C for gemm
|
||||
// Gemm only support unidirectional broadcast on C
|
||||
if (add_input_defs[1]->Shape()->dim_size() > 2) {
|
||||
continue;
|
||||
}
|
||||
gemm_input_defs.push_back(add_input_defs[1]);
|
||||
} else {
|
||||
// matmul output as Add_B, should use Add_A as input C for gemm
|
||||
// Gemm only support unidirectional broadcast on C
|
||||
if (add_input_defs[0]->Shape()->dim_size() > 2) {
|
||||
continue;
|
||||
}
|
||||
gemm_input_defs.push_back(add_input_defs[0]);
|
||||
}
|
||||
|
||||
graph.AddNode(graph.GenerateNodeName("gemm"),
|
||||
"Gemm",
|
||||
"fused Matmul and Add " + add_node.OpType(),
|
||||
gemm_input_defs,
|
||||
add_node.MutableOutputDefs());
|
||||
|
||||
removed_nodes.push_front(matmul_node->Index());
|
||||
removed_nodes.push_front(add_node.Index());
|
||||
}
|
||||
|
||||
// Have to remove node in reversed order for now to walk around the issue in RemoveNode
|
||||
for (auto it = removed_nodes.begin(); it != removed_nodes.end(); ++it) {
|
||||
graph.RemoveNode(*it);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
16
onnxruntime/core/graph/matmul_add_fusion.h
Normal file
16
onnxruntime/core/graph/matmul_add_fusion.h
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/graph/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
class MatMulAddFusion : public onnxruntime::GraphTransformer {
|
||||
public:
|
||||
MatMulAddFusion() noexcept : onnxruntime::GraphTransformer("MatMulAddFusion", "Fusing MatMul and Add into Gemm") {}
|
||||
Status Apply(onnxruntime::Graph& graph, bool& modified) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -15,7 +15,7 @@ template <typename T_X,
|
|||
typename T_W,
|
||||
typename T_B,
|
||||
typename T_Y>
|
||||
class Gemm final : public OpKernel {
|
||||
class Gemm : public OpKernel {
|
||||
public:
|
||||
Gemm(const OpKernelInfo& info) : OpKernel(info) {
|
||||
int64_t temp;
|
||||
|
|
@ -42,6 +42,7 @@ class Gemm final : public OpKernel {
|
|||
int64_t N = helper.N();
|
||||
int64_t K = helper.K();
|
||||
auto Y = context->Output(0, TensorShape({M, N}));
|
||||
T_Y* y_data = Y->template MutableData<T_Y>();
|
||||
|
||||
//bias
|
||||
// Todo: we might should move this part into math::gemm to let eigen
|
||||
|
|
@ -101,9 +102,11 @@ class Gemm final : public OpKernel {
|
|||
X->template Data<T_X>(),
|
||||
W->template Data<T_W>(),
|
||||
beta_,
|
||||
Y->template MutableData<T_Y>(),
|
||||
y_data,
|
||||
&CPUMathUtil::Instance());
|
||||
|
||||
FuseActivation<T_Y>(activation_, y_data, M * N, leaky_relu_alpha_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -112,6 +115,11 @@ class Gemm final : public OpKernel {
|
|||
CBLAS_TRANSPOSE trans_B_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
|
||||
protected:
|
||||
// For fused gemm + activation
|
||||
std::string activation_;
|
||||
float leaky_relu_alpha_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -10,15 +10,15 @@ class GemmHelper {
|
|||
public:
|
||||
GemmHelper(const TensorShape& left, bool trans_left, const TensorShape& right, bool trans_right, const TensorShape& bias) {
|
||||
//dimension check
|
||||
ORT_ENFORCE(left.NumDimensions() == 2);
|
||||
ORT_ENFORCE(left.NumDimensions() == 2 || left.NumDimensions() == 1);
|
||||
ORT_ENFORCE(right.NumDimensions() == 2);
|
||||
|
||||
if (trans_left) {
|
||||
M_ = left[1];
|
||||
K_ = left[0];
|
||||
M_ = left.NumDimensions() == 2 ? left[1] : left[0];
|
||||
K_ = left.NumDimensions() == 2 ? left[0] :1 ;
|
||||
} else {
|
||||
M_ = left[0];
|
||||
K_ = left[1];
|
||||
M_ = left.NumDimensions() == 2 ? left[0] : 1;
|
||||
K_ = left.NumDimensions() == 2 ? left[1] : left[0];
|
||||
}
|
||||
|
||||
int k_dim;
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cpu/nn/conv_impl.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
|
|
@ -144,7 +145,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
|
||||
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
|
|
|
|||
|
|
@ -23,23 +23,6 @@
|
|||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
template <typename T>
|
||||
void fuse_activation(const std::string& activation, T* y_data, size_t size, float alpha) {
|
||||
EigenVectorArrayMap<T> y_vec(y_data, size);
|
||||
if (activation.empty()) {
|
||||
return;
|
||||
} else if (activation == "Relu") {
|
||||
y_vec = y_vec.cwiseMax(0);
|
||||
} else if (activation == "Sigmoid") {
|
||||
y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp()));
|
||||
} else if (activation == "Tanh") {
|
||||
y_vec = y_vec.tanh();
|
||||
} else if (activation == "LeakyRelu") {
|
||||
y_vec = (y_vec >= 0).select(y_vec, (T)alpha * y_vec);
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Conv<T>::Compute(OpKernelContext* context) const {
|
||||
|
|
@ -155,7 +138,7 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
auto Bvec = ConstEigenVectorMap<T>(B->template Data<T>(), M);
|
||||
Ymatrix.rowwise() += Bvec.transpose();
|
||||
}
|
||||
fuse_activation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
FuseActivation(activation_, Ydata, Y_offset * group_, alpha_);
|
||||
|
||||
Xdata += X_offset * group_;
|
||||
Ydata += Y_offset * group_;
|
||||
|
|
|
|||
|
|
@ -76,4 +76,25 @@ class CPUMathUtil {
|
|||
CPUMathUtil() = default;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void FuseActivation(const std::string& activation, T* y_data, size_t size, float leaky_relu_alpha) {
|
||||
if (activation.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
EigenVectorArrayMap<T> y_vec(y_data, size);
|
||||
|
||||
if (activation == "Relu") {
|
||||
y_vec = y_vec.cwiseMax(0);
|
||||
} else if (activation == "Sigmoid") {
|
||||
y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp()));
|
||||
} else if (activation == "Tanh") {
|
||||
y_vec = y_vec.tanh();
|
||||
} else if (activation == "LeakyRelu") {
|
||||
y_vec = (y_vec >= 0).select(y_vec, (T)leaky_relu_alpha * y_vec);
|
||||
} else {
|
||||
ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@
|
|||
#include "core/graph/conv_mul_fusion.h"
|
||||
#include "core/graph/conv_add_fusion.h"
|
||||
#include "core/graph/conv_activation_fusion.h"
|
||||
#include "core/graph/matmul_add_fusion.h"
|
||||
#include "core/graph/gemm_activation_fusion.h"
|
||||
#include "core/platform/env.h"
|
||||
|
||||
#include "test/capturing_sink.h"
|
||||
|
|
@ -197,5 +199,60 @@ TEST(GraphTransformationTests, FuseConvAddMul3D) {
|
|||
ASSERT_TRUE(st.IsOK()) << st;
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MatMulAddFusion_two_input) {
|
||||
string model_uri = MODEL_FOLDER + "matmul_add_fusion/2Input/model.onnx";
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
|
||||
std::unique_ptr<MatMulAddFusion> matmul_add_fusion_transformer = std::make_unique<MatMulAddFusion>();
|
||||
|
||||
session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer));
|
||||
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, MatMulAddFusion_three_input) {
|
||||
string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/model.onnx";
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
|
||||
std::unique_ptr<MatMulAddFusion> matmul_add_fusion_transformer = std::make_unique<MatMulAddFusion>();
|
||||
|
||||
session_object.RegisterGraphTransformer(std::move(matmul_add_fusion_transformer));
|
||||
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, Gemm_Relu_three_input) {
|
||||
string model_uri = MODEL_FOLDER + "matmul_add_fusion/3Input/gemm_relu.onnx";
|
||||
|
||||
SessionOptions so;
|
||||
so.session_logid = "GraphTransformationTests.LoadModelToTransform";
|
||||
InferenceSession session_object{so, &DefaultLoggingManager()};
|
||||
ASSERT_TRUE(session_object.Load(model_uri).IsOK());
|
||||
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
|
||||
std::unique_ptr<GemmActivationFusion> gemm_activation_fusion_transformer = std::make_unique<GemmActivationFusion>();
|
||||
|
||||
session_object.RegisterGraphTransformer(std::move(gemm_activation_fusion_transformer));
|
||||
|
||||
ASSERT_TRUE(session_object.Initialize().IsOK());
|
||||
}
|
||||
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/model.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/model.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_0.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_0.pb
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_1.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/input_1.pb
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/output_0.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/2Input/test_data_0/output_0.pb
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/gemm_relu.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/gemm_relu.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/model.onnx
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_0.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_0.pb
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_1.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_1.pb
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_2.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/input_2.pb
vendored
Normal file
Binary file not shown.
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/output_0.pb
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/matmul_add_fusion/3Input/test_data_0/output_0.pb
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue