diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 9c19ac4f7e..a76178f16f 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -66,6 +66,7 @@ Do not modify directly.*
* com.microsoft.QuantizeBFP
* com.microsoft.QuantizeLinear
* com.microsoft.QuantizeWithOrder
+ * com.microsoft.QuickGelu
* com.microsoft.Range
* com.microsoft.ReduceSumInteger
* com.microsoft.Rfft
@@ -3476,6 +3477,43 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.QuickGelu**
+
+ Compute x * Sigmoid(alpha * x).
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- alpha : float
+- Alpha value.
+
+
+#### Inputs
+
+
+- X : T
+- The input data as Tensor.
+
+
+#### Outputs
+
+
+- Y : T
+- The output.
+
+
+#### Type Constraints
+
+
+- T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
+- Constrain input and output types to float tensors.
+
+
+
### **com.microsoft.Range**
Creates a sequence of numbers that begins at `start` and extends by increments of `delta`
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 1f0cac6835..a72bc8b2d3 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -434,6 +434,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* x_zero_point:**T**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)|
@@ -791,6 +792,7 @@ Do not modify directly.*
|QOrderedMatMul|*in* A:**Q**
*in* scale_A:**S**
*in* B:**Q**
*in* scale_B:**S**
*in* scale_Y:**S**
*in* bias:**S**
*in* C:**Q**
*in* scale_C:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float16)
**T2** = tensor(int8), tensor(uint8)|
|QuantizeWithOrder|*in* input:**F**
*in* scale_input:**S**
*out* output:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
+|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc
index 6a9fbe52ef..556699192d 100644
--- a/onnxruntime/contrib_ops/cpu/activations.cc
+++ b/onnxruntime/contrib_ops/cpu/activations.cc
@@ -34,5 +34,13 @@ ONNX_OPERATOR_KERNEL_EX(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
Gelu);
+ONNX_OPERATOR_KERNEL_EX(
+ QuickGelu,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ QuickGelu);
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h
index 82a60a375f..4a4f76c2ef 100644
--- a/onnxruntime/contrib_ops/cpu/activations.h
+++ b/onnxruntime/contrib_ops/cpu/activations.h
@@ -93,5 +93,46 @@ class Gelu : public OpKernel {
}
};
+// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
+// MlasComputeLogistic instead of using Eigen for better perf.
+template
+class QuickGelu : public OpKernel {
+ public:
+ QuickGelu(const OpKernelInfo& info) : OpKernel(info) { alpha_ = info.GetAttrOrDefault("alpha", 1.702f); }
+
+ Status Compute(OpKernelContext* context) const override {
+ const Tensor* input = context->Input(0);
+ const T* input_data = input->template Data();
+ Tensor* output = context->Output(0, input->Shape());
+ T* output_data = output->template MutableData();
+ concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
+ int64_t elem_count = input->Shape().Size();
+ constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
+ int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
+ concurrency::ThreadPool::TryBatchParallelFor(
+ tp, static_cast(task_count),
+ [&](ptrdiff_t task_idx) {
+ const auto start = task_idx * length_per_task;
+ const T* p_input = input_data + start;
+ T* p_output = output_data + start;
+ int64_t count = std::min(length_per_task, elem_count - start);
+ for (int64_t i = 0; i < count; i++) {
+ p_output[i] = p_input[i] * alpha_;
+ }
+
+ MlasComputeLogistic(p_output, p_output, count);
+
+ for (int64_t i = 0; i < count; i++) {
+ p_output[i] = p_input[i] * p_output[i];
+ }
+ },
+ 0);
+ return Status::OK();
+ }
+
+ private:
+ float alpha_;
+};
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index d4273aeb96..f10fa6233c 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -41,6 +41,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasG
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
@@ -219,6 +220,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to main backward compatibility
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc
index 113104287c..8eecde2b95 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.cc
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc
@@ -50,6 +50,7 @@ UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
+UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h
index 763e6cb922..ab339f276c 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations.h
@@ -77,6 +77,20 @@ class Gelu final : public UnaryElementwise {
MAKE_FUNC_CTX_NULL()
};
+template
+class QuickGelu final : public UnaryElementwise {
+ public:
+ QuickGelu(const OpKernelInfo& info) : UnaryElementwise(info) {
+ alpha_ = info.GetAttrOrDefault("alpha", 1.702f);
+ }
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ MAKE_FUNC_CTX_ALPHA()
+ float alpha_;
+};
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
index 65a1dc992a..0c856815fd 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu
@@ -50,6 +50,17 @@ struct OP_Gelu : public CtxGelu {
}
};
+template
+struct OP_QuickGelu : public CtxQuickGelu {
+ __device__ __inline__ T operator()(const T& a) const {
+ T v = a * static_cast(alpha);
+ T one = static_cast(1.f);
+ T zero = static_cast(0.f);
+ T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
+ return a * sigmoid;
+ }
+};
+
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \
diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
index 56ece01e46..5d18283a39 100644
--- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
+++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h
@@ -12,21 +12,14 @@ typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
typedef onnxruntime::cuda::CtxNull CtxGelu;
+typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
- UNARY_ACTIVATION_OP_NAME(Gelu)
-
-#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
- template \
- void Impl_##name( \
- cudaStream_t stream, \
- const T* input_data, \
- T* output_data, \
- const Ctx##name* func_ctx, \
- size_t count)
+ UNARY_ACTIVATION_OP_NAME(Gelu) \
+ UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_CONTRIB_ACTIVATION_OPS()
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index f63490f5e7..603a6ec8bb 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -19,6 +19,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
@@ -130,6 +133,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo, // backward compatibility
diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
index d997213a0a..b92efc3a61 100644
--- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc
@@ -19,6 +19,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
@@ -124,6 +127,9 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo, // backward compatibility
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index bb1cf49350..0cbf041238 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -558,6 +558,36 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BiasGelu, 1,
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
+constexpr const char* QuickGelu_ver1_doc = R"DOC(Compute x * Sigmoid(alpha * x).)DOC";
+ONNX_MS_OPERATOR_SET_SCHEMA(
+ QuickGelu, 1,
+ OpSchema()
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc(QuickGelu_ver1_doc)
+ .Attr("alpha", "Alpha value.", AttributeProto::FLOAT, 1.702f)
+ .Input(0, "X", "The input data as Tensor.", "T")
+ .Output(0, "Y", "The output.", "T")
+ .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
+ .SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema,
+ FunctionProto& functionProto) {
+ auto* tp = ctx.getInputType(0);
+ if ((tp == nullptr) || (!tp->has_tensor_type())) return false;
+ auto elem_type = (TensorProto_DataType)(tp->tensor_type().elem_type());
+ auto* alpha_attr = ctx.getAttribute("alpha");
+ float alpha = (alpha_attr != nullptr) ? alpha_attr->f() : 1.702f;
+ FunctionBuilder builder(functionProto);
+ builder.AddOpset("", 13).Const("Alpha", ToTensor(alpha, elem_type)).Add(R"(
+ CX = Mul (Alpha, X)
+ SIGMOIDCX = Sigmoid (CX)
+ Y = Mul (X, SIGMOIDCX)
+ )");
+ schema.BuildFunction(functionProto);
+ return true;
+ }));
+
// Used to be ONNX 1.7 Inverse(12)
// Comment out docs not to increase the binary size
//
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index 47d051ffda..fe0d9956eb 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -64,6 +64,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedGemm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatherND);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse);
@@ -142,6 +143,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index a8b3147bfb..21be292a61 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -57,6 +57,7 @@
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
#include "core/optimizer/qdq_transformer/qdq_s8_to_u8.h"
#include "core/optimizer/qdq_transformer/relu_quantizelinear.h"
+#include "core/optimizer/quick_gelu_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@@ -275,6 +276,7 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
+ transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps));
diff --git a/onnxruntime/core/optimizer/quick_gelu_fusion.cc b/onnxruntime/core/optimizer/quick_gelu_fusion.cc
new file mode 100644
index 0000000000..93de7a64bd
--- /dev/null
+++ b/onnxruntime/core/optimizer/quick_gelu_fusion.cc
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/optimizer/quick_gelu_fusion.h"
+
+#include "core/graph/graph_utils.h"
+#include "core/optimizer/initializer.h"
+#include "core/optimizer/utils.h"
+
+using namespace ONNX_NAMESPACE;
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+
+/**
+Rewrite x*sigmoid(alpha*x) or x*sigmoid(x) to QuickGelu.
+*/
+Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
+ GraphViewer graph_viewer(graph);
+ const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
+ for (auto node_index : node_topology_list) {
+ auto* p_node = graph.GetNode(node_index);
+ if (!p_node) continue;
+
+ Node& node = *p_node;
+ ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
+
+ InlinedVector> nodes_to_fuse;
+
+ int alpha_index = -1;
+ float alpha = 1.0f;
+ if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}) &&
+ graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1) {
+ for (int i = 0; i < static_cast(node.InputDefs().size()); ++i) {
+ const NodeArg& input_arg = *(node.InputDefs()[i]);
+ if (!optimizer_utils::IsScalar(input_arg)) continue;
+ const TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
+ if (!tensor_proto) continue;
+ Initializer init_const{*tensor_proto, graph.ModelPath()};
+ const auto data_type = tensor_proto->data_type();
+ if (data_type == TensorProto_DataType_FLOAT) {
+ alpha = *(init_const.data());
+ alpha_index = i;
+ break;
+ } else if (data_type == TensorProto_DataType_DOUBLE) {
+ alpha = static_cast(*(init_const.data()));
+ alpha_index = i;
+ break;
+ } else if (data_type == TensorProto_DataType_FLOAT16) {
+ alpha = math::halfToFloat(init_const.data()->val);
+ alpha_index = i;
+ break;
+ }
+ }
+ }
+
+ NodeArg* quick_gelu_input_arg = nullptr;
+ Node* p_sigmoid_node = p_node;
+ // If alpha_index is not -1, it means the node is Mul node and it has a scalar input.
+ // We expect the output of Mul node is consumed by a Sigmoid node.
+ // If alpha_index is -1, it means current node is expected to be a Sigmoid node.
+ if (alpha_index != -1) {
+ quick_gelu_input_arg = node.MutableInputDefs()[(alpha_index + 1) % 2];
+ nodes_to_fuse.emplace_back(node);
+ p_sigmoid_node = graph.GetNode(node.OutputNodesBegin()->Index());
+ }
+
+ Node& sigmoid_node = *p_sigmoid_node;
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain(sigmoid_node, "Sigmoid", {6, 13}) ||
+ !graph_utils::IsSupportedProvider(sigmoid_node, GetCompatibleExecutionProviders()) ||
+ sigmoid_node.GetOutputEdgesCount() != 1) {
+ continue;
+ }
+ nodes_to_fuse.emplace_back(sigmoid_node);
+ if (!quick_gelu_input_arg) {
+ quick_gelu_input_arg = sigmoid_node.MutableInputDefs()[0];
+ }
+
+ Node& mul_node = *graph.GetNode(sigmoid_node.OutputNodesBegin()->Index());
+ int sigmoid_output_index = optimizer_utils::IndexOfNodeInput(mul_node, *sigmoid_node.MutableOutputDefs()[0]);
+ if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
+ !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) ||
+ mul_node.MutableInputDefs()[(sigmoid_output_index + 1) % 2]->Name() != quick_gelu_input_arg->Name()) {
+ continue;
+ }
+ nodes_to_fuse.emplace_back(mul_node);
+
+ NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0];
+ Node& quick_gelu_node =
+ graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
+ std::array{quick_gelu_output_arg}, {}, kMSDomain);
+ quick_gelu_node.AddAttribute("alpha", alpha);
+ quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
+ graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, quick_gelu_node);
+ modified = true;
+ }
+
+ return Status::OK();
+}
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/quick_gelu_fusion.h b/onnxruntime/core/optimizer/quick_gelu_fusion.h
new file mode 100644
index 0000000000..2131219515
--- /dev/null
+++ b/onnxruntime/core/optimizer/quick_gelu_fusion.h
@@ -0,0 +1,21 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/optimizer/graph_transformer.h"
+
+namespace onnxruntime {
+
+/**
+ * @brief Rewrite graph fusing x*sigmoid(alpha*x) or x*sigmoid(x) to QuickGelu.
+ */
+class QuickGeluFusion : public GraphTransformer {
+ public:
+ QuickGeluFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept
+ : GraphTransformer("QuickGeluFusion", compatible_execution_providers) {}
+
+ Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
+};
+
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/activation/activations_impl.h b/onnxruntime/core/providers/cuda/activation/activations_impl.h
index 53359ae7a7..954cff766a 100644
--- a/onnxruntime/core/providers/cuda/activation/activations_impl.h
+++ b/onnxruntime/core/providers/cuda/activation/activations_impl.h
@@ -48,7 +48,7 @@ typedef CtxAlpha CtxThresholdedRelu;
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template \
void Impl_##name( \
- cudaStream_t stream, \
+ cudaStream_t stream, \
const T* input_data, \
T* output_data, \
const Ctx##name* func_ctx, \
diff --git a/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh b/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh
index 2949d1b62c..a41888d0df 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/binary_elementwise_impl.cuh
@@ -74,7 +74,7 @@ __global__ void _BinaryElementWiseSimple(
const T1* lhs_data,
const T2* rhs_data,
T* output_data,
- const FuncT& func,
+ const FuncT func,
CUDA_LONG N) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T1 lvalue[NumElementsPerThread];
diff --git a/onnxruntime/test/contrib_ops/activation_op_test.cc b/onnxruntime/test/contrib_ops/activation_op_test.cc
index 9e7b22608f..40423ace45 100644
--- a/onnxruntime/test/contrib_ops/activation_op_test.cc
+++ b/onnxruntime/test/contrib_ops/activation_op_test.cc
@@ -47,6 +47,53 @@ TEST_F(ActivationOpTest, Gelu) {
"Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast(M_SQRT1_2))); }, {},
false, 1, kMSDomain);
}
-} // namespace test
+TEST_F(ActivationOpTest, QuickGelu) {
+ // QuickGelu is not a single activation, some corner values in input_values will not work.
+ std::vector> quick_gelu_input_values{{-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}};
+
+ // Positive alpha.
+ {
+ float alpha = 1.702f;
+ TestActivationOp(
+ "QuickGelu", quick_gelu_input_values,
+ [alpha](float x) {
+ auto tmp = x * alpha;
+ auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
+ y = tmp >= 0 ? y : 1 - y;
+ return x * y;
+ },
+ {{"alpha", alpha}}, false, 1, kMSDomain);
+ }
+
+ // Silu = x*sigmoid(x), i.e., alpha = 1.0f.
+ {
+ float alpha = 1.0f;
+ TestActivationOp(
+ "QuickGelu", quick_gelu_input_values,
+ [alpha](float x) {
+ auto tmp = x * alpha;
+ auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
+ y = tmp >= 0 ? y : 1 - y;
+ return x * y;
+ },
+ {{"alpha", alpha}}, false, 1, kMSDomain);
+ }
+
+ // Negative alpha.
+ {
+ float alpha = -1.702f;
+ TestActivationOp(
+ "QuickGelu", quick_gelu_input_values,
+ [alpha](float x) {
+ auto tmp = x * alpha;
+ auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
+ y = tmp >= 0 ? y : 1 - y;
+ return x * y;
+ },
+ {{"alpha", alpha}}, false, 1, kMSDomain);
+ }
+}
+
+} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc
index c391ca2048..aaea49e962 100755
--- a/onnxruntime/test/optimizer/graph_transform_test.cc
+++ b/onnxruntime/test/optimizer/graph_transform_test.cc
@@ -60,6 +60,7 @@
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/propagate_cast_ops.h"
+#include "core/optimizer/quick_gelu_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@@ -3508,6 +3509,218 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
}
+TEST_F(GraphTransformationTests, QuickGelu) {
+ // Sigmoid(x*alpha)*x, float
+ {
+ const float alpha = 1.702f;
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({{2, 3, 3, 3}});
+ auto* alpha_arg = builder.MakeInitializer({}, {alpha});
+ auto* mul_out_0 = builder.MakeIntermediate();
+ auto* sigmoid_out = builder.MakeIntermediate();
+ auto* mul_out_1 = builder.MakeOutput();
+
+ builder.AddNode("Mul", {input_arg, alpha_arg}, {mul_out_0});
+ builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
+ builder.AddNode("Mul", {sigmoid_out, input_arg}, {mul_out_1});
+ };
+
+ auto pre_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ };
+
+ auto post_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
+ for (auto& node : graph.Nodes()) {
+ if (node.OpType() == "QuickGelu") {
+ auto& attrs = node.GetAttributes();
+ ASSERT_TRUE(attrs.find("alpha") != attrs.end());
+ ASSERT_EQ(alpha, attrs.at("alpha").f());
+ }
+ }
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
+ pre_graph_checker, post_graph_checker);
+ }
+
+ // x*Sigmoid(alpha*x), MLFloat16
+ {
+ const float alpha = -1.f;
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({{2, 3, 3, 3}});
+ auto* alpha_arg = builder.MakeInitializer({}, {static_cast(alpha)});
+ auto* mul_out_0 = builder.MakeIntermediate();
+ auto* sigmoid_out = builder.MakeIntermediate();
+ auto* mul_out_1 = builder.MakeOutput();
+
+ builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
+ builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
+ builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
+ };
+
+ auto pre_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ };
+
+ auto post_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
+ for (auto& node : graph.Nodes()) {
+ if (node.OpType() == "QuickGelu") {
+ auto& attrs = node.GetAttributes();
+ ASSERT_TRUE(attrs.find("alpha") != attrs.end());
+ ASSERT_EQ(alpha, attrs.at("alpha").f());
+ }
+ }
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
+ pre_graph_checker, post_graph_checker);
+ }
+
+ // Sigmoid's output is consumed by other node.
+ {
+ const float alpha = 1.702f;
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({{2, 3, 3, 3}});
+ auto* alpha_arg = builder.MakeInitializer({}, {alpha});
+ auto* mul_out_0 = builder.MakeIntermediate();
+ auto* sigmoid_out = builder.MakeIntermediate();
+ auto* mul_out_1 = builder.MakeOutput();
+ auto* identity_out = builder.MakeOutput();
+
+ builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
+ builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
+ builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
+ builder.AddNode("Identity", {sigmoid_out}, {identity_out});
+ };
+
+ auto pre_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ };
+
+ auto post_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 0);
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
+ pre_graph_checker, post_graph_checker);
+ }
+
+ // First Mul's output is consumed by other node.
+ {
+ const float alpha = -1.f;
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({{2, 3, 3, 3}});
+ auto* alpha_arg = builder.MakeInitializer({}, {static_cast(alpha)});
+ auto* mul_out_0 = builder.MakeIntermediate();
+ auto* sigmoid_out = builder.MakeIntermediate();
+ auto* mul_out_1 = builder.MakeOutput();
+ auto* identity_out = builder.MakeOutput();
+
+ builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
+ builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
+ builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
+ builder.AddNode("Identity", {mul_out_0}, {identity_out});
+ };
+
+ auto pre_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ };
+
+ auto post_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 0);
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
+ pre_graph_checker, post_graph_checker);
+ }
+
+ // Sigmoid(x)*x, float
+ {
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({{2, 3, 3, 3}});
+ auto* sigmoid_out = builder.MakeIntermediate();
+ auto* mul_out = builder.MakeOutput();
+
+ builder.AddNode("Sigmoid", {input_arg}, {sigmoid_out});
+ builder.AddNode("Mul", {sigmoid_out, input_arg}, {mul_out});
+ };
+
+ auto pre_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 1);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ };
+
+ auto post_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
+ for (auto& node : graph.Nodes()) {
+ if (node.OpType() == "QuickGelu") {
+ auto& attrs = node.GetAttributes();
+ ASSERT_TRUE(attrs.find("alpha") != attrs.end());
+ ASSERT_EQ(1.0f, attrs.at("alpha").f());
+ }
+ }
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
+ pre_graph_checker, post_graph_checker);
+ }
+
+ // x*Sigmoid(x), MLFloat16
+ {
+ auto build_test_case = [&](ModelTestBuilder& builder) {
+ auto* input_arg = builder.MakeInput({{2, 3, 3, 3}});
+ auto* sigmoid_out = builder.MakeIntermediate();
+ auto* mul_out = builder.MakeOutput();
+
+ builder.AddNode("Sigmoid", {input_arg}, {sigmoid_out});
+ builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out});
+ };
+
+ auto pre_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 1);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
+ };
+
+ auto post_graph_checker = [&](Graph& graph) {
+ ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
+ ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
+ for (auto& node : graph.Nodes()) {
+ if (node.OpType() == "QuickGelu") {
+ auto& attrs = node.GetAttributes();
+ ASSERT_TRUE(attrs.find("alpha") != attrs.end());
+ ASSERT_EQ(1.0f, attrs.at("alpha").f());
+ }
+ }
+ };
+
+ std::unique_ptr transformer = std::make_unique();
+ TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
+ pre_graph_checker, post_graph_checker);
+ }
+}
+
struct BiasSoftmaxFusionTester {
std::shared_ptr p_model_;
Status model_load_;
diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc
index 3242893bc8..c8e0f2bd5e 100755
--- a/orttraining/orttraining/core/graph/gradient_builder.cc
+++ b/orttraining/orttraining/core/graph/gradient_builder.cc
@@ -704,6 +704,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetSigmoidGradient) {
{GI(0)})};
}
+IMPLEMENT_GRADIENT_BUILDER(GetQuickGeluGradient) {
+ return std::vector{
+ NodeDef(OpDef{"QuickGeluGrad", kMSDomain, 1}, {GO(0), I(0)}, {GI(0)}, SrcNodeAttributes())};
+}
+
IMPLEMENT_GRADIENT_BUILDER(GetSoftmaxGradient) {
return std::vector{
NodeDef(OpDef{SrcNodeOpsetVersion() < 13 ? "SoftmaxGrad" : "SoftmaxGrad_13", kMSDomain, 1},
diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h
index 174c650201..cecf37d415 100755
--- a/orttraining/orttraining/core/graph/gradient_builder.h
+++ b/orttraining/orttraining/core/graph/gradient_builder.h
@@ -43,6 +43,7 @@ DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
DECLARE_GRADIENT_BUILDER(GetSqueezeGradient)
DECLARE_GRADIENT_BUILDER(GetSigmoidGradient)
+DECLARE_GRADIENT_BUILDER(GetQuickGeluGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxGradient)
DECLARE_GRADIENT_BUILDER(GetLogSoftmaxGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyGradient)
diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc
index 35da30c0be..c5693fb616 100755
--- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc
+++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc
@@ -74,6 +74,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
REGISTER_GRADIENT_BUILDER("Unsqueeze", GetUnsqueezeGradient);
REGISTER_GRADIENT_BUILDER("Sigmoid", GetSigmoidGradient);
+ REGISTER_GRADIENT_BUILDER("QuickGelu", GetQuickGeluGradient);
REGISTER_GRADIENT_BUILDER("Softmax", GetSoftmaxGradient);
REGISTER_GRADIENT_BUILDER("LogSoftmax", GetLogSoftmaxGradient);
REGISTER_GRADIENT_BUILDER("SoftmaxCrossEntropy", GetSoftmaxCrossEntropyGradient);
diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc
index 7fba133954..a0e010c419 100644
--- a/orttraining/orttraining/core/graph/training_op_defs.cc
+++ b/orttraining/orttraining/core/graph/training_op_defs.cc
@@ -2634,6 +2634,19 @@ Example 4:
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
});
+ ONNX_CONTRIB_OPERATOR_SCHEMA(QuickGeluGrad)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .SetDoc("QuickGeluGrad")
+ .Attr("alpha", "Alpha value.", AttributeProto::FLOAT, 1.702f)
+ .AllowUncheckedAttributes()
+ .Input(0, "dY", "The gradient tensor from output.", "T")
+ .Input(1, "X", "The input tensor. ", "T")
+ .Output(0, "dX", "Gradient of the input.", "T")
+ .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
+ "Constrain input and output types to float tensors.")
+ .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
+
ONNX_CONTRIB_OPERATOR_SCHEMA(TanhGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
index 3149600ee8..421c5c8663 100644
--- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
+++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc
@@ -40,6 +40,7 @@
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/propagate_cast_ops.h"
+#include "core/optimizer/quick_gelu_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@@ -99,6 +100,7 @@ std::vector> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique(compatible_eps));
transformers.emplace_back(std::make_unique(compatible_eps));
transformers.emplace_back(std::make_unique(compatible_eps));
+ transformers.emplace_back(std::make_unique(compatible_eps));
transformers.emplace_back(std::make_unique(compatible_eps));
transformers.emplace_back(std::make_unique(compatible_eps));
diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc
index 6f1b50fd43..0dc062e095 100644
--- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc
+++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc
@@ -152,15 +152,16 @@ void GenerateRandomDataWithOneHot(std::vector>& x_datas, std:
void UnaryOpGradientTest(const std::string& op_type, const std::string& domain = kOnnxDomain,
const int opset_version = 9,
std::vector>* execution_providers = nullptr,
- std::function* transformer = nullptr) {
+ std::function* transformer = nullptr,
+ const std::vector& attributes = {},
+ float error_tolerance = 1e-3f) {
TensorShape shape({2, 3, 4});
TensorInfo x_info{shape, true, transformer};
float max_error;
- float error_tolerance = 1e-3f;
GradientChecker gradient_checker;
OpDef op_def{op_type, domain, opset_version};
- ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info}, {shape}, &max_error, {}, true, false,
+ ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info}, {shape}, &max_error, attributes, true, false,
execution_providers));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
@@ -1548,6 +1549,23 @@ TEST(GradientCheckerTest, DISABLED_BatchNormalizationGrad) {
TEST(GradientCheckerTest, SigmoidGrad) { UnaryOpGradientTest("Sigmoid"); }
+TEST(GradientCheckerTest, QuickGeluGrad) {
+ // Default alpha = 1.702, relax the tolerance due failure on Win for some seed.
+ { UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, {}, 5e-2f); }
+
+ // Silu, alpha = 1.0.
+ {
+ std::vector attributes = {MakeAttribute("alpha", 1.0f)};
+ UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, attributes, 5e-2f);
+ }
+
+ // Negative alpha.
+ {
+ std::vector attributes = {MakeAttribute("alpha", -1.702f)};
+ UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, attributes, 5e-2f);
+ }
+}
+
void GradientCheckerSoftmaxGradHelper(bool is_log_softmax, int version = 11) {
TensorShape shape({2, 3, 4});
float max_error;
diff --git a/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc
index c2d0cf21bb..fc551ae75b 100644
--- a/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc
+++ b/orttraining/orttraining/test/training_ops/cpu/activation/activation_op_test.cc
@@ -83,6 +83,12 @@ constexpr float SigmoidGrad(float dy, float y) {
constexpr float TanhGrad(float dy, float y) {
return dy * (1 - y * y);
}
+
+float QuickGeluGrad(float dy, float x, float alpha) {
+ float v = x * alpha;
+ float sigmoid = v >= 0 ? 1.f / (1.f + std::exp(-v)) : 1.f - 1.f / (1 + std::exp(v));
+ return dy * sigmoid * (1 + v * (1 - sigmoid));
+}
} // namespace
TEST(GeluGradTest, Basic) {
@@ -199,6 +205,50 @@ TEST(TanhGradTest, Basic) {
{}, 1, kMSDomain);
}
+TEST(QuickGeluGradTest, Basic) {
+ const std::vector x_vals = {-10.0f, -1.0f, 0.0f, 1.0f, 10.0f};
+ const std::vector dY(5, 1.0f);
+
+ // Positive alpha.
+ {
+ const float alpha = 1.702f;
+ TestElementwiseGradientOp(
+ "QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
+ [alpha](const std::vector& params) {
+ ORT_ENFORCE(params.size() == 2);
+ const auto dy = params[0], x = params[1];
+ return QuickGeluGrad(dy, x, alpha);
+ },
+ {{"alpha", alpha}}, 1, kMSDomain);
+ }
+
+ // Silu = x*sigmoid(x), i.e., alpha = 1.0f.
+ {
+ const float alpha = 1.0f;
+ TestElementwiseGradientOp(
+ "QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
+ [alpha](const std::vector& params) {
+ ORT_ENFORCE(params.size() == 2);
+ const auto dy = params[0], x = params[1];
+ return QuickGeluGrad(dy, x, alpha);
+ },
+ {{"alpha", alpha}}, 1, kMSDomain);
+ }
+
+ // Negative alpha.
+ {
+ const float alpha = -1.702f;
+ TestElementwiseGradientOp(
+ "QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
+ [alpha](const std::vector& params) {
+ ORT_ENFORCE(params.size() == 2);
+ const auto dy = params[0], x = params[1];
+ return QuickGeluGrad(dy, x, alpha);
+ },
+ {{"alpha", alpha}}, 1, kMSDomain);
+ }
+}
+
namespace {
template
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,
diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc
index 66a28f61c1..da74aae921 100644
--- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc
+++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc
@@ -50,6 +50,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad);
// REVIEW(mzs): ConstEigenVectorArrayMap.cast,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// REVIEW(mzs): ConstEigenVectorArrayMap.cast,
diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc
index bbc018a831..e1b02cb7e3 100644
--- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc
+++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc
@@ -1,15 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#include "op_gradients.h"
-#include "core/util/math.h"
-#include "core/util/math_cpuonly.h"
+#include "orttraining/training_ops/cpu/op_gradients.h"
+
+#include "core/mlas/inc/mlas.h"
#include "core/providers/common.h"
-#include
-#include "core/util/math.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/cpu/tensor/transpose.h"
+#include "core/util/math.h"
+#include "core/util/math_cpuonly.h"
+#include
#include "gsl/gsl"
namespace onnxruntime {
@@ -210,5 +211,44 @@ Status TanhGrad::Compute(OpKernelContext* context) const {
dx = dy * (1 - y * y);
return Status::OK();
}
+
+ONNX_OPERATOR_KERNEL_EX(QuickGeluGrad, kMSDomain, 1, kCpuExecutionProvider,
+ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ QuickGeluGrad);
+
+template
+Status QuickGeluGrad::Compute(OpKernelContext* context) const {
+ auto& dY = *context->Input(0);
+ const T* dY_data = dY.template Data();
+ auto& X = *context->Input(1);
+ const T* X_data = X.template Data();
+ auto& dX = *context->Output(0, dY.Shape());
+ T* dX_data = dX.template MutableData();
+ concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
+ int64_t elem_count = dY.Shape().Size();
+ constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
+ int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
+ concurrency::ThreadPool::TryBatchParallelFor(
+ tp, static_cast(task_count),
+ [&](ptrdiff_t task_idx) {
+ const auto start = task_idx * length_per_task;
+ const T* p_dy = dY_data + start;
+ const T* p_x = X_data + start;
+ T* p_dx = dX_data + start;
+ int64_t count = std::min(length_per_task, elem_count - start);
+ for (int64_t i = 0; i < count; i++) {
+ p_dx[i] = p_x[i] * alpha_;
+ }
+
+ MlasComputeLogistic(p_dx, p_dx, count);
+
+ for (int64_t i = 0; i < count; i++) {
+ p_dx[i] = p_dy[i] * p_dx[i] * (1.f + alpha_ * p_x[i] * (1.f - p_dx[i]));
+ }
+ },
+ 0);
+ return Status::OK();
+}
+
} // namespace contrib
} // namespace onnxruntime
diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.h b/orttraining/orttraining/training_ops/cpu/op_gradients.h
index 4a6ec568b4..be1268dac5 100644
--- a/orttraining/orttraining/training_ops/cpu/op_gradients.h
+++ b/orttraining/orttraining/training_ops/cpu/op_gradients.h
@@ -33,6 +33,20 @@ class SigmoidGrad final : public OpKernel {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SigmoidGrad);
};
+template
+class QuickGeluGrad final : public OpKernel {
+ public:
+ explicit QuickGeluGrad(const OpKernelInfo& info) : OpKernel(info) {
+ alpha_ = info.GetAttrOrDefault("alpha", 1.702f);
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QuickGeluGrad);
+ float alpha_;
+};
+
template
class TanhGrad final : public OpKernel {
public:
diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc
index f42a523848..7c0c064c1f 100644
--- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc
+++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.cc
@@ -47,6 +47,7 @@ ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
+ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
} // namespace cuda
diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h
index 444565a371..34de4ef8bb 100644
--- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h
+++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad.h
@@ -55,6 +55,20 @@ class SigmoidGrad final : public BinaryElementwise {
MAKE_FUNC_CTX_NULL()
};
+template
+class QuickGeluGrad final : public BinaryElementwise {
+ public:
+ QuickGeluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {
+ alpha_ = info.GetAttrOrDefault("alpha", 1.702f);
+ }
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ MAKE_FUNC_CTX_ALPHA()
+ float alpha_;
+};
+
template
class TanhGrad final : public BinaryElementwise {
public:
diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu
index 1633cc45b7..2c23a3ed87 100644
--- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu
+++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.cu
@@ -46,6 +46,17 @@ struct OP_SigmoidGrad : public CtxSigmoidGrad {
}
};
+template
+struct OP_QuickGeluGrad : public CtxQuickGeluGrad {
+ __device__ __inline__ T operator()(const T& dy, const T& x) const {
+ T v = x * static_cast(alpha);
+ T one = static_cast(1.f);
+ T zero = static_cast(0.f);
+ T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
+ return dy * sigmoid * (one + v * (one - sigmoid));
+ }
+};
+
template
struct OP_TanhGrad : public CtxTanhGrad {
__device__ __inline__ T operator()(const T& dy, const T& y) const {
diff --git a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h
index af18144377..8e925f0484 100644
--- a/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h
+++ b/orttraining/orttraining/training_ops/cuda/activation/activations_grad_impl.h
@@ -11,13 +11,15 @@ typedef onnxruntime::cuda::CtxNull CtxGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxFastGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxReluGrad;
typedef onnxruntime::cuda::CtxNull CtxSigmoidGrad;
+typedef onnxruntime::cuda::CtxAlpha CtxQuickGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
-#define ACTIVATION_GRAD_OPS() \
- ACTIVATION_GRAD_OP_NAME(GeluGrad) \
- ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
- ACTIVATION_GRAD_OP_NAME(ReluGrad) \
- ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
+#define ACTIVATION_GRAD_OPS() \
+ ACTIVATION_GRAD_OP_NAME(GeluGrad) \
+ ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
+ ACTIVATION_GRAD_OP_NAME(ReluGrad) \
+ ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
+ ACTIVATION_GRAD_OP_NAME(QuickGeluGrad) \
ACTIVATION_GRAD_OP_NAME(TanhGrad)
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc
index 5265f908fe..1a48aae152 100644
--- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc
+++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc
@@ -112,6 +112,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SigmoidGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
@@ -335,6 +339,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc
index f3964a3173..cb10290449 100644
--- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc
+++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc
@@ -109,6 +109,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SigmoidGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
@@ -301,6 +305,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,