diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 6369d1666c..7ed747da0a 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1183,6 +1183,7 @@ Do not modify directly.*
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
+|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 8a46537c97..cce11c4795 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -299,7 +299,7 @@ InlinedVector> GenerateTransformers(
transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
- transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps));
+ transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps));
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp
new file mode 100644
index 0000000000..3683ab7b0b
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp
@@ -0,0 +1,133 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+
+class DmlOperatorQuickGelu : public DmlOperator
+{
+public:
+ DmlOperatorQuickGelu(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext)
+ {
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
+ DmlOperator::Initialize(kernelCreationContext);
+
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs.size() == 1);
+ ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs.size() == 1);
+ const float alpha = kernelCreationContext.GetAttribute(AttrName::Alpha);
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ // 1. Apply the alpha if needed
+ DML_SCALE_BIAS scaleBias{alpha, 0.0f};
+ DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC mulAlphaDesc{};
+ if (alpha != 1.0f)
+ {
+ mulAlphaDesc.InputTensor = &inputDescs[0];
+ mulAlphaDesc.OutputTensor = &inputDescs[0];
+ mulAlphaDesc.ScaleBias = &scaleBias;
+ }
+ DML_OPERATOR_DESC dmlMulAlphaDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &mulAlphaDesc };
+
+ // 2. Apply the sigmoid activation function
+ DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoidDesc{};
+ sigmoidDesc.InputTensor = &inputDescs[0];
+ sigmoidDesc.OutputTensor = &inputDescs[0];
+ DML_OPERATOR_DESC dmlSigmoidDesc = { DML_OPERATOR_ACTIVATION_SIGMOID, &sigmoidDesc };
+
+ // 3. Multiply the sigmoid result with the original input
+ DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC multiplyDesc{};
+ multiplyDesc.ATensor = &inputDescs[0];
+ multiplyDesc.BTensor = &inputDescs[0];
+ multiplyDesc.OutputTensor = &inputDescs[0];
+ DML_OPERATOR_DESC dmlMultiplyDesc = { DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &multiplyDesc };
+
+ enum NodeIndex
+ {
+ sigmoidNodeIndex,
+ multiplyNodeIndex,
+ mulAlphaNodeIndex,
+ nodeCount,
+ };
+
+ // Construct the graph
+ std::vector opDescs;
+ opDescs.reserve(3);
+ opDescs.push_back(&dmlSigmoidDesc);
+ opDescs.push_back(&dmlMultiplyDesc);
+
+ std::vector inputEdges;
+ inputEdges.reserve(2);
+
+ std::vector intermediateEdges;
+ intermediateEdges.reserve(2);
+
+ std::vector outputEdges;
+ outputEdges.reserve(1);
+
+ if (alpha != 1.0f)
+ {
+ opDescs.push_back(&dmlMulAlphaDesc);
+
+ DML_INPUT_GRAPH_EDGE_DESC inputToMulAlphaEdge{};
+ inputToMulAlphaEdge.GraphInputIndex = 0;
+ inputToMulAlphaEdge.ToNodeIndex = mulAlphaNodeIndex;
+ inputToMulAlphaEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(inputToMulAlphaEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC mulAlphaToSigmoidEdge{};
+ mulAlphaToSigmoidEdge.FromNodeIndex = mulAlphaNodeIndex;
+ mulAlphaToSigmoidEdge.FromNodeOutputIndex = 0;
+ mulAlphaToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex;
+ mulAlphaToSigmoidEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(mulAlphaToSigmoidEdge);
+ }
+ else
+ {
+ DML_INPUT_GRAPH_EDGE_DESC inputToSigmoidEdge{};
+ inputToSigmoidEdge.GraphInputIndex = 0;
+ inputToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex;
+ inputToSigmoidEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(inputToSigmoidEdge);
+ }
+
+ DML_INPUT_GRAPH_EDGE_DESC inputToMultiplyEdge{};
+ inputToMultiplyEdge.GraphInputIndex = 0;
+ inputToMultiplyEdge.ToNodeIndex = multiplyNodeIndex;
+ inputToMultiplyEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(inputToMultiplyEdge);
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC sigmoidToMultiplyEdge{};
+ sigmoidToMultiplyEdge.FromNodeIndex = sigmoidNodeIndex;
+ sigmoidToMultiplyEdge.FromNodeOutputIndex = 0;
+ sigmoidToMultiplyEdge.ToNodeIndex = multiplyNodeIndex;
+ sigmoidToMultiplyEdge.ToNodeInputIndex = 1;
+ intermediateEdges.push_back(sigmoidToMultiplyEdge);
+
+ DML_OUTPUT_GRAPH_EDGE_DESC multiplyToOutputEdge{};
+ multiplyToOutputEdge.FromNodeIndex = multiplyNodeIndex;
+ multiplyToOutputEdge.FromNodeOutputIndex = 0;
+ multiplyToOutputEdge.GraphOutputIndex = 0;
+ outputEdges.push_back(multiplyToOutputEdge);
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+ operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size());
+ operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+ }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(QuickGelu, DmlOperatorQuickGelu);
+
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 0677757b54..349efdab3e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -381,6 +381,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
+DML_OP_EXTERN_CREATION_FUNCTION(QuickGelu);
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
@@ -878,6 +879,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
};
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 9c01aac536..c0d1562f8a 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1546,6 +1546,7 @@ using ShapeInferenceHelper_IsInf = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Mod = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_BitShift= GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Round = GetBroadcastedOutputShapeHelper;
+using ShapeInferenceHelper_QuickGelu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_ReduceSum = ReduceHelper;
using ShapeInferenceHelper_ReduceMean = ReduceHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 96c10e8dc4..57464786d7 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -410,6 +410,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
+ static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
} // namespace MsftOperatorSet1