mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[DML EP] Add QuickGelu (#15220)
This commit is contained in:
parent
a96e19abc4
commit
9191e04259
6 changed files with 139 additions and 1 deletions
|
|
@ -1183,6 +1183,7 @@ Do not modify directly.*
|
|||
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|
||||
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
| |
|
||||
| |
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<MatMulScaleFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,133 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorQuickGelu : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorQuickGelu(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 1);
|
||||
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1);
|
||||
DmlOperator::Initialize(kernelCreationContext);
|
||||
|
||||
ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs.size() == 1);
|
||||
ML_CHECK_VALID_ARGUMENT(m_outputTensorDescs.size() == 1);
|
||||
const float alpha = kernelCreationContext.GetAttribute<float>(AttrName::Alpha);
|
||||
|
||||
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
|
||||
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
|
||||
|
||||
// 1. Apply the alpha if needed
|
||||
DML_SCALE_BIAS scaleBias{alpha, 0.0f};
|
||||
DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC mulAlphaDesc{};
|
||||
if (alpha != 1.0f)
|
||||
{
|
||||
mulAlphaDesc.InputTensor = &inputDescs[0];
|
||||
mulAlphaDesc.OutputTensor = &inputDescs[0];
|
||||
mulAlphaDesc.ScaleBias = &scaleBias;
|
||||
}
|
||||
DML_OPERATOR_DESC dmlMulAlphaDesc = { DML_OPERATOR_ELEMENT_WISE_IDENTITY, &mulAlphaDesc };
|
||||
|
||||
// 2. Apply the sigmoid activation function
|
||||
DML_ACTIVATION_SIGMOID_OPERATOR_DESC sigmoidDesc{};
|
||||
sigmoidDesc.InputTensor = &inputDescs[0];
|
||||
sigmoidDesc.OutputTensor = &inputDescs[0];
|
||||
DML_OPERATOR_DESC dmlSigmoidDesc = { DML_OPERATOR_ACTIVATION_SIGMOID, &sigmoidDesc };
|
||||
|
||||
// 3. Multiply the sigmoid result with the original input
|
||||
DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC multiplyDesc{};
|
||||
multiplyDesc.ATensor = &inputDescs[0];
|
||||
multiplyDesc.BTensor = &inputDescs[0];
|
||||
multiplyDesc.OutputTensor = &inputDescs[0];
|
||||
DML_OPERATOR_DESC dmlMultiplyDesc = { DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &multiplyDesc };
|
||||
|
||||
enum NodeIndex
|
||||
{
|
||||
sigmoidNodeIndex,
|
||||
multiplyNodeIndex,
|
||||
mulAlphaNodeIndex,
|
||||
nodeCount,
|
||||
};
|
||||
|
||||
// Construct the graph
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs;
|
||||
opDescs.reserve(3);
|
||||
opDescs.push_back(&dmlSigmoidDesc);
|
||||
opDescs.push_back(&dmlMultiplyDesc);
|
||||
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
inputEdges.reserve(2);
|
||||
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
intermediateEdges.reserve(2);
|
||||
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
outputEdges.reserve(1);
|
||||
|
||||
if (alpha != 1.0f)
|
||||
{
|
||||
opDescs.push_back(&dmlMulAlphaDesc);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToMulAlphaEdge{};
|
||||
inputToMulAlphaEdge.GraphInputIndex = 0;
|
||||
inputToMulAlphaEdge.ToNodeIndex = mulAlphaNodeIndex;
|
||||
inputToMulAlphaEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToMulAlphaEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC mulAlphaToSigmoidEdge{};
|
||||
mulAlphaToSigmoidEdge.FromNodeIndex = mulAlphaNodeIndex;
|
||||
mulAlphaToSigmoidEdge.FromNodeOutputIndex = 0;
|
||||
mulAlphaToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex;
|
||||
mulAlphaToSigmoidEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(mulAlphaToSigmoidEdge);
|
||||
}
|
||||
else
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToSigmoidEdge{};
|
||||
inputToSigmoidEdge.GraphInputIndex = 0;
|
||||
inputToSigmoidEdge.ToNodeIndex = sigmoidNodeIndex;
|
||||
inputToSigmoidEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToSigmoidEdge);
|
||||
}
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC inputToMultiplyEdge{};
|
||||
inputToMultiplyEdge.GraphInputIndex = 0;
|
||||
inputToMultiplyEdge.ToNodeIndex = multiplyNodeIndex;
|
||||
inputToMultiplyEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(inputToMultiplyEdge);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC sigmoidToMultiplyEdge{};
|
||||
sigmoidToMultiplyEdge.FromNodeIndex = sigmoidNodeIndex;
|
||||
sigmoidToMultiplyEdge.FromNodeOutputIndex = 0;
|
||||
sigmoidToMultiplyEdge.ToNodeIndex = multiplyNodeIndex;
|
||||
sigmoidToMultiplyEdge.ToNodeInputIndex = 1;
|
||||
intermediateEdges.push_back(sigmoidToMultiplyEdge);
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC multiplyToOutputEdge{};
|
||||
multiplyToOutputEdge.FromNodeIndex = multiplyNodeIndex;
|
||||
multiplyToOutputEdge.FromNodeOutputIndex = 0;
|
||||
multiplyToOutputEdge.GraphOutputIndex = 0;
|
||||
outputEdges.push_back(multiplyToOutputEdge);
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
|
||||
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
|
||||
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
|
||||
operatorGraphDesc.outputEdges = outputEdges.data();
|
||||
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
|
||||
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
|
||||
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(QuickGelu, DmlOperatorQuickGelu);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -381,6 +381,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Shape);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(Size);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(QuickGelu);
|
||||
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
|
||||
|
|
@ -878,6 +879,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
|
||||
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QuerySkipLayerNormalization)},
|
||||
{REG_INFO_MS( 1, EmbedLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1546,6 +1546,7 @@ using ShapeInferenceHelper_IsInf = GetBroadcastedOutputShapeHelper;
|
|||
using ShapeInferenceHelper_Mod = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_BitShift= GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_Round = GetBroadcastedOutputShapeHelper;
|
||||
using ShapeInferenceHelper_QuickGelu = GetOutputShapeAsInputShapeHelper;
|
||||
|
||||
using ShapeInferenceHelper_ReduceSum = ReduceHelper;
|
||||
using ShapeInferenceHelper_ReduceMean = ReduceHelper;
|
||||
|
|
|
|||
|
|
@ -410,6 +410,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Attention = 1;
|
||||
static const int sc_sinceVer_SkipLayerNormalization = 1;
|
||||
static const int sc_sinceVer_EmbedLayerNormalization = 1;
|
||||
static const int sc_sinceVer_QuickGelu = 1;
|
||||
static const int sc_sinceVer_GroupNorm = 1;
|
||||
} // namespace MsftOperatorSet1
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue