mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update BiasGelu fusion and related ops (#23518)
### Description (1) Update BiasGelu fusion to support onnx Gelu-20 Since onnx Gelu-20 supports float/double/bf16/fp16, here we update related ops to support these data types in CUDA and ROCm execution providers: (2) Add double support for Gelu/FastGelu op in CUDA/ROCm execution provider (3) Add BFloat16 support for Gelu ops in CUDA execution provider (4) Add unit tests (5) Update operator documents ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/23491
This commit is contained in:
parent
4dde74a393
commit
0bb4ea6797
18 changed files with 193 additions and 11 deletions
|
|
@ -1754,7 +1754,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
|
||||
<dt><tt>T</tt> : tensor(float), tensor(double), tensor(float16), tensor(bfloat16)</dt>
|
||||
<dd>Constrain input and output types to float or half tensors.</dd>
|
||||
</dl>
|
||||
|
||||
|
|
|
|||
|
|
@ -912,11 +912,11 @@ Do not modify directly.*
|
|||
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|
||||
|DynamicTimeWarping|*in* input:**F**<br> *out* output:**I**|1+|**F** = tensor(float)<br/> **I** = tensor(int32)|
|
||||
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|
||||
|FastGelu|*in* X:**T**<br> *in* bias:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|GatedRelativePositionBias|*in* query_layer:**T**<br> *in* query_bias:**T**<br> *in* rel_pos:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* eco_a:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|GemmFloat8|*in* A:**TA**<br> *in* B:**TB**<br> *in* C:**TC**<br> *in* scaleA:**TS**<br> *in* scaleB:**TS**<br> *in* scaleY:**TS**<br> *out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)<br/> **TS** = tensor(float)|
|
||||
|GemmaRotaryEmbedding|*in* emb:**U**<br> *in* q:**T**<br> *in* q_rot:**T**<br> *in* k:**T**<br> *in* k_rot:**T**<br> *out* output1:**T**<br> *out* output2:**T**|1+|**T** = tensor(float16)<br/> **U** = tensor(float)|
|
||||
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ namespace cuda {
|
|||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_KERNEL_TYPED(BFloat16)
|
||||
REGISTER_KERNEL_TYPED(double)
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
|
|
|
|||
|
|
@ -25,10 +25,14 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace cuda {
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample);
|
||||
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, BiasGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu);
|
||||
|
|
@ -154,7 +158,6 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, uint8_t_MLFloat16, DequantizeLinear);
|
|||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_int8_t, QAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, QAttention);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FusedConv);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu);
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul); // backward compatibility
|
||||
class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul);
|
||||
class CUDA_MS_OP_CLASS_NAME(1, QOrderedMatMul);
|
||||
|
|
@ -234,10 +237,13 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, GridSample)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, double, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, Gelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, BiasGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, float, BiasSplitGelu)>,
|
||||
|
|
@ -362,7 +368,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, UnfoldTensor)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, DynamicTimeWarping)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_CLASS_NAME(1, Trilu)>,
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FastGelu)>,
|
||||
// TransposedMatMul is still here for backward compatibility
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, FusedMatMul)>,
|
||||
|
|
|
|||
|
|
@ -66,10 +66,12 @@ class ElementwiseTunableOp : public TunableOp<ElementwiseParams<T>> {
|
|||
}
|
||||
|
||||
ELEMENTWISE_FWD_DECL(FastGeLU, float);
|
||||
ELEMENTWISE_FWD_DECL(FastGeLU, double);
|
||||
ELEMENTWISE_FWD_DECL(FastGeLU, half);
|
||||
ELEMENTWISE_FWD_DECL(FastGeLU, BFloat16);
|
||||
|
||||
ELEMENTWISE_FWD_DECL(GeLU, float);
|
||||
ELEMENTWISE_FWD_DECL(GeLU, double);
|
||||
ELEMENTWISE_FWD_DECL(GeLU, half);
|
||||
ELEMENTWISE_FWD_DECL(GeLU, BFloat16);
|
||||
|
||||
|
|
|
|||
|
|
@ -4,5 +4,6 @@
|
|||
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"
|
||||
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, float);
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, double);
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, half);
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::FastGeLU, BFloat16);
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
#include "contrib_ops/rocm/bert/elementwise_impl/impl.cuh"
|
||||
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, double);
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, float);
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, half);
|
||||
ELEMENTWISE_KERNEL_IMPL(functor::GeLU, BFloat16);
|
||||
|
|
|
|||
|
|
@ -11,10 +11,13 @@ namespace contrib {
|
|||
namespace rocm {
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
|
||||
|
|
@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
|
||||
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
|
||||
|
|
@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, Gelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
|
||||
|
|
@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
|
||||
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Trilu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
|
||||
// TransposedMatMul is still here for backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
|
||||
|
|
|
|||
|
|
@ -1490,7 +1490,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
.Input(0, "X", "input tensor", "T")
|
||||
.Input(1, "bias", "bias tensor", "T", OpSchema::Optional)
|
||||
.Output(0, "Y", "output tensor", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(double)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float or half tensors.")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)
|
||||
.SetContextDependentFunctionBodyBuilder([](const FunctionBodyBuildContext& ctx, const OpSchema& schema, FunctionProto& functionProto) {
|
||||
// fastgelu(x) =
|
||||
|
|
|
|||
|
|
@ -61,7 +61,10 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
}
|
||||
|
||||
const Node& next_node = (*next_node_itr);
|
||||
if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
|
||||
|
||||
bool is_onnx_gelu = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {20}, kOnnxDomain);
|
||||
if (!(is_onnx_gelu ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Gelu", {1}, kMSDomain) ||
|
||||
graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "FastGelu", {1}, kMSDomain)) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
continue;
|
||||
|
|
@ -72,6 +75,12 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
|
||||
bool is_approximate = is_fast_gelu;
|
||||
if (is_onnx_gelu) {
|
||||
const ONNX_NAMESPACE::AttributeProto* attribute = graph_utils::GetNodeAttribute(next_node, "approximate");
|
||||
is_approximate = (attribute != nullptr) && utils::HasString(*attribute) && (attribute->s() == "tanh");
|
||||
}
|
||||
|
||||
if (graph.NodeProducesGraphOutput(node)) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -79,7 +88,7 @@ Status BiasGeluFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& add_node = node;
|
||||
Node& gelu_node = const_cast<Node&>(next_node);
|
||||
std::string op_type = "BiasGelu";
|
||||
if (is_fast_gelu) op_type = "FastGelu";
|
||||
if (is_approximate) op_type = "FastGelu";
|
||||
|
||||
Node& gelu_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type),
|
||||
op_type,
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ namespace cuda {
|
|||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_KERNEL_TYPED(BFloat16)
|
||||
REGISTER_KERNEL_TYPED(double)
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -80,6 +81,7 @@ namespace contrib::cuda {
|
|||
|
||||
REGISTER_CONTRIB_KERNEL_TYPED(float)
|
||||
REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16)
|
||||
REGISTER_CONTRIB_KERNEL_TYPED(BFloat16)
|
||||
REGISTER_CONTRIB_KERNEL_TYPED(double)
|
||||
|
||||
#undef REGISTER_CONTRIB_KERNEL_TYPED
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ Status LaunchGeluKernel(
|
|||
|
||||
SPECIALIZED_GELU_IMPL(float);
|
||||
SPECIALIZED_GELU_IMPL(half);
|
||||
SPECIALIZED_GELU_IMPL(BFloat16);
|
||||
SPECIALIZED_GELU_IMPL(double);
|
||||
|
||||
#undef SPECIALIZED_GELU_IMPL
|
||||
|
|
|
|||
|
|
@ -389,7 +389,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) {
|
|||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
|
||||
#ifdef USE_CUDA
|
||||
int min_cuda_architecture = 530;
|
||||
int min_cuda_architecture = 800;
|
||||
if (!HasCudaEnvironment(min_cuda_architecture)) {
|
||||
LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16";
|
||||
return;
|
||||
|
|
@ -440,5 +440,43 @@ TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
|
|||
}
|
||||
#endif
|
||||
|
||||
// CUDA and ROCm only for double type.
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM)
|
||||
TEST(FastGeluTest, FastGeluWithBias_Double) {
|
||||
OpTester tester("FastGelu", 1, onnxruntime::kMSDomain);
|
||||
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
|
||||
std::vector<double> X = {
|
||||
0.8, -0.5, 0.0, 1.0,
|
||||
0.5, 0.2, 0.3, -0.6};
|
||||
|
||||
std::vector<double> B = {
|
||||
-0.5, 0.6, 1.2, 2.1};
|
||||
|
||||
std::vector<double> Y = {
|
||||
0.185371, 0.053983, 1.061703, 3.097373,
|
||||
0.000000, 0.630432, 1.399572, 1.399572};
|
||||
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> bias_dims = {hidden_size};
|
||||
std::vector<int64_t> output_dims = input_dims;
|
||||
|
||||
tester.AddInput<double>("X", input_dims, X);
|
||||
tester.AddInput<double>("bias", bias_dims, B);
|
||||
tester.AddOutput<double>("Y", output_dims, Y);
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
#ifdef USE_CUDA
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
#elif USE_ROCM
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
#endif
|
||||
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4781,6 +4781,46 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
|
|||
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasOnnxGeluTest) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_gelu_fusion.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Gelu"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 1);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, BiasOnnxFastGeluTest) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/bias_onnx_fast_gelu_fusion.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_));
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
const InlinedHashSet<std::string_view> no_limit_empty_ep_list = {};
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(
|
||||
std::make_unique<GeluFusion>(no_limit_empty_ep_list, TransformerLevel::Level2), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique<BiasGeluFusion>(), TransformerLevel::Level2));
|
||||
ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_));
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Add"] == 0);
|
||||
ASSERT_TRUE(op_to_count["Gelu"] == 0);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.FastGelu"] == 1);
|
||||
ASSERT_TRUE(op_to_count["com.microsoft.BiasGelu"] == 0);
|
||||
}
|
||||
|
||||
// BiasGelu allows input switching based on input dimensions.
|
||||
// This test validates the input edges are plugged correct in the optimized graph.
|
||||
TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
|
||||
|
|
|
|||
Binary file not shown.
|
|
@ -29,3 +29,40 @@ graph = helper.make_graph(
|
|||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, r"bias_gelu_fusion.onnx")
|
||||
|
||||
graph = helper.make_graph(
|
||||
[
|
||||
helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"),
|
||||
helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu"),
|
||||
],
|
||||
"Gelu_Add_Fusion", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
|
||||
helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
|
||||
],
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, r"bias_onnx_gelu_fusion.onnx")
|
||||
|
||||
|
||||
graph = helper.make_graph(
|
||||
[
|
||||
helper.make_node("Add", ["X", "B"], ["add0_out"], "add0"),
|
||||
helper.make_node("Gelu", ["add0_out"], ["Y"], "gelu", approximate="tanh"),
|
||||
],
|
||||
"Gelu_Add_Fusion", # name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
|
||||
helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["batch", "seqlen", 1024]),
|
||||
],
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, r"bias_onnx_fast_gelu_fusion.onnx")
|
||||
|
|
|
|||
21
onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx
vendored
Normal file
21
onnxruntime/test/testdata/transform/fusion/bias_onnx_fast_gelu_fusion.onnx
vendored
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
|
||||
:½
|
||||
|
||||
X
|
||||
Badd0_outadd0"Add
|
||||
1
|
||||
add0_outYgelu"Gelu*
|
||||
approximate"tanh Gelu_Add_FusionZ#
|
||||
X
|
||||
|
||||
batch
|
||||
seqlen
|
||||
€Z
|
||||
B
|
||||
|
||||
€b#
|
||||
Y
|
||||
|
||||
batch
|
||||
seqlen
|
||||
€B
|
||||
20
onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx
vendored
Normal file
20
onnxruntime/test/testdata/transform/fusion/bias_onnx_gelu_fusion.onnx
vendored
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
|
||||
:¥
|
||||
|
||||
X
|
||||
Badd0_outadd0"Add
|
||||
|
||||
add0_outYgelu"GeluGelu_Add_FusionZ#
|
||||
X
|
||||
|
||||
batch
|
||||
seqlen
|
||||
€Z
|
||||
B
|
||||
|
||||
€b#
|
||||
Y
|
||||
|
||||
batch
|
||||
seqlen
|
||||
€B
|
||||
Loading…
Reference in a new issue