Gelu fusion - kernel and transformer (#1746)

* Gelu contrib op & transformer

* Gelu kernels for CPU&cuda

* Merged PR 5034: fix a condition for gelu transformer

The ONNX models doesn't guarantee to assign an unique name to each node, so the previous condition could fail.

(cherry picked from commit e335ef5466444cb0aae45f885ea3a825ed9f1088)

* Fix builds

* remove useless comments

* fix test failure when nocontribp

* Move impelmentation under KMSdomain

* fix comments

* fix linux build

* Fix few comments

* fix linux build
This commit is contained in:
pengwa 2019-10-03 19:34:46 +08:00 committed by GitHub
parent b0665262c0
commit 9959e84906
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 377 additions and 64 deletions

View file

@ -25,5 +25,13 @@ ONNX_CPU_OPERATOR_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ThresholdedRelu<float>);
} // namespace contrib
ONNX_OPERATOR_KERNEL_EX(
Gelu,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gelu<float>);
} // namespace contrib
} // namespace onnxruntime

View file

@ -6,6 +6,7 @@
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"
#include <unsupported/Eigen/SpecialFunctions>
namespace onnxruntime {
namespace contrib {
@ -28,5 +29,19 @@ class ScaledTanh final : public OpKernel {
const float beta_;
};
template <typename T>
class Gelu : public OpKernel {
public:
Gelu(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* context) const override {
const auto* X = context->Input<Tensor>(0);
Tensor* Y = context->Output(0, X->Shape());
EIGEN_X_VAR(xm);
EIGEN_Y = xm * 0.5f * ((xm * static_cast<float>(M_SQRT1_2)).erf() + 1.0f);
return Status::OK();
}
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -29,8 +29,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CDist);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, CDist);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu);
// This section includes all opkernel declarations for former experimental ops which have now been removed from onnx.
// This section includes all op kernel declarations for former experimental ops which have now been removed from onnx.
// To maintain backward compatibility these are added as contrib ops.
// Note: the domain for all contrib ops should be MSDomain. However since these ops started out as onnx domain ops
// we cannot change the domain now as this will break backward compatibility.
@ -55,7 +56,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Par
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderInput);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, ReorderOutput);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSNchwcDomain, 1, float, Conv);
@ -104,6 +104,7 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, CDist)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, CDist)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gelu)>,
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to main backward compatibility
@ -127,7 +128,8 @@ void RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>};
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Scale)>,
};
for (auto& function_table_entry : function_table) {
kernel_registry.Register(function_table_entry());

View file

@ -10,10 +10,10 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_ACTIVATION_KERNEL(x, ver, T) \
#define REGISTER_ACTIVATION_KERNEL(x, ver, domain, T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
x, \
kOnnxDomain, \
domain, \
ver, \
T, \
kCudaExecutionProvider, \
@ -37,23 +37,23 @@ namespace cuda {
return Status::OK(); \
}
#define UNARY_ACTIVATION_OP_TYPED(name, ver, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, T) \
#define UNARY_ACTIVATION_OP_TYPED(name, ver, domain, T) \
REGISTER_ACTIVATION_KERNEL(name, ver, domain, T) \
UNARY_ACTIVATION_COMPUTE(name, T)
#define UNARY_ACTIVATION_OP_HFD(name, ver) \
UNARY_ACTIVATION_OP_TYPED(name, ver, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, double)
#define UNARY_ACTIVATION_OP_HFD(name, ver, domain) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, MLFloat16) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, float) \
UNARY_ACTIVATION_OP_TYPED(name, ver, domain, double)
UNARY_ACTIVATION_OP_HFD(Affine, 1);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1);
UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, float)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, double)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, double)
} //namespace cuda
} // namespace contrib

View file

@ -66,6 +66,17 @@ class ScaledTanh final : public UnaryElementwise {
float beta_;
};
template <typename T>
class Gelu final : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_NULL()
};
} // namespace cuda
} //namespace contrib
} // namespace contrib
} // namespace onnxruntime

View file

@ -36,6 +36,13 @@ struct OP_ScaledTanh : public CtxScaledTanh {
}
};
template <typename T>
struct OP_Gelu : public CtxGelu {
__device__ __inline__ T operator()(const T& a) const {
return a * _Normcdf(a);
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(input_data, \

View file

@ -11,11 +11,13 @@ namespace cuda {
typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
typedef onnxruntime::cuda::CtxNull CtxGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus)
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu)
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template <typename T> \

View file

@ -9,6 +9,9 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace cuda {
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to maintain backward compatibility
@ -34,10 +37,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu);
void RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to maintain backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, Affine)>,
@ -60,7 +65,7 @@ void RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ThresholdedRelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ThresholdedRelu)>,
};
for (auto& function_table_entry : function_table) {

View file

@ -9,6 +9,7 @@
#include "core/graph/op.h"
#include "onnx/defs/schema.h"
#include "onnx/defs/shape_inference.h"
#include "onnx/defs/function.h"
#include "core/mlas/inc/mlas.h"
#ifdef MICROSOFT_INTERNAL
@ -1723,10 +1724,23 @@ Example 4:
RegisterNchwcSchemas();
}
ONNX_CONTRIB_OPERATOR_SCHEMA(Gelu)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Gelu")
.Input(0, "X", "The input data as Tensor.", "T")
.Output(0, "Y", "The output.", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
#ifdef MICROSOFT_INTERNAL
// register internal ops
RegisterInternalSchemas();
#endif
}
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -0,0 +1,179 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/initializer.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/graph/graph_utils.h"
#include "float.h"
#include <deque>
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
static bool CheckConstantInput(const Graph& graph, const NodeArg& input_arg, float expected_value) {
auto shape = input_arg.Shape();
auto dim_size = shape->dim_size();
if (dim_size != 0) {
// only check scalar.
return false;
}
const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
if (tensor_proto == nullptr) {
return false;
}
auto init_const = std::make_unique<Initializer>(tensor_proto);
const auto data_type = tensor_proto->data_type();
if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
float* val = init_const->data<float>();
float diff = std::abs(val[0] - static_cast<float>(expected_value));
if (diff > FLT_EPSILON) {
return false;
}
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
double* val = init_const->data<double>();
double diff = std::abs(val[0] - static_cast<double>(expected_value));
if (diff > DBL_EPSILON) {
return false;
}
} else if (data_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
MLFloat16* val = init_const->data<MLFloat16>();
float diff = std::abs(math::halfToFloat(val[0].val) - static_cast<float>(expected_value));
if (diff > FLT_EPSILON) {
return false;
}
}
return true;
}
// Gelu supports limited data types.
static std::vector<std::string> supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"};
static bool IsSupportedDataType(const Node& node) {
for (const auto& input_arg : node.InputDefs()) {
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
}
return true;
}
Status GeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
std::deque<onnxruntime::NodeIndex> removed_nodes;
for (auto node_index : node_topology_list) {
auto& div = *graph.GetNode(node_index);
ORT_RETURN_IF_ERROR(Recurse(div, modified, graph_level));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(div, "Div", {7}) ||
!graph_utils::IsSupportedProvider(div, GetCompatibleExecutionProviders()) ||
div.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(div)) {
continue;
}
// Check second input is sqrt(2)
if (!CheckConstantInput(graph, *(div.MutableInputDefs()[1]), static_cast<float>(M_SQRT2))) {
continue;
}
const Node& erf_node = *(div.OutputNodesBegin());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(erf_node, "Erf", {9}) ||
erf_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
erf_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(erf_node)) {
continue;
}
const Node& add_node = *(erf_node.OutputNodesBegin());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(add_node, "Add", {7}) ||
add_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
add_node.GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(add_node)) {
continue;
}
// Check the other input node(e.g. not of type Erf) is 1.0f.
const Node& add_first_input_node = *(add_node.InputNodesBegin());
int add_const_input_index = 0;
if (add_first_input_node.OpType().compare("Erf") == 0) {
add_const_input_index = 1;
}
const auto& add_const_input_arg = add_node.InputDefs()[add_const_input_index];
if (!CheckConstantInput(graph, *add_const_input_arg, 1.0f)) {
continue;
}
const Node& mul_node = *(add_node.OutputNodesBegin());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7}) ||
mul_node.GetExecutionProviderType() != div.GetExecutionProviderType() ||
!IsSupportedDataType(mul_node)) {
continue;
}
const Node* mul2_node = nullptr;
for (auto iter = mul_node.InputNodesBegin(); iter != mul_node.InputNodesEnd(); ++iter) {
if ((*iter).OpType().compare("Mul") == 0) {
// find the other input node of Mul
mul2_node = &(*iter);
break;
}
}
if (mul2_node == nullptr) {
continue;
}
if (!graph_utils::IsSupportedOptypeVersionAndDomain(*mul2_node, "Mul", {7}) ||
mul2_node->GetExecutionProviderType() != div.GetExecutionProviderType() ||
mul2_node->GetOutputEdgesCount() != 1 ||
!IsSupportedDataType(*mul2_node)) {
continue;
}
// Check the other input node(e.g. not of type Add) is 0.5f.
int mul_const_input_index = 0;
if (mul2_node->InputDefs()[0]->Name() == div.MutableInputDefs()[0]->Name()) {
mul_const_input_index = 1;
}
const auto& mul_const_input_arg = mul2_node->InputDefs()[mul_const_input_index];
if (!CheckConstantInput(graph, *mul_const_input_arg, 0.5f)) {
continue;
}
const std::vector<NodeArg*> gelu_input_defs{div.MutableInputDefs()[0]};
const std::vector<NodeArg*> gelu_output_defs{const_cast<NodeArg*>(mul_node.OutputDefs()[0])};
Node& gelu_node = graph.AddNode(graph.GenerateNodeName("Gelu"),
"Gelu",
"fused Gelu subgraphs ",
gelu_input_defs,
gelu_output_defs, {}, kMSDomain);
// Assign provider to this new node. Provider should be same as the provider for old node.
gelu_node.SetExecutionProviderType(div.GetExecutionProviderType());
removed_nodes.push_front(div.Index());
removed_nodes.push_front(erf_node.Index());
removed_nodes.push_front(add_node.Index());
removed_nodes.push_front(mul2_node->Index());
removed_nodes.push_front(mul_node.Index());
}
// Have to remove node in reversed order for now to walk around the issue in RemoveNode
for (onnxruntime::NodeIndex removed_node : removed_nodes) {
graph.RemoveNode(removed_node);
}
if (!removed_nodes.empty()) {
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
@Class GeluFusion
Rewrite graph fusing Gelu activation subgraph to a single Gelu node.
The formula corresponding to Gelu activation subgraph:
x * 0.5 * (1.0 + erf(x / sqrt(2.0))), where x is the input.
*/
class GeluFusion : public GraphTransformer {
public:
GeluFusion(const std::unordered_set<std::string>& compatible_execution_providers = {}) noexcept
: GraphTransformer("GeluFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
};
} // namespace onnxruntime

View file

@ -18,6 +18,7 @@
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/optimizer/gelu_fusion.h"
#include "core/mlas/inc/mlas.h"
#include "core/session/inference_session.h"
@ -114,6 +115,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
transformers.emplace_back(onnxruntime::make_unique<GemmActivationFusion>(l2_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<MatMulAddFusion>(l2_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<ConvActivationFusion>(l2_execution_providers));
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(l2_execution_providers));
#endif
} break;

View file

@ -45,5 +45,4 @@ Status Tanh<float>::Compute(OpKernelContext* context) const {
MlasComputeTanh(X->template Data<float>(), Y->template MutableData<float>(), x_shape.Size());
return Status::OK();
}
} // namespace onnxruntime

View file

@ -188,5 +188,4 @@ class ThresholdedRelu final : public OpKernel {
private:
const float alpha_;
};
} // namespace onnxruntime

View file

@ -1,6 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
@ -34,16 +33,16 @@ typedef CtxNull CtxSoftsign;
typedef CtxNull CtxTanh;
typedef CtxAlpha CtxThresholdedRelu;
#define UNARY_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(Elu) \
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
UNARY_ACTIVATION_OP_NAME(Relu) \
UNARY_ACTIVATION_OP_NAME(Selu) \
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
UNARY_ACTIVATION_OP_NAME(Softplus) \
UNARY_ACTIVATION_OP_NAME(Softsign) \
UNARY_ACTIVATION_OP_NAME(Tanh) \
#define UNARY_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(Elu) \
UNARY_ACTIVATION_OP_NAME(HardSigmoid) \
UNARY_ACTIVATION_OP_NAME(LeakyRelu) \
UNARY_ACTIVATION_OP_NAME(Relu) \
UNARY_ACTIVATION_OP_NAME(Selu) \
UNARY_ACTIVATION_OP_NAME(Sigmoid) \
UNARY_ACTIVATION_OP_NAME(Softplus) \
UNARY_ACTIVATION_OP_NAME(Softsign) \
UNARY_ACTIVATION_OP_NAME(Tanh) \
UNARY_ACTIVATION_OP_NAME(ThresholdedRelu)
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \

View file

@ -51,7 +51,7 @@ __global__ void _BinaryElementWiseSimple(
const T* lhs_data,
const T* rhs_data,
T* output_data,
FuncT func,
const FuncT& func,
CUDA_LONG N) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);
output_data[id] = func(lhs_data[IncL ? id : 0], rhs_data[IncR ? id : 0]);

View file

@ -174,6 +174,18 @@ __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }
template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }
template <typename T>
__device__ __inline__ T _Normcdf(T a);
template <>
__device__ __inline__ float _Normcdf(float a) { return normcdff(a); }
template <>
__device__ __inline__ double _Normcdf(double a) { return normcdf(a); }
template <>
__device__ __inline__ half _Normcdf(half a) { return half(normcdff((float)a)); }
// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer
// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type.

View file

@ -14,8 +14,9 @@ void TestActivationContribOp(const char* szOp, std::vector<float>& input_vals,
std::function<float(float)> expected_func,
const std::unordered_map<std::string, float> attribs = {},
bool is_tensorrt_supported = true,
int opset_version = 7) {
OpTester test(szOp, opset_version);
int opset_version = 7,
const char* domain = kOnnxDomain) {
OpTester test(szOp, opset_version, domain);
for (auto attr : attribs)
test.AddAttribute(attr.first, attr.second);
@ -45,10 +46,11 @@ std::vector<float> input_values = {
TEST(ActivationContribOpTest, ThresholdedRelu_version_1_to_9) {
float alpha = 0.1f;
TestActivationContribOp("ThresholdedRelu",
input_values,
[alpha](float x) { return (x >= alpha) ? x : 0; },
{{"alpha", alpha}}, true, 1);
TestActivationContribOp(
"ThresholdedRelu",
input_values,
[alpha](float x) { return (x >= alpha) ? x : 0; },
{{"alpha", alpha}}, true, 1);
}
TEST(ActivationContribOpTest, ScaledTanh) {
@ -77,5 +79,13 @@ TEST(ActivationContribOpTest, ParametricSoftplus) {
{{"alpha", alpha}, {"beta", beta}});
}
TEST(ActivationContribOpTest, Gelu) {
TestActivationContribOp(
"Gelu",
input_values,
[](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast<float>(M_SQRT1_2))); },
{}, false, 1, kMSDomain);
}
} // namespace test
} // namespace onnxruntime

View file

@ -29,6 +29,7 @@
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/constant_folding.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/gelu_fusion.h"
using namespace std;
using namespace ONNX_NAMESPACE;
@ -537,7 +538,6 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
}
TEST(GraphTransformationTests, ReluClipFusion) {
// Clip op schema changed for opset version 11. Until Clip op is updated in ORT hard coding this model to use
// older opset.
Model model("ReluClipFusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {});
@ -606,5 +606,26 @@ TEST(GraphTransformationTests, ReluClipFusion) {
}
}
#ifndef DISABLE_CONTRIB_OPS
TEST(GraphTransformationTests, GeluFusionTest) {
string model_uri = MODEL_FOLDER + "fusion/gelu.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model).IsOK());
Graph& graph = p_model->MainGraph();
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
graph_transformation_mgr.Register(std::make_unique<GeluFusion>(), TransformerLevel::Level2);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2);
ASSERT_TRUE(ret.IsOK());
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Div"] == 0);
ASSERT_TRUE(op_to_count["Add"] == 0);
ASSERT_TRUE(op_to_count["Erf"] == 0);
ASSERT_TRUE(op_to_count["Mul"] == 0);
ASSERT_TRUE(op_to_count["Gelu"] == 1);
}
#endif
} // namespace test
} // namespace onnxruntime

View file

@ -12,15 +12,14 @@ void TestUnaryElementwiseOp(const char* szOp, std::vector<float>& input_vals,
std::function<float(float)> expected_func,
const std::unordered_map<std::string, float> attribs = {},
bool is_tensorrt_supported = true,
int opset_version = 7) {
OpTester test(szOp, opset_version);
int opset_version = 7, const char* domain = kOnnxDomain) {
OpTester test(szOp, opset_version, domain);
for (auto attr : attribs)
test.AddAttribute(attr.first, attr.second);
std::vector<int64_t> dims{(int64_t)input_vals.size()};
std::vector<float> expected_vals;
for (const auto& iv : input_vals)
expected_vals.push_back(expected_func(iv));
@ -36,11 +35,11 @@ void TestUnaryElementwiseOp(const char* szOp, std::vector<float>& input_vals,
//Disabled because of accuracy issues for MYRIAD FP16 and VAD_M
#if defined(OPENVINO_CONFIG_MYRIAD) || defined(OPENVINO_CONFIG_VAD_M)
int relu = strcmp(szOp, "Relu");
int leaky = strcmp(szOp, "LeakyRelu");
if(relu == 0 || leaky == 0){
excluded_providers.insert(kOpenVINOExecutionProvider);
}
int relu = strcmp(szOp, "Relu");
int leaky = strcmp(szOp, "LeakyRelu");
if (relu == 0 || leaky == 0) {
excluded_providers.insert(kOpenVINOExecutionProvider);
}
#endif
test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_providers);
@ -108,10 +107,11 @@ TEST(ActivationOpTest, LeakyRelu) {
TEST(ActivationOpTest, ThresholdedRelu) {
float alpha = 0.1f;
TestUnaryElementwiseOp("ThresholdedRelu",
input_vals,
[alpha](float x) { return (x >= alpha) ? x : 0; },
{{"alpha", alpha}}, true, 10);
TestUnaryElementwiseOp(
"ThresholdedRelu",
input_vals,
[alpha](float x) { return (x >= alpha) ? x : 0; },
{{"alpha", alpha}}, true, 10);
}
TEST(ActivationOpTest, Selu) {
@ -204,9 +204,10 @@ TEST(ActivationOpTest, Softplus) {
}
TEST(ActivationOpTest, Softsign) {
TestUnaryElementwiseOp("Softsign",
no_inf_input_vals,
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
TestUnaryElementwiseOp(
"Softsign",
no_inf_input_vals,
[](float x) { return x / (1 + std::abs(x)); }, {}, false); // Disable TensorRT because result mismatches
}
} // namespace test

Binary file not shown.