mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
Gelu fusion - kernel and transformer (#1746)
* Gelu contrib op & transformer * Gelu kernels for CPU&cuda * Merged PR 5034: fix a condition for gelu transformer The ONNX models doesn't guarantee to assign an unique name to each node, so the previous condition could fail. (cherry picked from commit e335ef5466444cb0aae45f885ea3a825ed9f1088) * Fix builds * remove useless comments * fix test failure when nocontribp * Move impelmentation under KMSdomain * fix comments * fix linux build * Fix few comments * fix linux build
This commit is contained in:
parent
b0665262c0
commit
9959e84906
21 changed files with 377 additions and 64 deletions
|
|
@ -25,5 +25,13 @@ ONNX_CPU_OPERATOR_KERNEL(
|
|||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
ThresholdedRelu<float>);
|
||||
|
||||
} // namespace contrib
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Gelu,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gelu<float>);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -28,5 +29,19 @@ class ScaledTanh final : public OpKernel {
|
|||
const float beta_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Gelu : public OpKernel {
|
||||
public:
|
||||
Gelu(const OpKernelInfo& info) : OpKernel(info) {}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const auto* X = context->Input<Tensor>(0);
|
||||
Tensor* Y = context->Output(0, X->Shape());
|
||||
EIGEN_X_VAR(xm);
|
||||
EIGEN_Y = xm * 0.5f * ((xm * static_cast<float>(M_SQRT1_2)).erf() + 1.0f);
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -29,8 +29,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CDist);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, CDist);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu);
|
||||
|
||||
// This section includes all opkernel declarations for former experimental ops which have now been removed from onnx.
|
||||
// This section includes all op kernel declarations for former experimental ops which have now been removed from onnx.
|
||||
// To maintain backward compatibility these are added as contrib ops.
|
||||
// Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops
|
||||
// we cannot change the domain now as this will break backward compatibility.
|
||||
|
|
@ -55,7 +56,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Par
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, Conv);
|
||||
|
|
@ -104,6 +104,7 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CDist)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, CDist)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu)>,
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to main backward compatibility
|
||||
|
|
@ -127,7 +128,8 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>};
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
kernel_registry.Register(function_table_entry());
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_ACTIVATION_KERNEL(x, ver, T) \
|
||||
#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
x, \
|
||||
kOnnxDomain, \
|
||||
domain, \
|
||||
ver, \
|
||||
T, \
|
||||
kCudaExecutionProvider, \
|
||||
|
|
@ -37,23 +37,23 @@ namespace cuda {
|
|||
return Status::OK(); \
|
||||
}
|
||||
|
||||
#define UNARY_ACTIVATION_OP_TYPED(name, ver, T) \
|
||||
REGISTER_ACTIVATION_KERNEL(name, ver, T) \
|
||||
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
|
||||
REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \
|
||||
UNARY_ACTIVATION_COMPUTE(name, T)
|
||||
|
||||
#define UNARY_ACTIVATION_OP_HFD(name, ver) \
|
||||
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
|
||||
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
|
||||
UNARY_ACTIVATION_OP_TYPED(name, ver, double)
|
||||
#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \
|
||||
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \
|
||||
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \
|
||||
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double)
|
||||
|
||||
UNARY_ACTIVATION_OP_HFD(Affine, 1);
|
||||
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1);
|
||||
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1);
|
||||
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
|
||||
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
|
||||
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
|
||||
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
|
||||
|
||||
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, MLFloat16)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, float)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, double)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double)
|
||||
|
||||
} //namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -66,6 +66,17 @@ class ScaledTanh final : public UnaryElementwise {
|
|||
float beta_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class Gelu final : public UnaryElementwise {
|
||||
public:
|
||||
Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
MAKE_FUNC_CTX_NULL()
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} //namespace contrib
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -36,6 +36,13 @@ struct OP_ScaledTanh : public CtxScaledTanh {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_Gelu : public CtxGelu {
|
||||
__device__ __inline__ T operator()(const T& a) const {
|
||||
return a * _Normcdf(a);
|
||||
}
|
||||
};
|
||||
|
||||
#define UNARY_ACTIVATION_IMPL(name) \
|
||||
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
|
||||
UnaryElementWiseImpl(input_data, \
|
||||
|
|
|
|||
|
|
@ -11,11 +11,13 @@ namespace cuda {
|
|||
typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
|
||||
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
|
||||
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
|
||||
typedef onnxruntime::cuda::CtxNull CtxGelu;
|
||||
|
||||
#define UNARY_CONTRIB_ACTIVATION_OPS() \
|
||||
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
|
||||
UNARY_ACTIVATION_OP_NAME(Affine) \
|
||||
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus)
|
||||
#define UNARY_CONTRIB_ACTIVATION_OPS() \
|
||||
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
|
||||
UNARY_ACTIVATION_OP_NAME(Affine) \
|
||||
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
|
||||
UNARY_ACTIVATION_OP_NAME(Gelu)
|
||||
|
||||
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ using namespace onnxruntime::common;
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to maintain backward compatibility
|
||||
|
|
@ -34,10 +37,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
|
||||
|
||||
|
||||
void RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
|
||||
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to maintain backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Affine)>,
|
||||
|
|
@ -60,7 +65,7 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
|
||||
};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#include "core/graph/op.h"
|
||||
#include "onnx/defs/schema.h"
|
||||
#include "onnx/defs/shape_inference.h"
|
||||
#include "onnx/defs/function.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
|
||||
#ifdef MICROSOFT_INTERNAL
|
||||
|
|
@ -1723,10 +1724,23 @@ Example 4:
|
|||
RegisterNchwcSchemas();
|
||||
}
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.SetDoc("Gelu")
|
||||
.Input(0, "X", "The input data as Tensor.", "T")
|
||||
.Output(0, "Y", "The output.", "T")
|
||||
.TypeConstraint(
|
||||
"T",
|
||||
{"tensor(float16)", "tensor(float)", "tensor(double)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
#ifdef MICROSOFT_INTERNAL
|
||||
// register internal ops
|
||||
RegisterInternalSchemas();
|
||||
#endif
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
179
onnxruntime/core/optimizer/gelu_fusion.cc
Normal file
179
onnxruntime/core/optimizer/gelu_fusion.cc
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "float.h"
|
||||
#include <deque>
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
namespace onnxruntime {
|
||||
|
||||
static bool CheckConstantInput(const Graph& graph, const NodeArg& input_arg, float expected_value) {
|
||||
auto shape = input_arg.Shape();
|
||||
auto dim_size = shape->dim_size();
|
||||
if (dim_size != 0) {
|
||||
// only check scalar.
|
||||
return false;
|
||||
}
|
||||
|
||||
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
if (tensor_proto == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto init_const = std::make_unique<Initializer>(tensor_proto);
|
||||
const auto data_type = tensor_proto->data_type();
|
||||
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
|
||||
float* val = init_const->data<float>();
|
||||
float diff = std::abs(val[0] - static_cast<float>(expected_value));
|
||||
if (diff > FLT_EPSILON) {
|
||||
return false;
|
||||
}
|
||||
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
|
||||
double* val = init_const->data<double>();
|
||||
double diff = std::abs(val[0] - static_cast<double>(expected_value));
|
||||
if (diff > DBL_EPSILON) {
|
||||
return false;
|
||||
}
|
||||
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
|
||||
MLFloat16* val = init_const->data<MLFloat16>();
|
||||
float diff = std::abs(math::halfToFloat(val[0].val) - static_cast<float>(expected_value));
|
||||
if (diff > FLT_EPSILON) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// Gelu supports limited data types.
|
||||
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"};
|
||||
|
||||
static bool IsSupportedDataType(const Node& node) {
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
if (std::find(supported_data_types.begin(), supported_data_types.end(),
|
||||
*(input_arg->Type())) == supported_data_types.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
std::deque<onnxruntime::NodeIndex> removed_nodes;
|
||||
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto& div = *graph.GetNode(node_index);
|
||||
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
|
||||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
|
||||
div.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(div)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check second input is sqrt(2)
|
||||
if (!CheckConstantInput(graph, *(div.MutableInputDefs()[1]), static_cast<float>(M_SQRT2))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node& erf_node = *(div.OutputNodesBegin());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) ||
|
||||
erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
erf_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(erf_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node& add_node = *(erf_node.OutputNodesBegin());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
|
||||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
add_node.GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(add_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the other input node(e.g. not of type Erf) is 1.0f.
|
||||
const Node& add_first_input_node = *(add_node.InputNodesBegin());
|
||||
int add_const_input_index = 0;
|
||||
if (add_first_input_node.OpType().compare("Erf") == 0) {
|
||||
add_const_input_index = 1;
|
||||
}
|
||||
const auto& add_const_input_arg = add_node.InputDefs()[add_const_input_index];
|
||||
if (!CheckConstantInput(graph, *add_const_input_arg, 1.0f)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node& mul_node = *(add_node.OutputNodesBegin());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
|
||||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
!IsSupportedDataType(mul_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const Node* mul2_node = nullptr;
|
||||
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
|
||||
if ((*iter).OpType().compare("Mul") == 0) {
|
||||
// find the other input node of Mul
|
||||
mul2_node = &(*iter);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (mul2_node == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*mul2_node, "Mul", {7}) ||
|
||||
mul2_node->GetExecutionProviderType() != div.GetExecutionProviderType() ||
|
||||
mul2_node->GetOutputEdgesCount() != 1 ||
|
||||
!IsSupportedDataType(*mul2_node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the other input node(e.g. not of type Add) is 0.5f.
|
||||
int mul_const_input_index = 0;
|
||||
if (mul2_node->InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
|
||||
mul_const_input_index = 1;
|
||||
}
|
||||
|
||||
const auto& mul_const_input_arg = mul2_node->InputDefs()[mul_const_input_index];
|
||||
if (!CheckConstantInput(graph, *mul_const_input_arg, 0.5f)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
|
||||
const std::vector<NodeArg*> gelu_output_defs{const_cast<NodeArg*>(mul_node.OutputDefs()[0])};
|
||||
Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"),
|
||||
"Gelu",
|
||||
"fused Gelu subgraphs ",
|
||||
gelu_input_defs,
|
||||
gelu_output_defs, {}, kMSDomain);
|
||||
|
||||
// Assign provider to this new node. Provider should be same as the provider for old node.
|
||||
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
|
||||
|
||||
removed_nodes.push_front(div.Index());
|
||||
removed_nodes.push_front(erf_node.Index());
|
||||
removed_nodes.push_front(add_node.Index());
|
||||
removed_nodes.push_front(mul2_node->Index());
|
||||
removed_nodes.push_front(mul_node.Index());
|
||||
}
|
||||
|
||||
// Have to remove node in reversed order for now to walk around the issue in RemoveNode
|
||||
for (onnxruntime::NodeIndex removed_node : removed_nodes) {
|
||||
graph.RemoveNode(removed_node);
|
||||
}
|
||||
|
||||
if (!removed_nodes.empty()) {
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace onnxruntime
|
||||
27
onnxruntime/core/optimizer/gelu_fusion.h
Normal file
27
onnxruntime/core/optimizer/gelu_fusion.h
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
@Class GeluFusion
|
||||
|
||||
Rewrite graph fusing Gelu activation subgraph to a single Gelu node.
|
||||
|
||||
The formula corresponding to Gelu activation subgraph:
|
||||
x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input.
|
||||
|
||||
*/
|
||||
class GeluFusion : public GraphTransformer {
|
||||
public:
|
||||
GeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -18,6 +18,7 @@
|
|||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/nchwc_transformer.h"
|
||||
#include "core/optimizer/free_dim_override_transformer.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/session/inference_session.h"
|
||||
|
||||
|
|
@ -114,6 +115,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
|
|||
transformers.emplace_back(onnxruntime::make_unique<GemmActivationFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<MatMulAddFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(l2_execution_providers));
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(l2_execution_providers));
|
||||
#endif
|
||||
} break;
|
||||
|
||||
|
|
|
|||
|
|
@ -45,5 +45,4 @@ Status Tanh<float>::Compute(OpKernelContext* context) const {
|
|||
MlasComputeTanh(X->template Data<float>(), Y->template MutableData<float>(), x_shape.Size());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -188,5 +188,4 @@ class ThresholdedRelu final : public OpKernel {
|
|||
private:
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -34,16 +33,16 @@ typedef CtxNull CtxSoftsign;
|
|||
typedef CtxNull CtxTanh;
|
||||
typedef CtxAlpha CtxThresholdedRelu;
|
||||
|
||||
#define UNARY_ACTIVATION_OPS() \
|
||||
UNARY_ACTIVATION_OP_NAME(Elu) \
|
||||
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
|
||||
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
|
||||
UNARY_ACTIVATION_OP_NAME(Relu) \
|
||||
UNARY_ACTIVATION_OP_NAME(Selu) \
|
||||
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
|
||||
UNARY_ACTIVATION_OP_NAME(Softplus) \
|
||||
UNARY_ACTIVATION_OP_NAME(Softsign) \
|
||||
UNARY_ACTIVATION_OP_NAME(Tanh) \
|
||||
#define UNARY_ACTIVATION_OPS() \
|
||||
UNARY_ACTIVATION_OP_NAME(Elu) \
|
||||
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
|
||||
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
|
||||
UNARY_ACTIVATION_OP_NAME(Relu) \
|
||||
UNARY_ACTIVATION_OP_NAME(Selu) \
|
||||
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
|
||||
UNARY_ACTIVATION_OP_NAME(Softplus) \
|
||||
UNARY_ACTIVATION_OP_NAME(Softsign) \
|
||||
UNARY_ACTIVATION_OP_NAME(Tanh) \
|
||||
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu)
|
||||
|
||||
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ __global__ void _BinaryElementWiseSimple(
|
|||
const T* lhs_data,
|
||||
const T* rhs_data,
|
||||
T* output_data,
|
||||
FuncT func,
|
||||
const FuncT& func,
|
||||
CUDA_LONG N) {
|
||||
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
|
||||
output_data[id] = func(lhs_data[IncL ? id : 0], rhs_data[IncR ? id : 0]);
|
||||
|
|
|
|||
|
|
@ -174,6 +174,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
|
|||
template <typename T>
|
||||
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
|
||||
|
||||
template <typename T>
|
||||
__device__ __inline__ T _Normcdf(T a);
|
||||
|
||||
template <>
|
||||
__device__ __inline__ float _Normcdf(float a) { return normcdff(a); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ double _Normcdf(double a) { return normcdf(a); }
|
||||
|
||||
template <>
|
||||
__device__ __inline__ half _Normcdf(half a) { return half(normcdff((float)a)); }
|
||||
|
||||
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
|
||||
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,8 +14,9 @@ void TestActivationContribOp(const char* szOp, std::vector<float>& input_vals,
|
|||
std::function<float(float)> expected_func,
|
||||
const std::unordered_map<std::string, float> attribs = {},
|
||||
bool is_tensorrt_supported = true,
|
||||
int opset_version = 7) {
|
||||
OpTester test(szOp, opset_version);
|
||||
int opset_version = 7,
|
||||
const char* domain = kOnnxDomain) {
|
||||
OpTester test(szOp, opset_version, domain);
|
||||
|
||||
for (auto attr : attribs)
|
||||
test.AddAttribute(attr.first, attr.second);
|
||||
|
|
@ -45,10 +46,11 @@ std::vector<float> input_values = {
|
|||
|
||||
TEST(ActivationContribOpTest, ThresholdedRelu_version_1_to_9) {
|
||||
float alpha = 0.1f;
|
||||
TestActivationContribOp("ThresholdedRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
{{"alpha", alpha}}, true, 1);
|
||||
TestActivationContribOp(
|
||||
"ThresholdedRelu",
|
||||
input_values,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
{{"alpha", alpha}}, true, 1);
|
||||
}
|
||||
|
||||
TEST(ActivationContribOpTest, ScaledTanh) {
|
||||
|
|
@ -77,5 +79,13 @@ TEST(ActivationContribOpTest, ParametricSoftplus) {
|
|||
{{"alpha", alpha}, {"beta", beta}});
|
||||
}
|
||||
|
||||
TEST(ActivationContribOpTest, Gelu) {
|
||||
TestActivationContribOp(
|
||||
"Gelu",
|
||||
input_values,
|
||||
[](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast<float>(M_SQRT1_2))); },
|
||||
{}, false, 1, kMSDomain);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -537,7 +538,6 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
|||
}
|
||||
|
||||
TEST(GraphTransformationTests, ReluClipFusion) {
|
||||
|
||||
// Clip op schema changed for opset version 11. Until Clip op is updated in ORT hard coding this model to use
|
||||
// older opset.
|
||||
Model model("ReluClipFusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {});
|
||||
|
|
@ -606,5 +606,26 @@ TEST(GraphTransformationTests, ReluClipFusion) {
|
|||
}
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
TEST(GraphTransformationTests, GeluFusionTest) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/gelu.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2);
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Div"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Erf"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Mul"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Gelu"] == 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,15 +12,14 @@ void TestUnaryElementwiseOp(const char* szOp, std::vector<float>& input_vals,
|
|||
std::function<float(float)> expected_func,
|
||||
const std::unordered_map<std::string, float> attribs = {},
|
||||
bool is_tensorrt_supported = true,
|
||||
int opset_version = 7) {
|
||||
OpTester test(szOp, opset_version);
|
||||
int opset_version = 7, const char* domain = kOnnxDomain) {
|
||||
OpTester test(szOp, opset_version, domain);
|
||||
|
||||
for (auto attr : attribs)
|
||||
test.AddAttribute(attr.first, attr.second);
|
||||
|
||||
std::vector<int64_t> dims{(int64_t)input_vals.size()};
|
||||
|
||||
|
||||
std::vector<float> expected_vals;
|
||||
for (const auto& iv : input_vals)
|
||||
expected_vals.push_back(expected_func(iv));
|
||||
|
|
@ -36,11 +35,11 @@ void TestUnaryElementwiseOp(const char* szOp, std::vector<float>& input_vals,
|
|||
|
||||
//Disabled because of accuracy issues for MYRIAD FP16 and VAD_M
|
||||
#if defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
|
||||
int relu = strcmp(szOp, "Relu");
|
||||
int leaky = strcmp(szOp, "LeakyRelu");
|
||||
if(relu == 0 || leaky == 0){
|
||||
excluded_providers.insert(kOpenVINOExecutionProvider);
|
||||
}
|
||||
int relu = strcmp(szOp, "Relu");
|
||||
int leaky = strcmp(szOp, "LeakyRelu");
|
||||
if (relu == 0 || leaky == 0) {
|
||||
excluded_providers.insert(kOpenVINOExecutionProvider);
|
||||
}
|
||||
#endif
|
||||
|
||||
test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_providers);
|
||||
|
|
@ -108,10 +107,11 @@ TEST(ActivationOpTest, LeakyRelu) {
|
|||
|
||||
TEST(ActivationOpTest, ThresholdedRelu) {
|
||||
float alpha = 0.1f;
|
||||
TestUnaryElementwiseOp("ThresholdedRelu",
|
||||
input_vals,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
{{"alpha", alpha}}, true, 10);
|
||||
TestUnaryElementwiseOp(
|
||||
"ThresholdedRelu",
|
||||
input_vals,
|
||||
[alpha](float x) { return (x >= alpha) ? x : 0; },
|
||||
{{"alpha", alpha}}, true, 10);
|
||||
}
|
||||
|
||||
TEST(ActivationOpTest, Selu) {
|
||||
|
|
@ -204,9 +204,10 @@ TEST(ActivationOpTest, Softplus) {
|
|||
}
|
||||
|
||||
TEST(ActivationOpTest, Softsign) {
|
||||
TestUnaryElementwiseOp("Softsign",
|
||||
no_inf_input_vals,
|
||||
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
|
||||
TestUnaryElementwiseOp(
|
||||
"Softsign",
|
||||
no_inf_input_vals,
|
||||
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/gelu.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/gelu.onnx
vendored
Normal file
Binary file not shown.
Loading…
Reference in a new issue