mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
QuickGelu Fusion (#12417)
Some models have QuickGelu(x)=x*sigmoid(1.702x), which has 3 Ops for forward and 5 Ops for backward. The PR is to fuse this to a single Op named QuickGelu and its gradient QuickGeluGrad. For CUDA, tested in V100 using input tensor with shape [64,128,2048] and float16 type: Before, FW takes 335us, BW takes 614us  After, FW takes 115us, BW takes 139us, which is much faster.  For CPU kernel, using same shape and float type: Before, FW takes 10us, BW takes 49us Mul: 3480[µs] Sigmoid: 1996[µs] Mul: 4789[µs] Mul: 4642[µs] Mul: 4195[µs] SigmoidGrad: 18328[µs] Mul: 2988[µs] Sum: 18576[µs] After, FW takes 4us, BW takes 5us, which is also much faster. QuickGelu: 3939[µs] QuickGeluGrad: 5089[µs] Co-authored-by: Vincent Wang <weicwang@microsoft.com>
This commit is contained in:
parent
20c3c35c33
commit
8b0669bf63
36 changed files with 752 additions and 26 deletions
|
|
@ -66,6 +66,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.QuantizeBFP">com.microsoft.QuantizeBFP</a>
|
||||
* <a href="#com.microsoft.QuantizeLinear">com.microsoft.QuantizeLinear</a>
|
||||
* <a href="#com.microsoft.QuantizeWithOrder">com.microsoft.QuantizeWithOrder</a>
|
||||
* <a href="#com.microsoft.QuickGelu">com.microsoft.QuickGelu</a>
|
||||
* <a href="#com.microsoft.Range">com.microsoft.Range</a>
|
||||
* <a href="#com.microsoft.ReduceSumInteger">com.microsoft.ReduceSumInteger</a>
|
||||
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
|
||||
|
|
@ -3476,6 +3477,43 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.QuickGelu"></a><a name="com.microsoft.quickgelu">**com.microsoft.QuickGelu**</a>
|
||||
|
||||
Compute x * Sigmoid(alpha * x).
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>alpha</tt> : float</dt>
|
||||
<dd>Alpha value.</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>X</tt> : T</dt>
|
||||
<dd>The input data as Tensor.</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>Y</tt> : T</dt>
|
||||
<dd>The output.</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
|
||||
<dd>Constrain input and output types to float tensors.</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.Range"></a><a name="com.microsoft.range">**com.microsoft.Range**</a>
|
||||
|
||||
Creates a sequence of numbers that begins at `start` and extends by increments of `delta`
|
||||
|
|
|
|||
|
|
@ -434,6 +434,7 @@ Do not modify directly.*
|
|||
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QLinearSoftmax|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* x_zero_point:**T**<br> *in* y_scale:**tensor(float)**<br> *in* y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|
||||
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float)|
|
||||
|
|
@ -791,6 +792,7 @@ Do not modify directly.*
|
|||
|QOrderedMatMul|*in* A:**Q**<br> *in* scale_A:**S**<br> *in* B:**Q**<br> *in* scale_B:**S**<br> *in* scale_Y:**S**<br> *in* bias:**S**<br> *in* C:**Q**<br> *in* scale_C:**S**<br> *out* Y:**Q**|1+|**Q** = tensor(int8)<br/> **S** = tensor(float)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float16)<br/> **T2** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeWithOrder|*in* input:**F**<br> *in* scale_input:**S**<br> *out* output:**Q**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|
||||
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -34,5 +34,13 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
Gelu<float>);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
QuickGelu,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
QuickGelu<float>);
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -93,5 +93,46 @@ class Gelu : public OpKernel {
|
|||
}
|
||||
};
|
||||
|
||||
// Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call
|
||||
// MlasComputeLogistic instead of using Eigen for better perf.
|
||||
template <typename T>
|
||||
class QuickGelu : public OpKernel {
|
||||
public:
|
||||
QuickGelu(const OpKernelInfo& info) : OpKernel(info) { alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f); }
|
||||
|
||||
Status Compute(OpKernelContext* context) const override {
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const T* input_data = input->template Data<T>();
|
||||
Tensor* output = context->Output(0, input->Shape());
|
||||
T* output_data = output->template MutableData<T>();
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
int64_t elem_count = input->Shape().Size();
|
||||
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
|
||||
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
tp, static_cast<int32_t>(task_count),
|
||||
[&](ptrdiff_t task_idx) {
|
||||
const auto start = task_idx * length_per_task;
|
||||
const T* p_input = input_data + start;
|
||||
T* p_output = output_data + start;
|
||||
int64_t count = std::min(length_per_task, elem_count - start);
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
p_output[i] = p_input[i] * alpha_;
|
||||
}
|
||||
|
||||
MlasComputeLogistic(p_output, p_output, count);
|
||||
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
p_output[i] = p_input[i] * p_output[i];
|
||||
}
|
||||
},
|
||||
0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BiasG
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
|
||||
|
||||
// ******** Start: Quantization ******************* //
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
|
||||
|
|
@ -219,6 +220,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
|
||||
// These ops were experimental ops in onnx domain which have been removed now. We add them here as
|
||||
// contrib ops to main backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, Affine)>,
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain);
|
|||
UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain);
|
||||
UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain);
|
||||
UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain);
|
||||
UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain);
|
||||
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16)
|
||||
REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, float)
|
||||
|
|
|
|||
|
|
@ -77,6 +77,20 @@ class Gelu final : public UnaryElementwise {
|
|||
MAKE_FUNC_CTX_NULL()
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class QuickGelu final : public UnaryElementwise {
|
||||
public:
|
||||
QuickGelu(const OpKernelInfo& info) : UnaryElementwise(info) {
|
||||
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
MAKE_FUNC_CTX_ALPHA()
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -50,6 +50,17 @@ struct OP_Gelu<half> : public CtxGelu {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_QuickGelu : public CtxQuickGelu {
|
||||
__device__ __inline__ T operator()(const T& a) const {
|
||||
T v = a * static_cast<T>(alpha);
|
||||
T one = static_cast<T>(1.f);
|
||||
T zero = static_cast<T>(0.f);
|
||||
T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
|
||||
return a * sigmoid;
|
||||
}
|
||||
};
|
||||
|
||||
#define UNARY_ACTIVATION_IMPL(name) \
|
||||
UNARY_ACTIVATION_IMPL_DECLARATION(name) { \
|
||||
UnaryElementWiseImpl(stream, \
|
||||
|
|
|
|||
|
|
@ -12,21 +12,14 @@ typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine;
|
|||
typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus;
|
||||
typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh;
|
||||
typedef onnxruntime::cuda::CtxNull CtxGelu;
|
||||
typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu;
|
||||
|
||||
#define UNARY_CONTRIB_ACTIVATION_OPS() \
|
||||
UNARY_ACTIVATION_OP_NAME(ScaledTanh) \
|
||||
UNARY_ACTIVATION_OP_NAME(Affine) \
|
||||
UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \
|
||||
UNARY_ACTIVATION_OP_NAME(Gelu)
|
||||
|
||||
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
void Impl_##name( \
|
||||
cudaStream_t stream, \
|
||||
const T* input_data, \
|
||||
T* output_data, \
|
||||
const Ctx##name* func_ctx, \
|
||||
size_t count)
|
||||
UNARY_ACTIVATION_OP_NAME(Gelu) \
|
||||
UNARY_ACTIVATION_OP_NAME(QuickGelu)
|
||||
|
||||
#define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name);
|
||||
UNARY_CONTRIB_ACTIVATION_OPS()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
|
||||
|
|
@ -130,6 +133,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
|
||||
|
|
|
|||
|
|
@ -19,6 +19,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
|
||||
|
|
@ -124,6 +127,9 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
|
||||
|
|
|
|||
|
|
@ -558,6 +558,36 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BiasGelu, 1,
|
|||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));
|
||||
|
||||
constexpr const char* QuickGelu_ver1_doc = R"DOC(Compute x * Sigmoid(alpha * x).)DOC";
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(
|
||||
QuickGelu, 1,
|
||||
OpSchema()
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc(QuickGelu_ver1_doc)
|
||||
.Attr("alpha", "Alpha value.", AttributeProto::FLOAT, 1.702f)
|
||||
.Input(0, "X", "The input data as Tensor.", "T")
|
||||
.Output(0, "Y", "The output.", "T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema,
|
||||
FunctionProto& functionProto) {
|
||||
auto* tp = ctx.getInputType(0);
|
||||
if ((tp == nullptr) || (!tp->has_tensor_type())) return false;
|
||||
auto elem_type = (TensorProto_DataType)(tp->tensor_type().elem_type());
|
||||
auto* alpha_attr = ctx.getAttribute("alpha");
|
||||
float alpha = (alpha_attr != nullptr) ? alpha_attr->f() : 1.702f;
|
||||
FunctionBuilder builder(functionProto);
|
||||
builder.AddOpset("", 13).Const("Alpha", ToTensor(alpha, elem_type)).Add(R"(
|
||||
CX = Mul (Alpha, X)
|
||||
SIGMOIDCX = Sigmoid (CX)
|
||||
Y = Mul (X, SIGMOIDCX)
|
||||
)");
|
||||
schema.BuildFunction(functionProto);
|
||||
return true;
|
||||
}));
|
||||
|
||||
// Used to be ONNX 1.7 Inverse(12)
|
||||
// Comment out docs not to increase the binary size
|
||||
//
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedGemm);
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedMatMul);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatherND);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse);
|
||||
|
|
@ -142,6 +143,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FusedMatMul)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatherND)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Gelu)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, QuickGelu)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GreedySearch)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GridSample)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Inverse)>());
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@
|
|||
#include "core/optimizer/qdq_transformer/qdq_propagation.h"
|
||||
#include "core/optimizer/qdq_transformer/qdq_s8_to_u8.h"
|
||||
#include "core/optimizer/qdq_transformer/relu_quantizelinear.h"
|
||||
#include "core/optimizer/quick_gelu_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -275,6 +276,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<MatMulScaleFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
||||
|
|
|
|||
101
onnxruntime/core/optimizer/quick_gelu_fusion.cc
Normal file
101
onnxruntime/core/optimizer/quick_gelu_fusion.cc
Normal file
|
|
@ -0,0 +1,101 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/optimizer/quick_gelu_fusion.h"
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
Rewrite x*sigmoid(alpha*x) or x*sigmoid(x) to QuickGelu.
|
||||
*/
|
||||
Status QuickGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
|
||||
GraphViewer graph_viewer(graph);
|
||||
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
|
||||
for (auto node_index : node_topology_list) {
|
||||
auto* p_node = graph.GetNode(node_index);
|
||||
if (!p_node) continue;
|
||||
|
||||
Node& node = *p_node;
|
||||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
InlinedVector<std::reference_wrapper<Node>> nodes_to_fuse;
|
||||
|
||||
int alpha_index = -1;
|
||||
float alpha = 1.0f;
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Mul", {7, 13, 14}) &&
|
||||
graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) && node.GetOutputEdgesCount() == 1) {
|
||||
for (int i = 0; i < static_cast<int>(node.InputDefs().size()); ++i) {
|
||||
const NodeArg& input_arg = *(node.InputDefs()[i]);
|
||||
if (!optimizer_utils::IsScalar(input_arg)) continue;
|
||||
const TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name());
|
||||
if (!tensor_proto) continue;
|
||||
Initializer init_const{*tensor_proto, graph.ModelPath()};
|
||||
const auto data_type = tensor_proto->data_type();
|
||||
if (data_type == TensorProto_DataType_FLOAT) {
|
||||
alpha = *(init_const.data<float>());
|
||||
alpha_index = i;
|
||||
break;
|
||||
} else if (data_type == TensorProto_DataType_DOUBLE) {
|
||||
alpha = static_cast<float>(*(init_const.data<double>()));
|
||||
alpha_index = i;
|
||||
break;
|
||||
} else if (data_type == TensorProto_DataType_FLOAT16) {
|
||||
alpha = math::halfToFloat(init_const.data<MLFloat16>()->val);
|
||||
alpha_index = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
NodeArg* quick_gelu_input_arg = nullptr;
|
||||
Node* p_sigmoid_node = p_node;
|
||||
// If alpha_index is not -1, it means the node is Mul node and it has a scalar input.
|
||||
// We expect the output of Mul node is consumed by a Sigmoid node.
|
||||
// If alpha_index is -1, it means current node is expected to be a Sigmoid node.
|
||||
if (alpha_index != -1) {
|
||||
quick_gelu_input_arg = node.MutableInputDefs()[(alpha_index + 1) % 2];
|
||||
nodes_to_fuse.emplace_back(node);
|
||||
p_sigmoid_node = graph.GetNode(node.OutputNodesBegin()->Index());
|
||||
}
|
||||
|
||||
Node& sigmoid_node = *p_sigmoid_node;
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(sigmoid_node, "Sigmoid", {6, 13}) ||
|
||||
!graph_utils::IsSupportedProvider(sigmoid_node, GetCompatibleExecutionProviders()) ||
|
||||
sigmoid_node.GetOutputEdgesCount() != 1) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_fuse.emplace_back(sigmoid_node);
|
||||
if (!quick_gelu_input_arg) {
|
||||
quick_gelu_input_arg = sigmoid_node.MutableInputDefs()[0];
|
||||
}
|
||||
|
||||
Node& mul_node = *graph.GetNode(sigmoid_node.OutputNodesBegin()->Index());
|
||||
int sigmoid_output_index = optimizer_utils::IndexOfNodeInput(mul_node, *sigmoid_node.MutableOutputDefs()[0]);
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) ||
|
||||
!graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) ||
|
||||
mul_node.MutableInputDefs()[(sigmoid_output_index + 1) % 2]->Name() != quick_gelu_input_arg->Name()) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_fuse.emplace_back(mul_node);
|
||||
|
||||
NodeArg* quick_gelu_output_arg = mul_node.MutableOutputDefs()[0];
|
||||
Node& quick_gelu_node =
|
||||
graph.AddNode(graph.GenerateNodeName("QuickGelu"), "QuickGelu", "QuickGelu", std::array{quick_gelu_input_arg},
|
||||
std::array{quick_gelu_output_arg}, {}, kMSDomain);
|
||||
quick_gelu_node.AddAttribute("alpha", alpha);
|
||||
quick_gelu_node.SetExecutionProviderType(node.GetExecutionProviderType());
|
||||
graph_utils::FinalizeNodeFusion(graph, nodes_to_fuse, quick_gelu_node);
|
||||
modified = true;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
21
onnxruntime/core/optimizer/quick_gelu_fusion.h
Normal file
21
onnxruntime/core/optimizer/quick_gelu_fusion.h
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
/**
|
||||
* @brief Rewrite graph fusing x*sigmoid(alpha*x) or x*sigmoid(x) to QuickGelu.
|
||||
*/
|
||||
class QuickGeluFusion : public GraphTransformer {
|
||||
public:
|
||||
QuickGeluFusion(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept
|
||||
: GraphTransformer("QuickGeluFusion", compatible_execution_providers) {}
|
||||
|
||||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -48,7 +48,7 @@ typedef CtxAlpha CtxThresholdedRelu;
|
|||
#define UNARY_ACTIVATION_IMPL_DECLARATION(name) \
|
||||
template <typename T> \
|
||||
void Impl_##name( \
|
||||
cudaStream_t stream, \
|
||||
cudaStream_t stream, \
|
||||
const T* input_data, \
|
||||
T* output_data, \
|
||||
const Ctx##name* func_ctx, \
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ __global__ void _BinaryElementWiseSimple(
|
|||
const T1* lhs_data,
|
||||
const T2* rhs_data,
|
||||
T* output_data,
|
||||
const FuncT& func,
|
||||
const FuncT func,
|
||||
CUDA_LONG N) {
|
||||
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
T1 lvalue[NumElementsPerThread];
|
||||
|
|
|
|||
|
|
@ -47,6 +47,53 @@ TEST_F(ActivationOpTest, Gelu) {
|
|||
"Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast<float>(M_SQRT1_2))); }, {},
|
||||
false, 1, kMSDomain);
|
||||
}
|
||||
} // namespace test
|
||||
|
||||
TEST_F(ActivationOpTest, QuickGelu) {
|
||||
// QuickGelu is not a single activation, some corner values in input_values will not work.
|
||||
std::vector<std::vector<float>> quick_gelu_input_values{{-1.0f, 0, 1.0f, 100.0f, -100.0f, 1000.0f, -1000.0f}};
|
||||
|
||||
// Positive alpha.
|
||||
{
|
||||
float alpha = 1.702f;
|
||||
TestActivationOp<float>(
|
||||
"QuickGelu", quick_gelu_input_values,
|
||||
[alpha](float x) {
|
||||
auto tmp = x * alpha;
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
|
||||
y = tmp >= 0 ? y : 1 - y;
|
||||
return x * y;
|
||||
},
|
||||
{{"alpha", alpha}}, false, 1, kMSDomain);
|
||||
}
|
||||
|
||||
// Silu = x*sigmoid(x), i.e., alpha = 1.0f.
|
||||
{
|
||||
float alpha = 1.0f;
|
||||
TestActivationOp<float>(
|
||||
"QuickGelu", quick_gelu_input_values,
|
||||
[alpha](float x) {
|
||||
auto tmp = x * alpha;
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
|
||||
y = tmp >= 0 ? y : 1 - y;
|
||||
return x * y;
|
||||
},
|
||||
{{"alpha", alpha}}, false, 1, kMSDomain);
|
||||
}
|
||||
|
||||
// Negative alpha.
|
||||
{
|
||||
float alpha = -1.702f;
|
||||
TestActivationOp<float>(
|
||||
"QuickGelu", quick_gelu_input_values,
|
||||
[alpha](float x) {
|
||||
auto tmp = x * alpha;
|
||||
auto y = 1.f / (1.f + std::exp(-std::abs(tmp))); // safe sigmoid
|
||||
y = tmp >= 0 ? y : 1 - y;
|
||||
return x * y;
|
||||
},
|
||||
{{"alpha", alpha}}, false, 1, kMSDomain);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@
|
|||
#include "core/optimizer/noop_elimination.h"
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
#include "core/optimizer/propagate_cast_ops.h"
|
||||
#include "core/optimizer/quick_gelu_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -3508,6 +3509,218 @@ TEST_F(GraphTransformationTests, FastGeluFusionWithCastsTest3) {
|
|||
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, QuickGelu) {
|
||||
// Sigmoid(x*alpha)*x, float
|
||||
{
|
||||
const float alpha = 1.702f;
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
|
||||
auto* mul_out_0 = builder.MakeIntermediate();
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out_1 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Mul", {input_arg, alpha_arg}, {mul_out_0});
|
||||
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {sigmoid_out, input_arg}, {mul_out_1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "QuickGelu") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
|
||||
ASSERT_EQ(alpha, attrs.at("alpha").f());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
// x*Sigmoid(alpha*x), MLFloat16
|
||||
{
|
||||
const float alpha = -1.f;
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
|
||||
auto* alpha_arg = builder.MakeInitializer<MLFloat16>({}, {static_cast<MLFloat16>(alpha)});
|
||||
auto* mul_out_0 = builder.MakeIntermediate();
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out_1 = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
|
||||
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "QuickGelu") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
|
||||
ASSERT_EQ(alpha, attrs.at("alpha").f());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
// Sigmoid's output is consumed by other node.
|
||||
{
|
||||
const float alpha = 1.702f;
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* alpha_arg = builder.MakeInitializer<float>({}, {alpha});
|
||||
auto* mul_out_0 = builder.MakeIntermediate();
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out_1 = builder.MakeOutput();
|
||||
auto* identity_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
|
||||
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
|
||||
builder.AddNode("Identity", {sigmoid_out}, {identity_out});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 0);
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
// First Mul's output is consumed by other node.
|
||||
{
|
||||
const float alpha = -1.f;
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
|
||||
auto* alpha_arg = builder.MakeInitializer<MLFloat16>({}, {static_cast<MLFloat16>(alpha)});
|
||||
auto* mul_out_0 = builder.MakeIntermediate();
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out_1 = builder.MakeOutput();
|
||||
auto* identity_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Mul", {alpha_arg, input_arg}, {mul_out_0});
|
||||
builder.AddNode("Sigmoid", {mul_out_0}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out_1});
|
||||
builder.AddNode("Identity", {mul_out_0}, {identity_out});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 2);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 0);
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
// Sigmoid(x)*x, float
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Sigmoid", {input_arg}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {sigmoid_out, input_arg}, {mul_out});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 1);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "QuickGelu") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
|
||||
ASSERT_EQ(1.0f, attrs.at("alpha").f());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
// x*Sigmoid(x), MLFloat16
|
||||
{
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
|
||||
auto* sigmoid_out = builder.MakeIntermediate();
|
||||
auto* mul_out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Sigmoid", {input_arg}, {sigmoid_out});
|
||||
builder.AddNode("Mul", {input_arg, sigmoid_out}, {mul_out});
|
||||
};
|
||||
|
||||
auto pre_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 1);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 1);
|
||||
};
|
||||
|
||||
auto post_graph_checker = [&](Graph& graph) {
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Mul"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["Sigmoid"], 0);
|
||||
ASSERT_EQ(CountOpsInGraph(graph)["com.microsoft.QuickGelu"], 1);
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node.OpType() == "QuickGelu") {
|
||||
auto& attrs = node.GetAttributes();
|
||||
ASSERT_TRUE(attrs.find("alpha") != attrs.end());
|
||||
ASSERT_EQ(1.0f, attrs.at("alpha").f());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<QuickGeluFusion>();
|
||||
TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
}
|
||||
|
||||
struct BiasSoftmaxFusionTester {
|
||||
std::shared_ptr<Model> p_model_;
|
||||
Status model_load_;
|
||||
|
|
|
|||
|
|
@ -704,6 +704,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetSigmoidGradient) {
|
|||
{GI(0)})};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetQuickGeluGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{"QuickGeluGrad", kMSDomain, 1}, {GO(0), I(0)}, {GI(0)}, SrcNodeAttributes())};
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetSoftmaxGradient) {
|
||||
return std::vector<NodeDef>{
|
||||
NodeDef(OpDef{SrcNodeOpsetVersion() < 13 ? "SoftmaxGrad" : "SoftmaxGrad_13", kMSDomain, 1},
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ DECLARE_GRADIENT_BUILDER(GetConvGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetUnsqueezeGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSqueezeGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSigmoidGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetQuickGeluGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSoftmaxGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetLogSoftmaxGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyGradient)
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Squeeze", GetSqueezeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Unsqueeze", GetUnsqueezeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Sigmoid", GetSigmoidGradient);
|
||||
REGISTER_GRADIENT_BUILDER("QuickGelu", GetQuickGeluGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Softmax", GetSoftmaxGradient);
|
||||
REGISTER_GRADIENT_BUILDER("LogSoftmax", GetLogSoftmaxGradient);
|
||||
REGISTER_GRADIENT_BUILDER("SoftmaxCrossEntropy", GetSoftmaxCrossEntropyGradient);
|
||||
|
|
|
|||
|
|
@ -2634,6 +2634,19 @@ Example 4:
|
|||
return ONNX_NAMESPACE::FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {onnx_opset_13});
|
||||
});
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(QuickGeluGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc("QuickGeluGrad")
|
||||
.Attr("alpha", "Alpha value.", AttributeProto::FLOAT, 1.702f)
|
||||
.AllowUncheckedAttributes()
|
||||
.Input(0, "dY", "The gradient tensor from output.", "T")
|
||||
.Input(1, "X", "The input tensor. ", "T")
|
||||
.Output(0, "dX", "Gradient of the input.", "T")
|
||||
.TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(TanhGrad)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
#include "core/optimizer/noop_elimination.h"
|
||||
#include "core/optimizer/not_where_fusion.h"
|
||||
#include "core/optimizer/propagate_cast_ops.h"
|
||||
#include "core/optimizer/quick_gelu_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/reshape_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
|
|
@ -99,6 +100,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(
|
|||
transformers.emplace_back(std::make_unique<LayerNormFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<SimplifiedLayerNormFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<FastGeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<SoftmaxCrossEntropyLossInternalFusion>(compatible_eps));
|
||||
transformers.emplace_back(std::make_unique<GatherToSplitFusion>(compatible_eps));
|
||||
|
||||
|
|
|
|||
|
|
@ -152,15 +152,16 @@ void GenerateRandomDataWithOneHot(std::vector<std::vector<float>>& x_datas, std:
|
|||
void UnaryOpGradientTest(const std::string& op_type, const std::string& domain = kOnnxDomain,
|
||||
const int opset_version = 9,
|
||||
std::vector<std::unique_ptr<IExecutionProvider>>* execution_providers = nullptr,
|
||||
std::function<float(float)>* transformer = nullptr) {
|
||||
std::function<float(float)>* transformer = nullptr,
|
||||
const std::vector<ONNX_NAMESPACE::AttributeProto>& attributes = {},
|
||||
float error_tolerance = 1e-3f) {
|
||||
TensorShape shape({2, 3, 4});
|
||||
TensorInfo x_info{shape, true, transformer};
|
||||
float max_error;
|
||||
float error_tolerance = 1e-3f;
|
||||
GradientChecker<float, float, float> gradient_checker;
|
||||
OpDef op_def{op_type, domain, opset_version};
|
||||
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info}, {shape}, &max_error, {}, true, false,
|
||||
ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info}, {shape}, &max_error, attributes, true, false,
|
||||
execution_providers));
|
||||
|
||||
EXPECT_IS_TINIER_THAN(max_error, error_tolerance);
|
||||
|
|
@ -1548,6 +1549,23 @@ TEST(GradientCheckerTest, DISABLED_BatchNormalizationGrad) {
|
|||
|
||||
TEST(GradientCheckerTest, SigmoidGrad) { UnaryOpGradientTest("Sigmoid"); }
|
||||
|
||||
TEST(GradientCheckerTest, QuickGeluGrad) {
|
||||
// Default alpha = 1.702, relax the tolerance due failure on Win for some seed.
|
||||
{ UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, {}, 5e-2f); }
|
||||
|
||||
// Silu, alpha = 1.0.
|
||||
{
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {MakeAttribute("alpha", 1.0f)};
|
||||
UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, attributes, 5e-2f);
|
||||
}
|
||||
|
||||
// Negative alpha.
|
||||
{
|
||||
std::vector<ONNX_NAMESPACE::AttributeProto> attributes = {MakeAttribute("alpha", -1.702f)};
|
||||
UnaryOpGradientTest("QuickGelu", kMSDomain, 1, nullptr, nullptr, attributes, 5e-2f);
|
||||
}
|
||||
}
|
||||
|
||||
void GradientCheckerSoftmaxGradHelper(bool is_log_softmax, int version = 11) {
|
||||
TensorShape shape({2, 3, 4});
|
||||
float max_error;
|
||||
|
|
|
|||
|
|
@ -83,6 +83,12 @@ constexpr float SigmoidGrad(float dy, float y) {
|
|||
constexpr float TanhGrad(float dy, float y) {
|
||||
return dy * (1 - y * y);
|
||||
}
|
||||
|
||||
float QuickGeluGrad(float dy, float x, float alpha) {
|
||||
float v = x * alpha;
|
||||
float sigmoid = v >= 0 ? 1.f / (1.f + std::exp(-v)) : 1.f - 1.f / (1 + std::exp(v));
|
||||
return dy * sigmoid * (1 + v * (1 - sigmoid));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(GeluGradTest, Basic) {
|
||||
|
|
@ -199,6 +205,50 @@ TEST(TanhGradTest, Basic) {
|
|||
{}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
TEST(QuickGeluGradTest, Basic) {
|
||||
const std::vector<float> x_vals = {-10.0f, -1.0f, 0.0f, 1.0f, 10.0f};
|
||||
const std::vector<float> dY(5, 1.0f);
|
||||
|
||||
// Positive alpha.
|
||||
{
|
||||
const float alpha = 1.702f;
|
||||
TestElementwiseGradientOp(
|
||||
"QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
|
||||
[alpha](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], x = params[1];
|
||||
return QuickGeluGrad(dy, x, alpha);
|
||||
},
|
||||
{{"alpha", alpha}}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
// Silu = x*sigmoid(x), i.e., alpha = 1.0f.
|
||||
{
|
||||
const float alpha = 1.0f;
|
||||
TestElementwiseGradientOp(
|
||||
"QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
|
||||
[alpha](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], x = params[1];
|
||||
return QuickGeluGrad(dy, x, alpha);
|
||||
},
|
||||
{{"alpha", alpha}}, 1, kMSDomain);
|
||||
}
|
||||
|
||||
// Negative alpha.
|
||||
{
|
||||
const float alpha = -1.702f;
|
||||
TestElementwiseGradientOp(
|
||||
"QuickGeluGrad", {{"dY", dY}, {"X", x_vals}},
|
||||
[alpha](const std::vector<float>& params) {
|
||||
ORT_ENFORCE(params.size() == 2);
|
||||
const auto dy = params[0], x = params[1];
|
||||
return QuickGeluGrad(dy, x, alpha);
|
||||
},
|
||||
{{"alpha", alpha}}, 1, kMSDomain);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <typename TComputeGeluGradScalarFn>
|
||||
void TestBiasGeluGradBroadcastBias(const std::string& op, int opset_version, const std::string& domain,
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad);
|
||||
|
||||
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
|
||||
// However these types work on GPU implementation.
|
||||
|
|
@ -163,6 +164,7 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGeluGrad)>,
|
||||
// REVIEW(mzs): ConstEigenVectorArrayMap.cast<MLFLoat16) does not seem to be supported.
|
||||
// However these types work on GPU implementation.
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad)>,
|
||||
|
|
|
|||
|
|
@ -1,15 +1,16 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "op_gradients.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include "orttraining/training_ops/cpu/op_gradients.h"
|
||||
|
||||
#include "core/mlas/inc/mlas.h"
|
||||
#include "core/providers/common.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include "core/util/math.h"
|
||||
#include "core/providers/cpu/math/element_wise_ops.h"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/providers/cpu/tensor/transpose.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/util/math_cpuonly.h"
|
||||
#include <unsupported/Eigen/SpecialFunctions>
|
||||
#include "gsl/gsl"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -210,5 +211,44 @@ Status TanhGrad<T>::Compute(OpKernelContext* context) const {
|
|||
dx = dy * (1 - y * y);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(QuickGeluGrad, kMSDomain, 1, kCpuExecutionProvider,
|
||||
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
QuickGeluGrad<float>);
|
||||
|
||||
template <typename T>
|
||||
Status QuickGeluGrad<T>::Compute(OpKernelContext* context) const {
|
||||
auto& dY = *context->Input<Tensor>(0);
|
||||
const T* dY_data = dY.template Data<T>();
|
||||
auto& X = *context->Input<Tensor>(1);
|
||||
const T* X_data = X.template Data<T>();
|
||||
auto& dX = *context->Output(0, dY.Shape());
|
||||
T* dX_data = dX.template MutableData<T>();
|
||||
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
|
||||
int64_t elem_count = dY.Shape().Size();
|
||||
constexpr int64_t length_per_task = 4096; // this number comes from FastGelu.
|
||||
int64_t task_count = (elem_count + length_per_task - 1) / length_per_task;
|
||||
concurrency::ThreadPool::TryBatchParallelFor(
|
||||
tp, static_cast<int32_t>(task_count),
|
||||
[&](ptrdiff_t task_idx) {
|
||||
const auto start = task_idx * length_per_task;
|
||||
const T* p_dy = dY_data + start;
|
||||
const T* p_x = X_data + start;
|
||||
T* p_dx = dX_data + start;
|
||||
int64_t count = std::min(length_per_task, elem_count - start);
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
p_dx[i] = p_x[i] * alpha_;
|
||||
}
|
||||
|
||||
MlasComputeLogistic(p_dx, p_dx, count);
|
||||
|
||||
for (int64_t i = 0; i < count; i++) {
|
||||
p_dx[i] = p_dy[i] * p_dx[i] * (1.f + alpha_ * p_x[i] * (1.f - p_dx[i]));
|
||||
}
|
||||
},
|
||||
0);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -33,6 +33,20 @@ class SigmoidGrad final : public OpKernel {
|
|||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SigmoidGrad);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class QuickGeluGrad final : public OpKernel {
|
||||
public:
|
||||
explicit QuickGeluGrad(const OpKernelInfo& info) : OpKernel(info) {
|
||||
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QuickGeluGrad);
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TanhGrad final : public OpKernel {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ ACTIVATION_GRAD_OP_HFD(GeluGrad, 1, kMSDomain);
|
|||
ACTIVATION_GRAD_OP_HFD(FastGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(ReluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(SigmoidGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(QuickGeluGrad, 1, kMSDomain);
|
||||
ACTIVATION_GRAD_OP_HFD(TanhGrad, 1, kMSDomain);
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -55,6 +55,20 @@ class SigmoidGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
|||
MAKE_FUNC_CTX_NULL()
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class QuickGeluGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
||||
public:
|
||||
QuickGeluGrad(const OpKernelInfo& info) : BinaryElementwise(info) {
|
||||
alpha_ = info.GetAttrOrDefault<float>("alpha", 1.702f);
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
MAKE_FUNC_CTX_ALPHA()
|
||||
float alpha_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TanhGrad final : public BinaryElementwise<ShouldNotBroadcast> {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,17 @@ struct OP_SigmoidGrad : public CtxSigmoidGrad {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_QuickGeluGrad : public CtxQuickGeluGrad {
|
||||
__device__ __inline__ T operator()(const T& dy, const T& x) const {
|
||||
T v = x * static_cast<T>(alpha);
|
||||
T one = static_cast<T>(1.f);
|
||||
T zero = static_cast<T>(0.f);
|
||||
T sigmoid = v >= zero ? one / (one + _Exp(-v)) : one - one / (one + _Exp(v));
|
||||
return dy * sigmoid * (one + v * (one - sigmoid));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct OP_TanhGrad : public CtxTanhGrad {
|
||||
__device__ __inline__ T operator()(const T& dy, const T& y) const {
|
||||
|
|
|
|||
|
|
@ -11,13 +11,15 @@ typedef onnxruntime::cuda::CtxNull CtxGeluGrad;
|
|||
typedef onnxruntime::cuda::CtxNull CtxFastGeluGrad;
|
||||
typedef onnxruntime::cuda::CtxNull CtxReluGrad;
|
||||
typedef onnxruntime::cuda::CtxNull CtxSigmoidGrad;
|
||||
typedef onnxruntime::cuda::CtxAlpha CtxQuickGeluGrad;
|
||||
typedef onnxruntime::cuda::CtxNull CtxTanhGrad;
|
||||
|
||||
#define ACTIVATION_GRAD_OPS() \
|
||||
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(ReluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
|
||||
#define ACTIVATION_GRAD_OPS() \
|
||||
ACTIVATION_GRAD_OP_NAME(GeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(FastGeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(ReluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(SigmoidGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(QuickGeluGrad) \
|
||||
ACTIVATION_GRAD_OP_NAME(TanhGrad)
|
||||
|
||||
#define BINARY_ELEMENTWISE_IMPL_DECLARATION(name) \
|
||||
|
|
|
|||
|
|
@ -112,6 +112,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SigmoidGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
|
||||
|
|
@ -335,6 +339,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
|
||||
|
|
|
|||
|
|
@ -109,6 +109,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SigmoidGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGeluGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad);
|
||||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad);
|
||||
|
|
@ -301,6 +305,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SigmoidGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, QuickGeluGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TanhGrad)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TanhGrad)>,
|
||||
|
|
|
|||
Loading…
Reference in a new issue