diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.cc b/onnxruntime/contrib_ops/cpu/fused_conv.cc new file mode 100644 index 0000000000..ae8f81e812 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/fused_conv.cc @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fused_conv.h" + +namespace onnxruntime { +namespace contrib { +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + FusedConv, + 1, + float, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + FusedConv); +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/fused_conv.h b/onnxruntime/contrib_ops/cpu/fused_conv.h new file mode 100644 index 0000000000..329eb82990 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/fused_conv.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/nn/conv_impl.h" + +namespace onnxruntime { +namespace contrib { + +template +class FusedConv : public Conv { + public: + FusedConv(const OpKernelInfo& info) : Conv(info) { + Conv::activation_ = info.GetAttrOrDefault("activation", ""); + Conv::alpha_ = info.GetAttrOrDefault("alpha", 0.01f); + } + + Status Compute(OpKernelContext* context) const override { + return Conv::Compute(context); + } +}; +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c60e93e66b..db9f6d55ab 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -6,7 +6,11 @@ #include "core/graph/contrib_ops/contrib_defs.h" #include "core/graph/contrib_ops/range_schema_defs.h" #include "core/graph/op.h" +#include "onnx/defs/shape_inference.h" +namespace ONNX_NAMESPACE { +void convPoolTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape); +} namespace onnxruntime { namespace contrib { using ::ONNX_NAMESPACE::AttributeProto; @@ -28,6 +32,62 @@ void RegisterContribSchemas() { Sample echo operator.)DOC"); // register schemas for more operators here + ONNX_CONTRIB_OPERATOR_SCHEMA(FusedConv) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(R"DOC( +The fused convolution operator schema is the same as Conv besides it includes an attribute +activation.)DOC") + .Attr( + "auto_pad", + "", + AttributeProto::STRING, + std::string("NOTSET")) + .Attr( + "kernel_shape", + "", + AttributeProto::INTS, + OPTIONAL) + .Attr( + "dilations", + "", + AttributeProto::INTS, + OPTIONAL) + .Attr( + "strides", "", AttributeProto::INTS, OPTIONAL) + .Attr("pads", + "", + AttributeProto::INTS, OPTIONAL) + .Attr( + "group", + "", + AttributeProto::INT, + static_cast(1)) + .Attr( + "activation", + "", + AttributeProto::STRING, + OPTIONAL) + .Input( + 0, + "X", + "", + "T") + .Input( + 1, + "W", + "", + "T") + .Input(2, "B", "", "T", OpSchema::Optional) + .Output( + 0, + "Y", + "", + "T") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::convPoolTypeAndShapeInference(ctx, false, true); + }); ONNX_CONTRIB_OPERATOR_SCHEMA(ExpandDims) .SetDomain(kMSDomain) diff --git a/onnxruntime/core/graph/conv_activation_fusion.cc b/onnxruntime/core/graph/conv_activation_fusion.cc new file mode 100644 index 0000000000..5b4649c69b --- /dev/null +++ b/onnxruntime/core/graph/conv_activation_fusion.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/initializer.h" +#include "core/graph/conv_activation_fusion.h" +#include "core/graph/graph_utils.h" + +using namespace onnx; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +namespace { +bool IsFusableActivation(const Node& node) { + return utils::IsSupportedOptypeVersionAndDomain(node, "LeakyRelu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Relu", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", 6) || utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", 6); +} +} // namespace + +Status ConvActivationFusion::Apply(Graph& graph, bool& modified) const { + GraphViewer graph_viewer(graph); + const auto& order = graph_viewer.GetNodesInTopologicalOrder(); + + std::vector removed_nodes; + for (auto index : order) { + auto node = graph.GetNode(index); + if (!utils::IsSupportedOptypeVersionAndDomain(*node, "Conv", 1) || node->GetOutputEdgesCount() != 1) { + continue; + } + const Node& next_node = *(node->OutputNodesBegin()); + if (!IsFusableActivation(next_node) || graph.IsNodeOutputsInGraphOutputs(next_node)) { + continue; + } + + Node* conv_node = node; + const Node& act_node = next_node; + std::vector input_args, output_args; + + Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node->Name()), "FusedConv", + "fused Conv " + conv_node->Name() + "with activation " + act_node.OpType(), + conv_node->MutableInputDefs(), + conv_node->MutableOutputDefs(), + &conv_node->GetAttributes(), + "com.microsoft"); + + //Add a new attribute to specify the activation type + fused_conv.AddAttribute("activation", "string"); + + //Add optional attributes for activations + if (act_node.OpType() == "LeakyRelu") { + const NodeAttributes attrs = act_node.GetAttributes(); + for (auto it = attrs.begin(); it != attrs.end(); ++it) { + fused_conv.AddAttribute(it->first, it->second); + } + } + + // Replace the input of the node following activation node + const NodeArg* act_output_def = act_node.OutputDefs()[0]; + NodeArg* fused_conv_output_def = fused_conv.MutableOutputDefs()[0]; + for (auto it = act_node.OutputNodesBegin(); it != act_node.OutputNodesEnd(); ++it) { + auto output_node = graph.GetNode((*it).Index()); + if (!output_node) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT); + } + + auto& input_defs = output_node->MutableInputDefs(); + for (auto& def : input_defs) { + if (def == act_output_def) { + def = fused_conv_output_def; + } + } + } + + removed_nodes.push_back(act_node.Index()); + removed_nodes.push_back(conv_node->Index()); + } + + for (auto i : removed_nodes) { + graph.RemoveNode(i); + } + + if (!removed_nodes.empty()) { + modified = true; + ONNXRUNTIME_RETURN_IF_ERROR(graph.Resolve()); + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/conv_activation_fusion.h b/onnxruntime/core/graph/conv_activation_fusion.h new file mode 100644 index 0000000000..a7158b4a0e --- /dev/null +++ b/onnxruntime/core/graph/conv_activation_fusion.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/graph/graph_transformer.h" + +namespace onnxruntime { + +class ConvActivationFusion : public onnxruntime::GraphTransformer { + public: + ConvActivationFusion() noexcept : onnxruntime::GraphTransformer("ConvActivationFusion", "Fusing Activation into Conv") {} + Status Apply(onnxruntime::Graph& graph, bool& modified) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv.cc b/onnxruntime/core/providers/cpu/nn/conv.cc index 01b38657e8..26ac4a90c0 100644 --- a/onnxruntime/core/providers/cpu/nn/conv.cc +++ b/onnxruntime/core/providers/cpu/nn/conv.cc @@ -4,9 +4,159 @@ #include "core/providers/cpu/nn/conv_impl.h" namespace onnxruntime { + +template <> +Status Conv::Compute(OpKernelContext* context) const { + size_t num_inputs = OpKernel::Node().InputDefs().size(); + const Tensor* X = context->Input(0); + const Tensor* W = context->Input(1); + const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; + const int64_t N = X->Shape()[0]; + const int64_t C = X->Shape()[1]; + const int64_t M = W->Shape()[0]; + ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W)); + + std::vector kernel_shape = ComputeKernelShape(W->Shape()); + + if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) { + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); + } + + for (size_t i = 0; i < kernel_shape.size(); ++i) { + if (kernel_shape[i] != W->Shape()[i + 2]) { + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); + } + } + + std::vector pads(pads_); + if (pads.empty()) { + pads.resize(kernel_shape.size() * 2, 0); + } + std::vector dilations(dilations_); + if (dilations.empty()) { + dilations.resize(kernel_shape.size(), 1); + } + std::vector strides(strides_); + if (strides.empty()) { + strides.resize(kernel_shape.size(), 1); + } + + std::vector Y_dims; + Y_dims.insert(Y_dims.begin(), {N, M}); + TensorShape input_shape = X->Shape().Slice(2); + ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); + Tensor* Y = context->Output(0, TensorShape(Y_dims)); + TensorShape output_shape = Y->Shape().Slice(2); + + AllocatorPtr alloc; + ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); + + const float* Xdata = X->template Data(); + float* Ydata = Y->template MutableData(); + + const size_t kernel_rank = kernel_shape.size(); + + if (kernel_rank == 2 || kernel_rank == 3) { + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize; + MlasConvPrepare(&Parameters, + kernel_rank, + static_cast(N), + static_cast(group_), + static_cast(C / group_), + input_shape.GetDims().data(), + kernel_shape.data(), + dilations.data(), + pads.data(), + strides.data(), + output_shape.GetDims().data(), + static_cast(M / group_), + &WorkingBufferSize); + + auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr; + BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); + + MlasConv(&Parameters, + Xdata, + W->template Data(), + B != nullptr ? B->template Data() : nullptr, + static_cast(working_buffer.get()), + Ydata); + + //TODO: this will be replaced with Tracy's changes. + fuse_activation(activation_, Ydata, Y->Shape().Size(), alpha_); + + } else { + const int64_t input_image_size = input_shape.Size(); + const int64_t output_image_size = output_shape.Size(); + const int64_t kernel_size = TensorShape(kernel_shape).Size(); + const int64_t X_offset = C / group_ * input_image_size; + const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; + const int64_t W_offset = W->Shape().Size() / group_; + const int64_t kernel_dim = C / group_ * kernel_size; + const int64_t col_buffer_size = kernel_dim * output_image_size; + + auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size); + BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); + float* col_buffer_data = static_cast(col_buffer.get()); + + TensorShape image_shape = X->Shape().Slice(1); + std::vector col_buffer_shape{kernel_dim}; + col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(), + output_shape.GetDims().end()); + + for (int image_id = 0; image_id < N; ++image_id) { + for (int group_id = 0; group_id < group_; ++group_id) { + math::Im2colNd( + Xdata + group_id * X_offset, + image_shape.GetDims().data(), + col_buffer_shape.data(), + C * input_image_size, + col_buffer_size, + kernel_shape.data(), + strides.data(), + dilations.data(), + pads.data(), + static_cast(kernel_shape.size()), + col_buffer_data, + &CPUMathUtil::Instance()); + math::Gemm( + CblasNoTrans, + CblasNoTrans, + M / group_, + output_image_size, + kernel_dim, + 1, + W->template Data() + group_id * W_offset, + col_buffer_data, + 0, + Ydata + group_id * Y_offset, + &CPUMathUtil::Instance()); + } + + if (B != nullptr) { + auto Ymatrix = EigenMatrixMap(Ydata, output_image_size, M); + auto Bvec = ConstEigenVectorMap(B->template Data(), M); + Ymatrix.rowwise() += Bvec.transpose(); + } + + fuse_activation(activation_, Ydata, Y_offset * group_, alpha_); + + Xdata += X_offset * group_; + Ydata += Y_offset * group_; + } + } + + return Status::OK(); +} + ONNX_CPU_OPERATOR_KERNEL( Conv, 1, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), Conv); -} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/conv_base.h b/onnxruntime/core/providers/cpu/nn/conv_base.h index 68c2cbdf28..e80acb245d 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_base.h +++ b/onnxruntime/core/providers/cpu/nn/conv_base.h @@ -10,7 +10,6 @@ #include "core/util/math.h" namespace onnxruntime { -namespace { // helper function template @@ -58,7 +57,6 @@ Status ComputePadAndOutputShape( } return Status::OK(); } -} // namespace // base class used by Conv and ConvTranspose class ConvBase { @@ -122,21 +120,21 @@ class ConvBase { if (X->Shape().NumDimensions() != W->Shape().NumDimensions()) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "X num_dims does not match W num_dims.", - " X: ", X->Shape().ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " X: ", X->Shape().ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } if (C != W->Shape()[1] * group_) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Input channels C is not equal to kernel channels * group.", - " C: ", C, - " kernel channels: ", W->Shape()[1], - " group: ", group_); + " C: ", C, + " kernel channels: ", W->Shape()[1], + " group: ", group_); } if (M % group_ != 0) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output channels M is not divisible by group.", - " M: ", M, - " group: ", group_); + " M: ", M, + " group: ", group_); } return Status::OK(); } @@ -179,6 +177,8 @@ class ConvBase { std::vector strides_; std::vector pads_; std::vector dilations_; + std::string activation_; + float alpha_; private: std::vector kernel_shape_; // must use ComputeKernelShape(...), instead of kernel_shape_ diff --git a/onnxruntime/core/providers/cpu/nn/conv_impl.h b/onnxruntime/core/providers/cpu/nn/conv_impl.h index e23ab9d715..40038f9cb1 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_impl.h +++ b/onnxruntime/core/providers/cpu/nn/conv_impl.h @@ -23,6 +23,23 @@ #include "core/mlas/inc/mlas.h" namespace onnxruntime { +template +void fuse_activation(const std::string& activation, T* y_data, size_t size, float alpha) { + EigenVectorArrayMap y_vec(y_data, size); + if (activation.empty()) { + return; + } else if (activation == "Relu") { + y_vec = y_vec.cwiseMax(0); + } else if (activation == "Sigmoid") { + y_vec = (y_vec >= 0).select(1 / (1. + (-y_vec.abs()).exp()), 1 - 1 / (1. + (-y_vec.abs()).exp())); + } else if (activation == "Tanh") { + y_vec = y_vec.tanh(); + } else if (activation == "LeakyRelu") { + y_vec = (y_vec >= 0).select(y_vec, (T)alpha * y_vec); + } else { + ONNXRUNTIME_NOT_IMPLEMENTED("Not implemented fused activation: ", activation); + } +} template Status Conv::Compute(OpKernelContext* context) const { @@ -40,15 +57,15 @@ Status Conv::Compute(OpKernelContext* context) const { if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } for (size_t i = 0; i < kernel_shape.size(); ++i) { if (kernel_shape[i] != W->Shape()[i + 2]) { return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); + " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), + " W: ", W->Shape().ToString().c_str()); } } @@ -151,6 +168,7 @@ Status Conv::Compute(OpKernelContext* context) const { auto Bvec = ConstEigenVectorMap(B->template Data(), M); Ymatrix.rowwise() += Bvec.transpose(); } + fuse_activation(activation_, Ydata, Y_offset * group_, alpha_); Xdata += X_offset * group_; Ydata += Y_offset * group_; @@ -160,146 +178,6 @@ Status Conv::Compute(OpKernelContext* context) const { } template <> -Status Conv::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); - const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; - const int64_t N = X->Shape()[0]; - const int64_t C = X->Shape()[1]; - const int64_t M = W->Shape()[0]; - ONNXRUNTIME_RETURN_IF_ERROR(ValidateInputShape(X, W)); - - std::vector kernel_shape = ComputeKernelShape(W->Shape()); - - if (kernel_shape.size() + 2 != W->Shape().NumDimensions()) { - return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape num_dims is not compatible with W num_dims.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); - } - - for (size_t i = 0; i < kernel_shape.size(); ++i) { - if (kernel_shape[i] != W->Shape()[i + 2]) { - return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "kernel_shape is not compatible with W shape.", - " kernel_shape: ", TensorShape(kernel_shape).ToString().c_str(), - " W: ", W->Shape().ToString().c_str()); - } - } - - std::vector pads(pads_); - if (pads.empty()) { - pads.resize(kernel_shape.size() * 2, 0); - } - std::vector dilations(dilations_); - if (dilations.empty()) { - dilations.resize(kernel_shape.size(), 1); - } - std::vector strides(strides_); - if (strides.empty()) { - strides.resize(kernel_shape.size(), 1); - } - - std::vector Y_dims; - Y_dims.insert(Y_dims.begin(), {N, M}); - TensorShape input_shape = X->Shape().Slice(2); - ONNXRUNTIME_RETURN_IF_ERROR(InferOutputShape(input_shape, kernel_shape, strides, dilations, &pads, &Y_dims)); - Tensor* Y = context->Output(0, TensorShape(Y_dims)); - TensorShape output_shape = Y->Shape().Slice(2); - - AllocatorPtr alloc; - ONNXRUNTIME_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); - - const float* Xdata = X->template Data(); - float* Ydata = Y->template MutableData(); - - const size_t kernel_rank = kernel_shape.size(); - - if (kernel_rank == 2 || kernel_rank == 3) { - MLAS_CONV_PARAMETERS Parameters; - size_t WorkingBufferSize; - MlasConvPrepare(&Parameters, - kernel_rank, - static_cast(N), - static_cast(group_), - static_cast(C / group_), - input_shape.GetDims().data(), - kernel_shape.data(), - dilations.data(), - pads.data(), - strides.data(), - output_shape.GetDims().data(), - static_cast(M / group_), - &WorkingBufferSize); - - auto working_data = WorkingBufferSize > 0 ? alloc->Alloc(sizeof(float) * WorkingBufferSize) : nullptr; - BufferUniquePtr working_buffer(working_data, BufferDeleter(alloc)); - - MlasConv(&Parameters, - Xdata, - W->template Data(), - B != nullptr ? B->template Data() : nullptr, - static_cast(working_buffer.get()), - Ydata); - } else { - const int64_t input_image_size = input_shape.Size(); - const int64_t output_image_size = output_shape.Size(); - const int64_t kernel_size = TensorShape(kernel_shape).Size(); - const int64_t X_offset = C / group_ * input_image_size; - const int64_t Y_offset = Y->Shape().Size() / Y->Shape()[0] / group_; - const int64_t W_offset = W->Shape().Size() / group_; - const int64_t kernel_dim = C / group_ * kernel_size; - const int64_t col_buffer_size = kernel_dim * output_image_size; - - auto col_data = alloc->Alloc(sizeof(float) * col_buffer_size); - BufferUniquePtr col_buffer(col_data, BufferDeleter(alloc)); - float* col_buffer_data = static_cast(col_buffer.get()); - - TensorShape image_shape = X->Shape().Slice(1); - std::vector col_buffer_shape{kernel_dim}; - col_buffer_shape.insert(col_buffer_shape.end(), output_shape.GetDims().begin(), - output_shape.GetDims().end()); - - for (int image_id = 0; image_id < N; ++image_id) { - for (int group_id = 0; group_id < group_; ++group_id) { - math::Im2colNd( - Xdata + group_id * X_offset, - image_shape.GetDims().data(), - col_buffer_shape.data(), - C * input_image_size, - col_buffer_size, - kernel_shape.data(), - strides.data(), - dilations.data(), - pads.data(), - static_cast(kernel_shape.size()), - col_buffer_data, - &CPUMathUtil::Instance()); - math::Gemm( - CblasNoTrans, - CblasNoTrans, - M / group_, - output_image_size, - kernel_dim, - 1, - W->template Data() + group_id * W_offset, - col_buffer_data, - 0, - Ydata + group_id * Y_offset, - &CPUMathUtil::Instance()); - } - - if (B != nullptr) { - auto Ymatrix = EigenMatrixMap(Ydata, output_image_size, M); - auto Bvec = ConstEigenVectorMap(B->template Data(), M); - Ymatrix.rowwise() += Bvec.transpose(); - } - - Xdata += X_offset * group_; - Ydata += Y_offset * group_; - } - } - - return Status::OK(); -} +Status Conv::Compute(OpKernelContext* context) const; } // namespace onnxruntime diff --git a/onnxruntime/test/ir/graph_transform_test.cc b/onnxruntime/test/ir/graph_transform_test.cc index be6a3b6b06..ff457defbd 100644 --- a/onnxruntime/test/ir/graph_transform_test.cc +++ b/onnxruntime/test/ir/graph_transform_test.cc @@ -10,6 +10,8 @@ #include "core/graph/conv_bn_fusion.h" #include "core/graph/conv_mul_fusion.h" #include "core/graph/conv_add_fusion.h" +#include "core/graph/conv_activation_fusion.h" +#include "core/platform/env.h" #include "test/capturing_sink.h" #include "test/test_environment.h" @@ -71,6 +73,25 @@ TEST(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { ASSERT_TRUE(session_object.Initialize().IsOK()); } +TEST(GraphTransformationTests, FuseConvActivation) { + SessionOptions so; + so.session_logid = "GraphTransformationTests.LoadModelToTransform"; + std::string activations[] = {"relu", "sigmoid", "softsign", "tanh", "leakyrelu"}; + + for (std::string act : activations) { + InferenceSession session_object{so, &DefaultLoggingManager()}; + std::string model_uri = MODEL_FOLDER + "fusion/conv_" + act + ".onnx"; + ASSERT_TRUE(session_object.Load(model_uri).IsOK()); + + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK()); + std::unique_ptr ConvActivationFusion_transformer = std::make_unique(); + session_object.RegisterGraphTransformer(std::move(ConvActivationFusion_transformer)); + + ASSERT_TRUE(session_object.Initialize().IsOK()); + } +} + TEST(GraphTransformationTests, FuseConvBNNoBias) { string model_uri = MODEL_FOLDER + "fusion/fuse-conv-bn-no-bias.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/conv_leakyrelu.onnx b/onnxruntime/test/testdata/transform/fusion/conv_leakyrelu.onnx new file mode 100644 index 0000000000..c7ebe41cca --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/conv_leakyrelu.onnx @@ -0,0 +1,27 @@ +onnx_conv_tanh:³ +* +X +WY"Conv* + kernel_shape@@@@  +" +YZ" LeakyRelu* +alphaحجL>  +test-modelZ +X + + + + +Z +W + + + + +b +Z + + + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/conv_relu.onnx b/onnxruntime/test/testdata/transform/fusion/conv_relu.onnx new file mode 100644 index 0000000000..571f92d1ce --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/conv_relu.onnx @@ -0,0 +1,26 @@ +onnx_conv_relu:‌ +* +X +WY"Conv* + kernel_shape@@@@  + +YZ"Relu +test-modelZ +X + + + + +Z +W + + + + +b +Z + + + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/conv_sigmoid.onnx b/onnxruntime/test/testdata/transform/fusion/conv_sigmoid.onnx new file mode 100644 index 0000000000..4a4ecc1836 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/conv_sigmoid.onnx @@ -0,0 +1,26 @@ +onnx_conv_sigmoid:  +* +X +WY"Conv* + kernel_shape@@@@  + +YZ"Sigmoid +test-modelZ +X + + + + +Z +W + + + + +b +Z + + + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/conv_softsign.onnx b/onnxruntime/test/testdata/transform/fusion/conv_softsign.onnx new file mode 100644 index 0000000000..bd58595a08 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/conv_softsign.onnx @@ -0,0 +1,26 @@ +onnx_conv_softsign:، +* +X +WY"Conv* + kernel_shape@@@@  + +YZ"Softsign +test-modelZ +X + + + + +Z +W + + + + +b +Z + + + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/conv_tanh.onnx b/onnxruntime/test/testdata/transform/fusion/conv_tanh.onnx new file mode 100644 index 0000000000..6c0dcf3773 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/conv_tanh.onnx @@ -0,0 +1,26 @@ +onnx_conv_tanh:‌ +* +X +WY"Conv* + kernel_shape@@@@  + +YZ"Tanh +test-modelZ +X + + + + +Z +W + + + + +b +Z + + + + +B \ No newline at end of file