mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
fuse Conv+Add+activation for CPU from different op-branch (#10987)
* Fuse op conv Add and activation from two branch * simplify code Co-authored-by: Jicheng Wen <jicwen@microsoft.com>
This commit is contained in:
parent
79e4ed8064
commit
11a4ca741d
11 changed files with 457 additions and 32 deletions
|
|
@ -739,6 +739,7 @@ struct MLAS_CONV_PARAMETERS {
|
|||
size_t InputSize;
|
||||
size_t OutputSize;
|
||||
size_t K;
|
||||
float Beta;
|
||||
MLAS_CONV_ALGORITHM Algorithm;
|
||||
ptrdiff_t ThreadCount;
|
||||
union {
|
||||
|
|
@ -752,25 +753,23 @@ struct MLAS_CONV_PARAMETERS {
|
|||
} u;
|
||||
};
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
MlasConvPrepare(
|
||||
MLAS_CONV_PARAMETERS* Parameters,
|
||||
size_t Dimensions,
|
||||
size_t BatchCount,
|
||||
size_t GroupCount,
|
||||
size_t InputChannels,
|
||||
const int64_t* InputShape,
|
||||
const int64_t* KernelShape,
|
||||
const int64_t* DilationShape,
|
||||
const int64_t* Padding,
|
||||
const int64_t* StrideShape,
|
||||
const int64_t* OutputShape,
|
||||
size_t FilterCount,
|
||||
const MLAS_ACTIVATION* Activation,
|
||||
size_t* WorkingBufferSize,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
);
|
||||
void MLASCALL
|
||||
MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters,
|
||||
size_t Dimensions,
|
||||
size_t BatchCount,
|
||||
size_t GroupCount,
|
||||
size_t InputChannels,
|
||||
const int64_t* InputShape,
|
||||
const int64_t* KernelShape,
|
||||
const int64_t* DilationShape,
|
||||
const int64_t* Padding,
|
||||
const int64_t* StrideShape,
|
||||
const int64_t* OutputShape,
|
||||
size_t FilterCount,
|
||||
const MLAS_ACTIVATION* Activation,
|
||||
size_t* WorkingBufferSize,
|
||||
float Beta,
|
||||
MLAS_THREADPOOL* ThreadPool);
|
||||
|
||||
void
|
||||
MLASCALL
|
||||
|
|
|
|||
|
|
@ -571,7 +571,7 @@ Return Value:
|
|||
//
|
||||
|
||||
size_t CountK;
|
||||
float beta = 0.0f;
|
||||
float beta = Parameters->Beta;
|
||||
float* SegmentOutput = Output + SegmentStartN + n;
|
||||
|
||||
for (size_t k = 0; k < K; k += CountK) {
|
||||
|
|
@ -934,9 +934,9 @@ Return Value:
|
|||
// Invoke the threaded GEMM directly with the input tensor.
|
||||
//
|
||||
|
||||
MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount,
|
||||
OutputSize, K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb, 0.0f,
|
||||
Output, OutputSize, ThreadPool);
|
||||
MlasGemm(CblasNoTrans, Parameters->u.GemmDirect.TransB, FilterCount, OutputSize,
|
||||
K, 1.0f, filter, K, Input, Parameters->u.GemmDirect.ldb,
|
||||
Parameters->Beta, Output, OutputSize, ThreadPool);
|
||||
|
||||
//
|
||||
// Apply the activation with optional bias.
|
||||
|
|
@ -962,7 +962,8 @@ Return Value:
|
|||
}
|
||||
|
||||
MlasGemm(CblasNoTrans, CblasNoTrans, FilterCount, OutputSize, K, 1.0f, filter,
|
||||
K, WorkingBuffer, OutputSize, 0.0f, Output, OutputSize, ThreadPool);
|
||||
K, WorkingBuffer, OutputSize, Parameters->Beta, Output, OutputSize,
|
||||
ThreadPool);
|
||||
|
||||
//
|
||||
// Apply the activation with optional bias.
|
||||
|
|
@ -1038,6 +1039,7 @@ MlasConvPrepare(
|
|||
size_t FilterCount,
|
||||
const MLAS_ACTIVATION* Activation,
|
||||
size_t* WorkingBufferSize,
|
||||
float Beta,
|
||||
MLAS_THREADPOOL* ThreadPool
|
||||
)
|
||||
/*++
|
||||
|
|
@ -1100,6 +1102,7 @@ Return Value:
|
|||
Parameters->GroupCount = GroupCount;
|
||||
Parameters->InputChannels = InputChannels;
|
||||
Parameters->FilterCount = FilterCount;
|
||||
Parameters->Beta = Beta;
|
||||
|
||||
size_t InputSize = 1;
|
||||
size_t OutputSize = 1;
|
||||
|
|
|
|||
285
onnxruntime/core/optimizer/conv_add_act_fusion.cc
Normal file
285
onnxruntime/core/optimizer/conv_add_act_fusion.cc
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <deque>
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/conv_add_act_fusion.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/graph/node_attr_utils.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
namespace {
|
||||
|
||||
namespace selectors {
|
||||
bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) {
|
||||
if (!node_arg.Exists()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto* type_proto = node_arg.TypeAsProto();
|
||||
if (!type_proto) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t actual_data_type;
|
||||
if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return data_type == actual_data_type;
|
||||
}
|
||||
|
||||
const Node* GetLoneConsumerNode(const GraphViewer& graph_viewer, const Node& node) {
|
||||
if (!optimizer_utils::CheckOutputEdges(graph_viewer.GetGraph(), node, 1)) {
|
||||
return nullptr;
|
||||
}
|
||||
return &*node.OutputNodesBegin();
|
||||
}
|
||||
|
||||
class ConvAddActivation : public NodeSelector {
|
||||
public:
|
||||
ConvAddActivation() = default;
|
||||
|
||||
std::optional<NodesToOptimizeIndices> Select(const GraphViewer& graph_viewer, const Node& node) const override {
|
||||
const std::string_view node_ep = node.GetExecutionProviderType();
|
||||
if (node_ep != kCpuExecutionProvider || !HasElementDataType(*node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
// we can't assign `conv_node` as the producer-node, even it is, because we have to make sure
|
||||
// 1. Its type is 'conv', 2. it has to satisfy the other requirements,like shape, please refer to SelectConvProducer for more info
|
||||
const Node* conv_node = nullptr;
|
||||
const auto* add_node = GetLoneConsumerNode(graph_viewer, node);
|
||||
if (!add_node) {
|
||||
return std::nullopt;
|
||||
}
|
||||
// Let's support addition first, leave any-element-wise-op fusion in the future.
|
||||
// what we want to here is that:
|
||||
// 1 find the Add node, 2 find it's producer node and make sure it's a conv node
|
||||
// 3 find the next node and check if it's a activation node, if yes, we will fuse conv+add+activation or conv+add
|
||||
//
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(*add_node, "Add", {7, 13, 14})) {
|
||||
conv_node = SelectProducerConv(*add_node);
|
||||
}
|
||||
if (!conv_node) {
|
||||
return std::nullopt;
|
||||
}
|
||||
// GetLoneConsumerNode will ensure outputedge_count is 1
|
||||
const auto* act_node = GetLoneConsumerNode(graph_viewer, *add_node);
|
||||
// even the next node is not a activation node, it's also fine.
|
||||
if (!act_node) {
|
||||
// we can't fuse add-activation when add_node has multiple consumer nodes
|
||||
act_node = nullptr;
|
||||
} else if (SelectActivation(graph_viewer, *act_node)) {
|
||||
// this branch is deliberately empty as we want to keep 'act_node' as remains.
|
||||
} else {
|
||||
act_node = nullptr;
|
||||
}
|
||||
|
||||
NodesToOptimizeIndicesBuilder builder{};
|
||||
builder.target_node = conv_node->Index();
|
||||
builder.output_nodes = {add_node->Index()};
|
||||
if (act_node) {
|
||||
builder.output_nodes.push_back(act_node->Index());
|
||||
}
|
||||
return builder.Build();
|
||||
}
|
||||
|
||||
static bool SelectActivation(const GraphViewer& graph_viewer, const Node& activation_node) {
|
||||
auto is_supported_cpu_ep_activation = [&graph_viewer](const Node& activation_node) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Relu", {6, 13, 14}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Sigmoid", {6, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Tanh", {6, 13}) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "LeakyRelu", {6})) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "Clip", {6, 11, 12, 13})) {
|
||||
float min, max;
|
||||
if (!optimizer_utils::GetClipConstantMinMax(graph_viewer.GetGraph(), activation_node, min, max)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(activation_node, "HardSigmoid", {6})) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
return is_supported_cpu_ep_activation(activation_node);
|
||||
}
|
||||
|
||||
const Node* SelectProducerConv(const Node& node) const {
|
||||
InlinedVector<const Node*> inputs_node;
|
||||
constexpr int32_t kTensorDims = 4; // NCHW
|
||||
const auto& input_defs = node.InputDefs();
|
||||
|
||||
for (auto producer_node_ptr = node.InputNodesBegin(); producer_node_ptr != node.InputNodesEnd(); ++producer_node_ptr) {
|
||||
const Node* producer_node = dynamic_cast<const Node*>(&(*producer_node_ptr));
|
||||
inputs_node.push_back(producer_node);
|
||||
}
|
||||
size_t input_defs_count = input_defs.size();
|
||||
if (input_defs_count != 2 || inputs_node.size() > input_defs_count) {
|
||||
return nullptr;
|
||||
}
|
||||
// Test if all of inputs have an equal shape.
|
||||
auto* input_0_shape = input_defs[0]->Shape();
|
||||
// Check if ONNX shape inferencing has computed a precise dimension value.
|
||||
if ((input_0_shape == nullptr) || (input_0_shape->dim_size() != kTensorDims)) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int i = 0; i < kTensorDims; i++) {
|
||||
auto& input_0_dim = input_0_shape->dim(i);
|
||||
// even though zero-dim is valid, but we don't support here
|
||||
if (!utils::HasDimValue(input_0_dim) || (input_0_dim.dim_value() == 0)) {
|
||||
if (!utils::HasDimParam(input_0_dim)) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
// we can't fuse them if shape is not matched, it will happens when broadcast-Add
|
||||
for (size_t n = 1; n < input_defs_count; n++) {
|
||||
auto* input_n_shape = input_defs[n]->Shape();
|
||||
if (input_n_shape == nullptr || (input_n_shape->dim_size() != kTensorDims)) {
|
||||
return nullptr;
|
||||
}
|
||||
for (int i = 0; i < kTensorDims; i++) {
|
||||
auto& input_0_dim = input_0_shape->dim(i);
|
||||
auto& input_n_dim = input_n_shape->dim(i);
|
||||
if (!utils::HasDimValue(input_n_dim) || (input_0_dim.dim_value() != input_n_dim.dim_value())) {
|
||||
if (!utils::HasDimParam(input_0_dim) || !utils::HasDimParam(input_n_dim) || (input_0_dim.dim_param() != input_n_dim.dim_param())) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If one of the inputs to the Add node is a convolution, then
|
||||
// attempt to fuse the addition into the convolution itself.
|
||||
for (size_t n = 0; (n < inputs_node.size()) && inputs_node[n]; n++) {
|
||||
const auto& producer_input_defs = inputs_node[n]->InputDefs();
|
||||
const auto& producer_input_args_count = inputs_node[n]->InputArgCount();
|
||||
size_t pre_input_defs_count = producer_input_defs.size();
|
||||
// Check if this is a single use convolution that hasn't already
|
||||
// been fused with another Add/Sum node. The Add/Sum can also only be
|
||||
// fused if the convolution isn't itself fused with an activation.
|
||||
if ((inputs_node[n]->OpType() == "Conv") && (pre_input_defs_count < 4) && (producer_input_args_count.size() < 4) &&
|
||||
(graph_utils::GetNodeAttribute(*inputs_node[n], "activation") == nullptr) && (inputs_node[n]->GetOutputEdgesCount() == 1)) {
|
||||
if (pre_input_defs_count < 3) {
|
||||
// The optional bias parameter is empty so set to an empty string.
|
||||
// TODO, add a new null arguments for bias
|
||||
continue;
|
||||
}
|
||||
return inputs_node[n];
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace selectors
|
||||
|
||||
namespace actions {
|
||||
using NTO = NodesToOptimize;
|
||||
|
||||
class FuseConvAddActivation : public ReplaceWithNew {
|
||||
private:
|
||||
std::string OpType(const RuntimeState&) const override { return "FusedConv"; }
|
||||
|
||||
std::string Domain(const RuntimeState&) const override { return kMSDomain; }
|
||||
|
||||
NodeAttributes ExtraAttributes(const RuntimeState& state) const override {
|
||||
NodeAttributes extra_fused_conv_attributes;
|
||||
|
||||
const auto* activation = state.selected_nodes.Output(state.selected_nodes.num_outputs-1);
|
||||
if (state.selected_nodes.num_outputs == 1 || activation->OpType() == "Add") {
|
||||
//activation node is the last node in conv+add+activation fusion pattern, while conv+add is also possible
|
||||
return extra_fused_conv_attributes;
|
||||
}
|
||||
ORT_ENFORCE(activation != nullptr, "Expected activation node.");
|
||||
|
||||
const auto& activation_op_type = activation->OpType();
|
||||
utils::SetNodeAttribute(utils::MakeAttribute("activation", activation_op_type), extra_fused_conv_attributes);
|
||||
|
||||
InlinedVector<float> activation_params;
|
||||
if (activation_op_type == "LeakyRelu") {
|
||||
activation_params.push_back(graph_utils::GetNodeAttribute(*activation, "alpha")->f());
|
||||
} else if (activation_op_type == "Clip") {
|
||||
float min, max;
|
||||
ORT_ENFORCE(optimizer_utils::GetClipConstantMinMax(state.graph, *activation, min, max),
|
||||
"Failed to get Clip min/max constants.");
|
||||
activation_params.push_back(min);
|
||||
activation_params.push_back(max);
|
||||
} else if (activation_op_type == "HardSigmoid") {
|
||||
auto* alpha_attr = graph_utils::GetNodeAttribute(*activation, "alpha");
|
||||
auto* beta_attr = graph_utils::GetNodeAttribute(*activation, "beta");
|
||||
float alpha = (alpha_attr == nullptr ? 0.2f : alpha_attr->f());
|
||||
float beta = (beta_attr == nullptr ? 0.5f : beta_attr->f());
|
||||
activation_params.push_back(alpha);
|
||||
activation_params.push_back(beta);
|
||||
}
|
||||
|
||||
if (!activation_params.empty()) {
|
||||
utils::SetNodeAttribute(utils::MakeAttribute("activation_params", activation_params),
|
||||
extra_fused_conv_attributes);
|
||||
}
|
||||
|
||||
return extra_fused_conv_attributes;
|
||||
}
|
||||
|
||||
std::vector<NodeAndMoveInfo> ValueMoves(const RuntimeState& state) const override {
|
||||
const auto& conv = state.selected_nodes.Target();
|
||||
ORT_ENFORCE(conv.GetOutputEdgesCount() == 1 && conv.OutputNodesBegin()->OpType() == "Add",
|
||||
"Expected Conv then Add.");
|
||||
|
||||
const auto add_input_idx = 1 - conv.OutputEdgesBegin()->GetDstArgIndex();
|
||||
|
||||
const auto conv_location = NTO::NodeLocation{NTO::NodeType::kTarget, 0};
|
||||
const auto add_location = NTO::NodeLocation{NTO::NodeType::kOutput, 0};
|
||||
const auto activation_location = NTO::NodeLocation{NTO::NodeType::kOutput, 1};
|
||||
//Conv+add+activation
|
||||
if (state.selected_nodes.num_outputs == 2) {
|
||||
return {
|
||||
MoveAll(conv_location, ArgType::kInput), // move all inputs from conv
|
||||
MoveAndAppend(add_location, ArgType::kInput, add_input_idx, ArgType::kInput), // append add input
|
||||
MoveAll(activation_location, ArgType::kOutput), // move all outputs from relu
|
||||
};
|
||||
} else {
|
||||
//Conv+Add only
|
||||
return {
|
||||
MoveAll(conv_location, ArgType::kInput), // move all inputs from conv
|
||||
MoveAndAppend(add_location, ArgType::kInput, add_input_idx, ArgType::kInput), // append add input
|
||||
MoveAll(add_location, ArgType::kOutput), // move all outputs from relu
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace actions
|
||||
|
||||
void RegisterConvAddActivationFusionRules(SelectorActionRegistry& registry) {
|
||||
const auto name = "ConvAddAct";
|
||||
auto action = std::make_unique<actions::FuseConvAddActivation>();
|
||||
auto selector = std::make_unique<selectors::ConvAddActivation>();
|
||||
registry.RegisterSelectorAndAction(name, {{"Conv", {1, 11}}},
|
||||
std::move(selector), std::move(action));
|
||||
}
|
||||
|
||||
|
||||
SelectorActionRegistry CreateSelectorActionRegistry() {
|
||||
SelectorActionRegistry registry{};
|
||||
RegisterConvAddActivationFusionRules(registry);
|
||||
return registry;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
ConvAddActivationFusion::ConvAddActivationFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers,
|
||||
const SatApplyContextVariant& apply_context)
|
||||
: SelectorActionTransformer{
|
||||
"ConvAddActivationFusion", CreateSelectorActionRegistry(), apply_context, compatible_execution_providers} {
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
25
onnxruntime/core/optimizer/conv_add_act_fusion.h
Normal file
25
onnxruntime/core/optimizer/conv_add_act_fusion.h
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/selectors_actions/selector_action_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class ConvAddActivationFusion
|
||||
|
||||
Transformer that optimizes the graph by using NCHW nodes and a more general version of convaddrelu.
|
||||
This Fusion pattern is used to fuse Conv Add Activation together from different branch, The reason
|
||||
is that we assume the graph would be executed by sequential executor. then the orders of branch running doesn't matter
|
||||
*/
|
||||
class ConvAddActivationFusion : public SelectorActionTransformer {
|
||||
public:
|
||||
ConvAddActivationFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
|
||||
const SatApplyContextVariant& apply_context = {});
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -12,6 +12,7 @@
|
|||
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
|
||||
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
|
||||
#include "core/session/onnxruntime_session_options_config_keys.h"
|
||||
#include "core/optimizer/conv_add_act_fusion.h"
|
||||
|
||||
#if !defined(ORT_MINIMAL_BUILD)
|
||||
|
||||
|
|
@ -161,7 +162,9 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
|
||||
const bool disable_quant_qdq =
|
||||
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
|
||||
#endif
|
||||
switch (level) {
|
||||
case TransformerLevel::Level1: {
|
||||
// RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run)
|
||||
|
|
@ -198,7 +201,6 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
const bool enable_gelu_approximation =
|
||||
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableGeluApproximation, "0") == "1";
|
||||
|
||||
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};
|
||||
const InlinedHashSet<std::string_view> cuda_rocm_eps = {onnxruntime::kCudaExecutionProvider,
|
||||
onnxruntime::kRocmExecutionProvider};
|
||||
const InlinedHashSet<std::string_view> cpu_cuda_rocm_eps = {onnxruntime::kCpuExecutionProvider,
|
||||
|
|
@ -263,6 +265,12 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
}
|
||||
auto cpu_allocator = cpu_execution_provider.GetAllocator(0, OrtMemTypeDefault);
|
||||
transformers.emplace_back(std::make_unique<NhwcTransformer>(std::move(cpu_allocator)));
|
||||
// NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similiar things
|
||||
// of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for
|
||||
// x86-64 cpu, not edge cpu like arm. But This tranformer could be used by opencl-ep/cpu-ep. So
|
||||
// we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled.
|
||||
// this PR #6351 implemented similiar fusion-pattern but only for CUDA, and can only fuse conv-add-relu, while we can fuse more activation.
|
||||
transformers.emplace_back(std::make_unique<ConvAddActivationFusion>(cpu_ep));
|
||||
#endif
|
||||
} break;
|
||||
|
||||
|
|
|
|||
|
|
@ -154,9 +154,10 @@ Status Conv<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
Status Conv<float>::Compute(OpKernelContext* context) const {
|
||||
size_t num_inputs = OpKernel::Node().InputDefs().size();
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
const auto* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = num_inputs == 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
const Tensor* X = context->Input<Tensor>(0);
|
||||
const Tensor* W = context->Input<Tensor>(1);
|
||||
const Tensor* B = num_inputs >= 3 ? context->Input<Tensor>(2) : nullptr;
|
||||
const Tensor* Sum = num_inputs >= 4 ? context->Input<Tensor>(3) : nullptr;
|
||||
const int64_t N = X->Shape()[0];
|
||||
const int64_t C = X->Shape()[1];
|
||||
const int64_t M = W->Shape()[0];
|
||||
|
|
@ -195,7 +196,18 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
const auto* Xdata = X->template Data<float>();
|
||||
const auto* Bdata = B != nullptr ? B->template Data<float>() : nullptr;
|
||||
auto* Ydata = Y->template MutableData<float>();
|
||||
|
||||
// Check for the optional Conv/Sum fusion.
|
||||
float Beta = 0.0f;
|
||||
if (Sum != nullptr) {
|
||||
const auto& sum_shape = Sum->Shape();
|
||||
ORT_RETURN_IF_NOT(Y->Shape() == sum_shape, "output and sum shape must match");
|
||||
// If the output was not allocated inplace with the sum tensor, then copy here.
|
||||
const auto* sum_data = Sum->template Data<float>();
|
||||
if (Ydata != sum_data) {
|
||||
memcpy(Ydata, sum_data, sum_shape.Size() * sizeof(float));
|
||||
}
|
||||
Beta = 1.0f;
|
||||
}
|
||||
const size_t kernel_rank = kernel_shape.size();
|
||||
concurrency::ThreadPool* thread_pool = context->GetOperatorThreadPool();
|
||||
|
||||
|
|
@ -216,6 +228,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
static_cast<size_t>(M / conv_attrs_.group),
|
||||
&activation_,
|
||||
&WorkingBufferSize,
|
||||
Beta,
|
||||
thread_pool);
|
||||
|
||||
auto* working_data = WorkingBufferSize > 0 ? alloc->Alloc(SafeInt<size_t>(sizeof(float)) * WorkingBufferSize)
|
||||
|
|
@ -266,7 +279,7 @@ Status Conv<float>::Compute(OpKernelContext* context) const {
|
|||
1,
|
||||
W->template Data<float>() + group_id * W_offset,
|
||||
col_buffer_data,
|
||||
0,
|
||||
Beta,
|
||||
Ydata + group_id * Y_offset,
|
||||
thread_pool);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class Conv<float> : public OpKernel {
|
|||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
|
||||
protected:
|
||||
MLAS_ACTIVATION activation_;
|
||||
|
||||
|
|
|
|||
|
|
@ -205,6 +205,31 @@ TEST(FusedConvTest, Conv2D_Bias_Z_Relu) {
|
|||
}
|
||||
|
||||
#endif
|
||||
|
||||
TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) {
|
||||
ConvOpAndTestAttributes attrs = {
|
||||
"", // auto_pad
|
||||
vector<int64_t>{1, 1}, // dilations
|
||||
1, // group
|
||||
vector<int64_t>{2, 2}, // kernel_shape
|
||||
vector<int64_t>{0, 0, 0, 0}, // pads
|
||||
vector<int64_t>{1, 1}, // strides
|
||||
"Relu" // activation
|
||||
};
|
||||
|
||||
vector<float> X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
|
||||
vector<int64_t> X_shape = {1, 1, 3, 3};
|
||||
vector<float> W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
vector<int64_t> W_shape = {2, 1, 2, 2};
|
||||
vector<int64_t> Y_shape = {1, 2, 2, 2};
|
||||
vector<float> B = {1.0f, -1.0f};
|
||||
vector<int64_t> B_shape = {2};
|
||||
vector<float> Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f};
|
||||
vector<int64_t> Z_shape = {1, 2, 2, 2};
|
||||
auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f};
|
||||
TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, providers_except_cpu);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) {
|
|||
static_cast<size_t>(output_channels_per_group),
|
||||
&activation,
|
||||
&WorkingBufferSize,
|
||||
0.0f,
|
||||
nullptr);
|
||||
|
||||
auto X = RandomVectorUniform(x_shape, -2.0, 2.0);
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ class MlasConv2DTest : public MlasTestBase {
|
|||
FilterCount,
|
||||
&Activation,
|
||||
&WorkingBufferSize,
|
||||
0.0f,
|
||||
threadpool_);
|
||||
|
||||
MlasConv(&Parameters,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "core/optimizer/concat_slice_elimination.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
#include "core/optimizer/conv_add_act_fusion.h"
|
||||
#include "core/optimizer/conv_add_fusion.h"
|
||||
#include "core/optimizer/conv_bn_fusion.h"
|
||||
#include "core/optimizer/conv_mul_fusion.h"
|
||||
|
|
@ -751,6 +752,70 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) {
|
|||
|
||||
#endif
|
||||
|
||||
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
// Conv->Add->Relu will be transformed to FusedConv
|
||||
TEST_F(GraphTransformationTests, FuseCpuConvAddRelu) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvAddActivationFusion>(), TransformerLevel::Level3));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed from graph
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 0); // Relu removed from graph
|
||||
}
|
||||
|
||||
// Conv->Add->Relu will be partly fused to Conv_Add->Relu since there is Identity depend on Add
|
||||
TEST_F(GraphTransformationTests, FuseCpuConvAddReluIdentity) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/conv_add_relu_identity.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1);
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvAddActivationFusion>(), TransformerLevel::Level3));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1); // Relu remains
|
||||
ASSERT_TRUE(op_to_count["Identity"] == 1); // Identity remains
|
||||
}
|
||||
|
||||
// Conv->Add will be transformed to FusedConv
|
||||
TEST_F(GraphTransformationTests, FuseCpuConvAdd) {
|
||||
auto model_uri = MODEL_FOLDER "fusion/conv_add.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
for (auto& node : p_model->MainGraph().Nodes()) {
|
||||
node.SetExecutionProviderType(kCpuExecutionProvider);
|
||||
}
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 1);
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<ConvAddActivationFusion>(), TransformerLevel::Level3));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level3, *logger_));
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0); // Add removed
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if !defined(DISABLE_CONTRIB_OPS)
|
||||
TEST_F(GraphTransformationTests, FuseConvActivation) {
|
||||
#ifdef USE_CUDA
|
||||
|
|
|
|||
Loading…
Reference in a new issue