From 9191e0425902eeed92ee67de70bd0cf6c6cd89b6 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Wed, 5 Apr 2023 10:49:34 -0700 Subject: [PATCH] [DML EP] Add QuickGelu (#15220) --- docs/OperatorKernels.md | 1 + .../core/optimizer/graph_transformer_utils.cc | 2 +- .../src/Operators/DmlOperatorQuickGelu.cpp | 133 ++++++++++++++++++ .../src/Operators/OperatorRegistration.cpp | 2 + .../dml/OperatorAuthorHelper/OperatorHelper.h | 1 + .../OperatorAuthorHelper/OperatorVersions.h | 1 + 6 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 6369d1666c..7ed747da0a 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1183,6 +1183,7 @@ Do not modify directly.* |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QLinearSigmoid|*in* X:**T**
*in* X_scale:**tensor(float)**
*in* X_zero_point:**T**
*in* Y_scale:**tensor(float)**
*in* Y_zero_point:**T**
*out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 8a46537c97..cce11c4795 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -299,7 +299,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp new file mode 100644 index 0000000000..3683ab7b0b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorQuickGelu : public DmlOperator +{ +public: + DmlOperatorQuickGelu(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1); + DmlOperator::Initialize(kernelCreationContext); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs.size() == 1); + ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs.size() == 1); + const float alpha = kernelCreationContext.GetAttribute(AttrName::Alpha); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + // 1. Apply the alpha if needed + DML_SCALE_BIAS scaleBias{alpha, 0.0f}; + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC mulAlphaDesc{}; + if (alpha != 1.0f) + { + mulAlphaDesc.InputTensor = &inputDescs[0]; + mulAlphaDesc.OutputTensor = &inputDescs[0]; + mulAlphaDesc.ScaleBias = &scaleBias; + } + DML_OPERATOR_DESC dmlMulAlphaDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &mulAlphaDesc }; + + // 2. Apply the sigmoid activation function + DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoidDesc{}; + sigmoidDesc.InputTensor = &inputDescs[0]; + sigmoidDesc.OutputTensor = &inputDescs[0]; + DML_OPERATOR_DESC dmlSigmoidDesc = { DML_OPERATOR_ACTIVATION_SIGMOID, &sigmoidDesc }; + + // 3. Multiply the sigmoid result with the original input + DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC multiplyDesc{}; + multiplyDesc.ATensor = &inputDescs[0]; + multiplyDesc.BTensor = &inputDescs[0]; + multiplyDesc.OutputTensor = &inputDescs[0]; + DML_OPERATOR_DESC dmlMultiplyDesc = { DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &multiplyDesc }; + + enum NodeIndex + { + sigmoidNodeIndex, + multiplyNodeIndex, + mulAlphaNodeIndex, + nodeCount, + }; + + // Construct the graph + std::vector opDescs; + opDescs.reserve(3); + opDescs.push_back(&dmlSigmoidDesc); + opDescs.push_back(&dmlMultiplyDesc); + + std::vector inputEdges; + inputEdges.reserve(2); + + std::vector intermediateEdges; + intermediateEdges.reserve(2); + + std::vector outputEdges; + outputEdges.reserve(1); + + if (alpha != 1.0f) + { + opDescs.push_back(&dmlMulAlphaDesc); + + DML_INPUT_GRAPH_EDGE_DESC inputToMulAlphaEdge{}; + inputToMulAlphaEdge.GraphInputIndex = 0; + inputToMulAlphaEdge.ToNodeIndex = mulAlphaNodeIndex; + inputToMulAlphaEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToMulAlphaEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC mulAlphaToSigmoidEdge{}; + mulAlphaToSigmoidEdge.FromNodeIndex = mulAlphaNodeIndex; + mulAlphaToSigmoidEdge.FromNodeOutputIndex = 0; + mulAlphaToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex; + mulAlphaToSigmoidEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(mulAlphaToSigmoidEdge); + } + else + { + DML_INPUT_GRAPH_EDGE_DESC inputToSigmoidEdge{}; + inputToSigmoidEdge.GraphInputIndex = 0; + inputToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex; + inputToSigmoidEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToSigmoidEdge); + } + + DML_INPUT_GRAPH_EDGE_DESC inputToMultiplyEdge{}; + inputToMultiplyEdge.GraphInputIndex = 0; + inputToMultiplyEdge.ToNodeIndex = multiplyNodeIndex; + inputToMultiplyEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToMultiplyEdge); + + DML_INTERMEDIATE_GRAPH_EDGE_DESC sigmoidToMultiplyEdge{}; + sigmoidToMultiplyEdge.FromNodeIndex = sigmoidNodeIndex; + sigmoidToMultiplyEdge.FromNodeOutputIndex = 0; + sigmoidToMultiplyEdge.ToNodeIndex = multiplyNodeIndex; + sigmoidToMultiplyEdge.ToNodeInputIndex = 1; + intermediateEdges.push_back(sigmoidToMultiplyEdge); + + DML_OUTPUT_GRAPH_EDGE_DESC multiplyToOutputEdge{}; + multiplyToOutputEdge.FromNodeIndex = multiplyNodeIndex; + multiplyToOutputEdge.FromNodeOutputIndex = 0; + multiplyToOutputEdge.GraphOutputIndex = 0; + outputEdges.push_back(multiplyToOutputEdge); + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(QuickGelu, DmlOperatorQuickGelu); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 0677757b54..349efdab3e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -381,6 +381,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shape); DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(Attention); DML_OP_EXTERN_CREATION_FUNCTION(NonZero); +DML_OP_EXTERN_CREATION_FUNCTION(QuickGelu); DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); @@ -878,6 +879,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, {REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)}, {REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)}, }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 9c01aac536..c0d1562f8a 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1546,6 +1546,7 @@ using ShapeInferenceHelper_IsInf = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Mod = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_BitShift= GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_Round = GetBroadcastedOutputShapeHelper; +using ShapeInferenceHelper_QuickGelu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_ReduceSum = ReduceHelper; using ShapeInferenceHelper_ReduceMean = ReduceHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 96c10e8dc4..57464786d7 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -410,6 +410,7 @@ namespace OperatorHelper static const int sc_sinceVer_Attention = 1; static const int sc_sinceVer_SkipLayerNormalization = 1; static const int sc_sinceVer_EmbedLayerNormalization = 1; + static const int sc_sinceVer_QuickGelu = 1; static const int sc_sinceVer_GroupNorm = 1; } // namespace MsftOperatorSet1