QuickGelu Fusion (#12417)

Some models have QuickGelu(x)=x*sigmoid(1.702x), which has 3 Ops for
forward and 5 Ops for backward. The PR is to fuse this to a single Op
named QuickGelu and its gradient QuickGeluGrad.

For CUDA, tested in V100 using input tensor with shape [64,128,2048] and
float16 type:
Before, FW takes 335us, BW takes 614us

![image](https://user-images.githubusercontent.com/11661208/182291335-15188709-ffe7-44d1-9d14-0b544cbe5e55.png)

After, FW takes 115us, BW takes 139us, which is much faster.

![image](https://user-images.githubusercontent.com/11661208/182291502-f0b5161c-b95c-45fc-90f8-ad0c592d2433.png)

For CPU kernel, using same shape and float type:
Before, FW takes 10us, BW takes 49us
Mul: 3480[µs]
Sigmoid: 1996[µs]
Mul: 4789[µs]
Mul: 4642[µs]
Mul: 4195[µs]
SigmoidGrad: 18328[µs]
Mul: 2988[µs]
Sum: 18576[µs]

After, FW takes 4us, BW takes 5us, which is also much faster.
QuickGelu: 3939[µs]
QuickGeluGrad: 5089[µs]

Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
Vincent Wang 2022-10-28 18:12:07 +08:00 committed by GitHub
parent 20c3c35c33
commit 8b0669bf63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 752 additions and 26 deletions

View file

@ -66,6 +66,7 @@ Do not modify directly.*
* <a href="#com.microsoft.QuantizeBFP">com.microsoft.QuantizeBFP</a>
* <a href="#com.microsoft.QuantizeLinear">com.microsoft.QuantizeLinear</a>
* <a href="#com.microsoft.QuantizeWithOrder">com.microsoft.QuantizeWithOrder</a>
* <a href="#com.microsoft.QuickGelu">com.microsoft.QuickGelu</a>
* <a href="#com.microsoft.Range">com.microsoft.Range</a>
* <a href="#com.microsoft.ReduceSumInteger">com.microsoft.ReduceSumInteger</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
@ -3476,6 +3477,43 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.QuickGelu"></a><a name="com.microsoft.quickgelu">**com.microsoft.QuickGelu**</a>
Compute x * Sigmoid(alpha * x).
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>alpha</tt> : float</dt>
<dd>Alpha value.</dd>
</dl>
#### Inputs
<dl>
<dt><tt>X</tt> : T</dt>
<dd>The input data as Tensor.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>Y</tt> : T</dt>
<dd>The output.</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>
### <a name="com.microsoft.Range"></a><a name="com.microsoft.range">**com.microsoft.Range**</a>
Creates a sequence of numbers that begins at `start` and extends by increments of `delta`

View file

@ -434,6 +434,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)|
@ -791,6 +792,7 @@ Do not modify directly.*
|QOrderedMatMul|*in* A:**Q**<br> *in* scale_A:**S**<br> *in* B:**Q**<br> *in* scale_B:**S**<br> *in* scale_Y:**S**<br> *in* bias:**S**<br> *in* C:**Q**<br> *in* scale_C:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float16)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuantizeWithOrder|*in* input:**F**<br> *in* scale_input:**S**<br> *out* output:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|

View file

@ -34,5 +34,13 @@ ONNX_OPERATOR_KERNEL_EX(
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gelu<float>);
ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
QuickGelu<float>);
} // namespace contrib
} // namespace onnxruntime

View file

@ -93,5 +93,46 @@ class Gelu : public OpKernel {
}
};
// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
// MlasComputeLogistic instead of using Eigen for better perf.
template <typename T>
class QuickGelu : public OpKernel {
public:
QuickGelu(const OpKernelInfo& info) : OpKernel(info) { alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f); }
Status Compute(OpKernelContext* context) const override {
const Tensor* input = context->Input<Tensor>(0);
const T* input_data = input->template Data<T>();
Tensor* output = context->Output(0, input->Shape());
T* output_data = output->template MutableData<T>();
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
int64_t elem_count = input->Shape().Size();
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
concurrency::ThreadPool::TryBatchParallelFor(
tp, static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
const auto start = task_idx * length_per_task;
const T* p_input = input_data + start;
T* p_output = output_data + start;
int64_t count = std::min(length_per_task, elem_count - start);
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * alpha_;
}
MlasComputeLogistic(p_output, p_output, count);
for (int64_t i = 0; i < count; i++) {
p_output[i] = p_input[i] * p_output[i];
}
},
0);
return Status::OK();
}
private:
float alpha_;
};
} // namespace contrib
} // namespace onnxruntime

View file

@ -41,6 +41,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasG
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
@ -219,6 +220,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
// contrib ops to main backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine)>,

View file

@ -50,6 +50,7 @@ UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)

View file

@ -77,6 +77,20 @@ class Gelu final : public UnaryElementwise {
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class QuickGelu final : public UnaryElementwise {
public:
QuickGelu(const OpKernelInfo& info) : UnaryElementwise(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -50,6 +50,17 @@ struct OP_Gelu<half> : public CtxGelu {
}
};
template <typename T>
struct OP_QuickGelu : public CtxQuickGelu {
__device__ __inline__ T operator()(const T& a) const {
T v = a * static_cast<T>(alpha);
T one = static_cast<T>(1.f);
T zero = static_cast<T>(0.f);
T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
return a * sigmoid;
}
};
#define UNARY_ACTIVATION_IMPL(name) \
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
UnaryElementWiseImpl(stream, \

View file

@ -12,21 +12,14 @@ typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
typedef onnxruntime::cuda::CtxNull CtxGelu;
typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
#define UNARY_CONTRIB_ACTIVATION_OPS() \
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
UNARY_ACTIVATION_OP_NAME(Affine) \
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
UNARY_ACTIVATION_OP_NAME(Gelu)
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
cudaStream_t stream, \
const T* input_data, \
T* output_data, \
const Ctx##name* func_ctx, \
size_t count)
UNARY_ACTIVATION_OP_NAME(Gelu) \
UNARY_ACTIVATION_OP_NAME(QuickGelu)
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
UNARY_CONTRIB_ACTIVATION_OPS()

View file

@ -19,6 +19,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
@ -130,6 +133,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility

View file

@ -19,6 +19,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
@ -124,6 +127,9 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility

View file

@ -558,6 +558,36 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BiasGelu, 1,
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
constexpr const char* QuickGelu_ver1_doc = R"DOC(Compute x * Sigmoid(alpha * x).)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
QuickGelu, 1,
OpSchema()
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(QuickGelu_ver1_doc)
.Attr("alpha", "Alpha value.", AttributeProto::FLOAT, 1.702f)
.Input(0, "X", "The input data as Tensor.", "T")
.Output(0, "Y", "The output.", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema,
FunctionProto& functionProto) {
auto* tp = ctx.getInputType(0);
if ((tp == nullptr) || (!tp->has_tensor_type())) return false;
auto elem_type = (TensorProto_DataType)(tp->tensor_type().elem_type());
auto* alpha_attr = ctx.getAttribute("alpha");
float alpha = (alpha_attr != nullptr) ? alpha_attr->f() : 1.702f;
FunctionBuilder builder(functionProto);
builder.AddOpset("", 13).Const("Alpha", ToTensor(alpha, elem_type)).Add(R"(
CX = Mul (Alpha, X)
SIGMOIDCX = Sigmoid (CX)
Y = Mul (X, SIGMOIDCX)
)");
schema.BuildFunction(functionProto);
return true;
}));
// Used to be ONNX 1.7 Inverse(12)
// Comment out docs not to increase the binary size
//

View file

@ -64,6 +64,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedGemm);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatherND);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse);
@ -142,6 +143,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedMatMul)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatherND)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse)>());

View file

@ -57,6 +57,7 @@
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
#include "core/optimizer/qdq_transformer/qdq_s8_to_u8.h"
#include "core/optimizer/qdq_transformer/relu_quantizelinear.h"
#include "core/optimizer/quick_gelu_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@ -275,6 +276,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<MatMulScaleFusion>(cpu_cuda_dml_rocm_eps));

View file

@ -0,0 +1,101 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/optimizer/quick_gelu_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/initializer.h"
#include "core/optimizer/utils.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
/**
Rewrite x*sigmoid(alpha*x) or x*sigmoid(x) to QuickGelu.
*/
Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
GraphViewer graph_viewer(graph);
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
for (auto node_index : node_topology_list) {
auto* p_node = graph.GetNode(node_index);
if (!p_node) continue;
Node& node = *p_node;
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
int alpha_index = -1;
float alpha = 1.0f;
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}) &&
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1) {
for (int i = 0; i < static_cast<int>(node.InputDefs().size()); ++i) {
const NodeArg& input_arg = *(node.InputDefs()[i]);
if (!optimizer_utils::IsScalar(input_arg)) continue;
const TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
if (!tensor_proto) continue;
Initializer init_const{*tensor_proto, graph.ModelPath()};
const auto data_type = tensor_proto->data_type();
if (data_type == TensorProto_DataType_FLOAT) {
alpha = *(init_const.data<float>());
alpha_index = i;
break;
} else if (data_type == TensorProto_DataType_DOUBLE) {
alpha = static_cast<float>(*(init_const.data<double>()));
alpha_index = i;
break;
} else if (data_type == TensorProto_DataType_FLOAT16) {
alpha = math::halfToFloat(init_const.data<MLFloat16>()->val);
alpha_index = i;
break;
}
}
}
NodeArg* quick_gelu_input_arg = nullptr;
Node* p_sigmoid_node = p_node;
// If alpha_index is not -1, it means the node is Mul node and it has a scalar input.
// We expect the output of Mul node is consumed by a Sigmoid node.
// If alpha_index is -1, it means current node is expected to be a Sigmoid node.
if (alpha_index != -1) {
quick_gelu_input_arg = node.MutableInputDefs()[(alpha_index + 1) % 2];
nodes_to_fuse.emplace_back(node);
p_sigmoid_node = graph.GetNode(node.OutputNodesBegin()->Index());
}
Node& sigmoid_node = *p_sigmoid_node;
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sigmoid_node, "Sigmoid", {6, 13}) ||
!graph_utils::IsSupportedProvider(sigmoid_node, GetCompatibleExecutionProviders()) ||
sigmoid_node.GetOutputEdgesCount() != 1) {
continue;
}
nodes_to_fuse.emplace_back(sigmoid_node);
if (!quick_gelu_input_arg) {
quick_gelu_input_arg = sigmoid_node.MutableInputDefs()[0];
}
Node& mul_node = *graph.GetNode(sigmoid_node.OutputNodesBegin()->Index());
int sigmoid_output_index = optimizer_utils::IndexOfNodeInput(mul_node, *sigmoid_node.MutableOutputDefs()[0]);
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) ||
mul_node.MutableInputDefs()[(sigmoid_output_index + 1) % 2]->Name() != quick_gelu_input_arg->Name()) {
continue;
}
nodes_to_fuse.emplace_back(mul_node);
NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0];
Node& quick_gelu_node =
graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
std::array{quick_gelu_output_arg}, {}, kMSDomain);
quick_gelu_node.AddAttribute("alpha", alpha);
quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, quick_gelu_node);
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,21 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
/**
* @brief Rewrite graph fusing x*sigmoid(alpha*x) or x*sigmoid(x) to QuickGelu.
*/
class QuickGeluFusion : public GraphTransformer {
public:
QuickGeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
: GraphTransformer("QuickGeluFusion", compatible_execution_providers) {}
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -48,7 +48,7 @@ typedef CtxAlpha CtxThresholdedRelu;
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
template <typename T> \
void Impl_##name( \
cudaStream_t stream, \
cudaStream_t stream, \
const T* input_data, \
T* output_data, \
const Ctx##name* func_ctx, \

View file

@ -74,7 +74,7 @@ __global__ void _BinaryElementWiseSimple(
const T1* lhs_data,
const T2* rhs_data,
T* output_data,
const FuncT& func,
const FuncT func,
CUDA_LONG N) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
T1 lvalue[NumElementsPerThread];

View file

@ -47,6 +47,53 @@ TEST_F(ActivationOpTest, Gelu) {
"Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast<float>(M_SQRT1_2))); }, {},
false, 1, kMSDomain);
}
} // namespace test
TEST_F(ActivationOpTest, QuickGelu) {
// QuickGelu is not a single activation, some corner values in input_values will not work.
std::vector<std::vector<float>> quick_gelu_input_values{{-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}};
// Positive alpha.
{
float alpha = 1.702f;
TestActivationOp<float>(
"QuickGelu", quick_gelu_input_values,
[alpha](float x) {
auto tmp = x * alpha;
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
y = tmp >= 0 ? y : 1 - y;
return x * y;
},
{{"alpha", alpha}}, false, 1, kMSDomain);
}
// Silu = x*sigmoid(x), i.e., alpha = 1.0f.
{
float alpha = 1.0f;
TestActivationOp<float>(
"QuickGelu", quick_gelu_input_values,
[alpha](float x) {
auto tmp = x * alpha;
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
y = tmp >= 0 ? y : 1 - y;
return x * y;
},
{{"alpha", alpha}}, false, 1, kMSDomain);
}
// Negative alpha.
{
float alpha = -1.702f;
TestActivationOp<float>(
"QuickGelu", quick_gelu_input_values,
[alpha](float x) {
auto tmp = x * alpha;
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
y = tmp >= 0 ? y : 1 - y;
return x * y;
},
{{"alpha", alpha}}, false, 1, kMSDomain);
}
}
} // namespace test
} // namespace onnxruntime

View file

@ -60,6 +60,7 @@
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/propagate_cast_ops.h"
#include "core/optimizer/quick_gelu_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@ -3508,6 +3509,218 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
}
TEST_F(GraphTransformationTests, QuickGelu) {
// Sigmoid(x*alpha)*x, float
{
const float alpha = 1.702f;
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
auto* mul_out_0 = builder.MakeIntermediate();
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeOutput();
builder.AddNode("Mul", {input_arg, alpha_arg}, {mul_out_0});
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
builder.AddNode("Mul", {sigmoid_out, input_arg}, {mul_out_1});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "QuickGelu") {
auto& attrs = node.GetAttributes();
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
ASSERT_EQ(alpha, attrs.at("alpha").f());
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// x*Sigmoid(alpha*x), MLFloat16
{
const float alpha = -1.f;
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* alpha_arg = builder.MakeInitializer<MLFloat16>({}, {static_cast<MLFloat16>(alpha)});
auto* mul_out_0 = builder.MakeIntermediate();
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeOutput();
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "QuickGelu") {
auto& attrs = node.GetAttributes();
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
ASSERT_EQ(alpha, attrs.at("alpha").f());
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// Sigmoid's output is consumed by other node.
{
const float alpha = 1.702f;
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
auto* mul_out_0 = builder.MakeIntermediate();
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeOutput();
auto* identity_out = builder.MakeOutput();
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
builder.AddNode("Identity", {sigmoid_out}, {identity_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 0);
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// First Mul's output is consumed by other node.
{
const float alpha = -1.f;
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* alpha_arg = builder.MakeInitializer<MLFloat16>({}, {static_cast<MLFloat16>(alpha)});
auto* mul_out_0 = builder.MakeIntermediate();
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out_1 = builder.MakeOutput();
auto* identity_out = builder.MakeOutput();
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
builder.AddNode("Identity", {mul_out_0}, {identity_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 0);
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// Sigmoid(x)*x, float
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out = builder.MakeOutput();
builder.AddNode("Sigmoid", {input_arg}, {sigmoid_out});
builder.AddNode("Mul", {sigmoid_out, input_arg}, {mul_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "QuickGelu") {
auto& attrs = node.GetAttributes();
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
ASSERT_EQ(1.0f, attrs.at("alpha").f());
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// x*Sigmoid(x), MLFloat16
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* sigmoid_out = builder.MakeIntermediate();
auto* mul_out = builder.MakeOutput();
builder.AddNode("Sigmoid", {input_arg}, {sigmoid_out});
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out});
};
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 1);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
for (auto& node : graph.Nodes()) {
if (node.OpType() == "QuickGelu") {
auto& attrs = node.GetAttributes();
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
ASSERT_EQ(1.0f, attrs.at("alpha").f());
}
}
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
}
struct BiasSoftmaxFusionTester {
std::shared_ptr<Model> p_model_;
Status model_load_;

View file

@ -704,6 +704,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetSigmoidGradient) {
{GI(0)})};
}
IMPLEMENT_GRADIENT_BUILDER(GetQuickGeluGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{"QuickGeluGrad", kMSDomain, 1}, {GO(0), I(0)}, {GI(0)}, SrcNodeAttributes())};
}
IMPLEMENT_GRADIENT_BUILDER(GetSoftmaxGradient) {
return std::vector<NodeDef>{
NodeDef(OpDef{SrcNodeOpsetVersion() < 13 ? "SoftmaxGrad" : "SoftmaxGrad_13", kMSDomain, 1},

View file

@ -43,6 +43,7 @@ DECLARE_GRADIENT_BUILDER(GetConvGradient)
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
DECLARE_GRADIENT_BUILDER(GetSqueezeGradient)
DECLARE_GRADIENT_BUILDER(GetSigmoidGradient)
DECLARE_GRADIENT_BUILDER(GetQuickGeluGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxGradient)
DECLARE_GRADIENT_BUILDER(GetLogSoftmaxGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyGradient)

View file

@ -74,6 +74,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
REGISTER_GRADIENT_BUILDER("Unsqueeze", GetUnsqueezeGradient);
REGISTER_GRADIENT_BUILDER("Sigmoid", GetSigmoidGradient);
REGISTER_GRADIENT_BUILDER("QuickGelu", GetQuickGeluGradient);
REGISTER_GRADIENT_BUILDER("Softmax", GetSoftmaxGradient);
REGISTER_GRADIENT_BUILDER("LogSoftmax", GetLogSoftmaxGradient);
REGISTER_GRADIENT_BUILDER("SoftmaxCrossEntropy", GetSoftmaxCrossEntropyGradient);

View file

@ -2634,6 +2634,19 @@ Example 4:
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
});
ONNX_CONTRIB_OPERATOR_SCHEMA(QuickGeluGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc("QuickGeluGrad")
.Attr("alpha", "Alpha value.", AttributeProto::FLOAT, 1.702f)
.AllowUncheckedAttributes()
.Input(0, "dY", "The gradient tensor from output.", "T")
.Input(1, "X", "The input tensor. ", "T")
.Output(0, "dX", "Gradient of the input.", "T")
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
ONNX_CONTRIB_OPERATOR_SCHEMA(TanhGrad)
.SetDomain(kMSDomain)
.SinceVersion(1)

View file

@ -40,6 +40,7 @@
#include "core/optimizer/noop_elimination.h"
#include "core/optimizer/not_where_fusion.h"
#include "core/optimizer/propagate_cast_ops.h"
#include "core/optimizer/quick_gelu_fusion.h"
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/reshape_fusion.h"
#include "core/optimizer/rule_based_graph_transformer.h"
@ -99,6 +100,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(compatible_eps));

View file

@ -152,15 +152,16 @@ void GenerateRandomDataWithOneHot(std::vector<std::vector<float>>& x_datas, std:
void UnaryOpGradientTest(const std::string& op_type, const std::string& domain = kOnnxDomain,
const int opset_version = 9,
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers = nullptr,
std::function<float(float)>* transformer = nullptr) {
std::function<float(float)>* transformer = nullptr,
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes = {},
float error_tolerance = 1e-3f) {
TensorShape shape({2, 3, 4});
TensorInfo x_info{shape, true, transformer};
float max_error;
float error_tolerance = 1e-3f;
GradientChecker<float, float, float> gradient_checker;
OpDef op_def{op_type, domain, opset_version};
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info}, {shape}, &max_error, {}, true, false,
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info}, {shape}, &max_error, attributes, true, false,
execution_providers));
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
@ -1548,6 +1549,23 @@ TEST(GradientCheckerTest, DISABLED_BatchNormalizationGrad) {
TEST(GradientCheckerTest, SigmoidGrad) { UnaryOpGradientTest("Sigmoid"); }
TEST(GradientCheckerTest, QuickGeluGrad) {
// Default alpha = 1.702, relax the tolerance due failure on Win for some seed.
{ UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, {}, 5e-2f); }
// Silu, alpha = 1.0.
{
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {MakeAttribute("alpha", 1.0f)};
UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, attributes, 5e-2f);
}
// Negative alpha.
{
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {MakeAttribute("alpha", -1.702f)};
UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, attributes, 5e-2f);
}
}
void GradientCheckerSoftmaxGradHelper(bool is_log_softmax, int version = 11) {
TensorShape shape({2, 3, 4});
float max_error;

View file

@ -83,6 +83,12 @@ constexpr float SigmoidGrad(float dy, float y) {
constexpr float TanhGrad(float dy, float y) {
return dy * (1 - y * y);
}
float QuickGeluGrad(float dy, float x, float alpha) {
float v = x * alpha;
float sigmoid = v >= 0 ? 1.f / (1.f + std::exp(-v)) : 1.f - 1.f / (1 + std::exp(v));
return dy * sigmoid * (1 + v * (1 - sigmoid));
}
} // namespace
TEST(GeluGradTest, Basic) {
@ -199,6 +205,50 @@ TEST(TanhGradTest, Basic) {
{}, 1, kMSDomain);
}
TEST(QuickGeluGradTest, Basic) {
const std::vector<float> x_vals = {-10.0f, -1.0f, 0.0f, 1.0f, 10.0f};
const std::vector<float> dY(5, 1.0f);
// Positive alpha.
{
const float alpha = 1.702f;
TestElementwiseGradientOp(
"QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
[alpha](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], x = params[1];
return QuickGeluGrad(dy, x, alpha);
},
{{"alpha", alpha}}, 1, kMSDomain);
}
// Silu = x*sigmoid(x), i.e., alpha = 1.0f.
{
const float alpha = 1.0f;
TestElementwiseGradientOp(
"QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
[alpha](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], x = params[1];
return QuickGeluGrad(dy, x, alpha);
},
{{"alpha", alpha}}, 1, kMSDomain);
}
// Negative alpha.
{
const float alpha = -1.702f;
TestElementwiseGradientOp(
"QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
[alpha](const std::vector<float>& params) {
ORT_ENFORCE(params.size() == 2);
const auto dy = params[0], x = params[1];
return QuickGeluGrad(dy, x, alpha);
},
{{"alpha", alpha}}, 1, kMSDomain);
}
}
namespace {
template <typename TComputeGeluGradScalarFn>
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,

View file

@ -50,6 +50,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad);
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
@ -163,6 +164,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad)>,
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
// However these types work on GPU implementation.
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad)>,

View file

@ -1,15 +1,16 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "op_gradients.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include "orttraining/training_ops/cpu/op_gradients.h"
#include "core/mlas/inc/mlas.h"
#include "core/providers/common.h"
#include <unsupported/Eigen/SpecialFunctions>
#include "core/util/math.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/util/math.h"
#include "core/util/math_cpuonly.h"
#include <unsupported/Eigen/SpecialFunctions>
#include "gsl/gsl"
namespace onnxruntime {
@ -210,5 +211,44 @@ Status TanhGrad<T>::Compute(OpKernelContext* context) const {
dx = dy * (1 - y * y);
return Status::OK();
}
ONNX_OPERATOR_KERNEL_EX(QuickGeluGrad, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
QuickGeluGrad<float>);
template <typename T>
Status QuickGeluGrad<T>::Compute(OpKernelContext* context) const {
auto& dY = *context->Input<Tensor>(0);
const T* dY_data = dY.template Data<T>();
auto& X = *context->Input<Tensor>(1);
const T* X_data = X.template Data<T>();
auto& dX = *context->Output(0, dY.Shape());
T* dX_data = dX.template MutableData<T>();
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
int64_t elem_count = dY.Shape().Size();
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
concurrency::ThreadPool::TryBatchParallelFor(
tp, static_cast<int32_t>(task_count),
[&](ptrdiff_t task_idx) {
const auto start = task_idx * length_per_task;
const T* p_dy = dY_data + start;
const T* p_x = X_data + start;
T* p_dx = dX_data + start;
int64_t count = std::min(length_per_task, elem_count - start);
for (int64_t i = 0; i < count; i++) {
p_dx[i] = p_x[i] * alpha_;
}
MlasComputeLogistic(p_dx, p_dx, count);
for (int64_t i = 0; i < count; i++) {
p_dx[i] = p_dy[i] * p_dx[i] * (1.f + alpha_ * p_x[i] * (1.f - p_dx[i]));
}
},
0);
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -33,6 +33,20 @@ class SigmoidGrad final : public OpKernel {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SigmoidGrad);
};
template <typename T>
class QuickGeluGrad final : public OpKernel {
public:
explicit QuickGeluGrad(const OpKernelInfo& info) : OpKernel(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
}
Status Compute(OpKernelContext* context) const override;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QuickGeluGrad);
float alpha_;
};
template <typename T>
class TanhGrad final : public OpKernel {
public:

View file

@ -47,6 +47,7 @@ ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
} // namespace cuda

View file

@ -55,6 +55,20 @@ class SigmoidGrad final : public BinaryElementwise<ShouldNotBroadcast> {
MAKE_FUNC_CTX_NULL()
};
template <typename T>
class QuickGeluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
public:
QuickGeluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
MAKE_FUNC_CTX_ALPHA()
float alpha_;
};
template <typename T>
class TanhGrad final : public BinaryElementwise<ShouldNotBroadcast> {
public:

View file

@ -46,6 +46,17 @@ struct OP_SigmoidGrad : public CtxSigmoidGrad {
}
};
template <typename T>
struct OP_QuickGeluGrad : public CtxQuickGeluGrad {
__device__ __inline__ T operator()(const T& dy, const T& x) const {
T v = x * static_cast<T>(alpha);
T one = static_cast<T>(1.f);
T zero = static_cast<T>(0.f);
T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
return dy * sigmoid * (one + v * (one - sigmoid));
}
};
template <typename T>
struct OP_TanhGrad : public CtxTanhGrad {
__device__ __inline__ T operator()(const T& dy, const T& y) const {

View file

@ -11,13 +11,15 @@ typedef onnxruntime::cuda::CtxNull CtxGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxFastGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxReluGrad;
typedef onnxruntime::cuda::CtxNull CtxSigmoidGrad;
typedef onnxruntime::cuda::CtxAlpha CtxQuickGeluGrad;
typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
#define ACTIVATION_GRAD_OPS() \
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
ACTIVATION_GRAD_OP_NAME(ReluGrad) \
ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
#define ACTIVATION_GRAD_OPS() \
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
ACTIVATION_GRAD_OP_NAME(ReluGrad) \
ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
ACTIVATION_GRAD_OP_NAME(QuickGeluGrad) \
ACTIVATION_GRAD_OP_NAME(TanhGrad)
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \

View file

@ -112,6 +112,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SigmoidGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
@ -335,6 +339,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,

View file

@ -109,6 +109,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SigmoidGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
@ -301,6 +305,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,