mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
[DML EP] Add SkipLayerNormalization (#13849)
### Description Add SkipLayerNormalization for the DML EP
This commit is contained in:
parent
004a1538d3
commit
96d8d2c278
8 changed files with 179 additions and 4 deletions
|
|
@ -1128,6 +1128,7 @@ Do not modify directly.*
|
|||
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|
||||
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
|
||||
| |
|
||||
| |
|
||||
|**Operator Domain:** *com.microsoft.dml*||||
|
||||
|
|
|
|||
|
|
@ -271,7 +271,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_dml_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<BiasSoftmaxFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<BiasDropoutFusion>(cuda_rocm_eps));
|
||||
#ifdef ENABLE_TRAINING
|
||||
|
|
@ -280,7 +280,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
|
|||
transformers.emplace_back(std::make_unique<SceLossGradBiasFusion>(cpu_cuda_rocm_eps));
|
||||
#endif
|
||||
|
||||
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps));
|
||||
|
||||
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
|
||||
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,164 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "precomp.h"
|
||||
|
||||
namespace Dml
|
||||
{
|
||||
|
||||
class DmlOperatorSkipLayerNormalization : public DmlOperator
|
||||
{
|
||||
public:
|
||||
DmlOperatorSkipLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
|
||||
: DmlOperator(kernelCreationContext)
|
||||
{
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2, 3, 4};
|
||||
|
||||
DmlOperator::Initialize(
|
||||
kernelCreationContext,
|
||||
kernelInputIndices,
|
||||
std::nullopt,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0),
|
||||
std::nullopt,
|
||||
kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0));
|
||||
|
||||
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
|
||||
|
||||
int32_t onnxAxis = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::Axis, -1);
|
||||
uint32_t inputDimCount = kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0);
|
||||
onnxAxis = OperatorHelper::HandleNegativeAxis(onnxAxis, inputDimCount);
|
||||
std::vector<uint32_t> onnxAxes(inputDimCount - onnxAxis);
|
||||
std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);
|
||||
|
||||
assert(m_inputTensorDescs.size() == 5);
|
||||
assert(m_outputTensorDescs.size() == 1);
|
||||
|
||||
auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
|
||||
auto skipDesc = m_inputTensorDescs[1].GetDmlDesc();
|
||||
auto gammaDesc = m_inputTensorDescs[2].GetDmlDesc();
|
||||
auto betaDesc = m_inputTensorDescs[3].GetDmlDesc();
|
||||
auto biasDesc = m_inputTensorDescs[4].GetDmlDesc();
|
||||
auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();
|
||||
|
||||
TensorDesc inputSkipBiasTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes());
|
||||
DML_TENSOR_DESC inputSkipBiasDmlTensorDesc = inputSkipBiasTensorDesc.GetDmlDesc();
|
||||
|
||||
DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipAddDesc = {};
|
||||
inputSkipAddDesc.ATensor = &inputDesc;
|
||||
inputSkipAddDesc.BTensor = &skipDesc;
|
||||
inputSkipAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
|
||||
DML_OPERATOR_DESC inputSkipAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipAddDesc };
|
||||
|
||||
DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipBiasAddDesc = {};
|
||||
inputSkipBiasAddDesc.ATensor = &inputSkipBiasDmlTensorDesc;
|
||||
inputSkipBiasAddDesc.BTensor = &biasDesc;
|
||||
inputSkipBiasAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
|
||||
DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };
|
||||
|
||||
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
|
||||
mvnDesc.InputTensor = &inputSkipBiasDmlTensorDesc;
|
||||
mvnDesc.ScaleTensor = &gammaDesc;
|
||||
mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
|
||||
mvnDesc.OutputTensor = &outputDesc;
|
||||
mvnDesc.Axes = onnxAxes.data();
|
||||
mvnDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
|
||||
mvnDesc.NormalizeVariance = true;
|
||||
mvnDesc.Epsilon = epsilon;
|
||||
mvnDesc.FusedActivation = nullptr;
|
||||
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };
|
||||
|
||||
// Construct the graph
|
||||
std::vector<const DML_OPERATOR_DESC*> opDescs;
|
||||
opDescs.reserve(3);
|
||||
|
||||
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
|
||||
inputEdges.reserve(5);
|
||||
|
||||
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
|
||||
intermediateEdges.reserve(2);
|
||||
|
||||
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
|
||||
outputEdges.reserve(1);
|
||||
|
||||
// Insert the Input + Skip operation into the graph
|
||||
opDescs.push_back(&inputSkipAddOpDesc);
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
|
||||
dataInputEdge.GraphInputIndex = 0;
|
||||
dataInputEdge.ToNodeIndex = 0;
|
||||
dataInputEdge.ToNodeInputIndex = 0;
|
||||
inputEdges.push_back(std::move(dataInputEdge));
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC skipInputEdge = {};
|
||||
skipInputEdge.GraphInputIndex = 1;
|
||||
skipInputEdge.ToNodeIndex = 0;
|
||||
skipInputEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(std::move(skipInputEdge));
|
||||
|
||||
// Insert the InputSkip + Bias operation into the graph
|
||||
if (biasDesc.Desc)
|
||||
{
|
||||
opDescs.push_back(&inputSkipBiasAddOpDesc);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
|
||||
intermediateEdge.FromNodeIndex = 0;
|
||||
intermediateEdge.FromNodeOutputIndex = 0;
|
||||
intermediateEdge.ToNodeIndex = 1;
|
||||
intermediateEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(std::move(intermediateEdge));
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
|
||||
biasInputEdge.GraphInputIndex = 4;
|
||||
biasInputEdge.ToNodeIndex = 1;
|
||||
biasInputEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(std::move(biasInputEdge));
|
||||
}
|
||||
|
||||
// Insert the MVN operation into the graph
|
||||
opDescs.push_back(&mvnOpDesc);
|
||||
|
||||
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
|
||||
intermediateEdge.FromNodeIndex = biasDesc.Desc ? 1 : 0;
|
||||
intermediateEdge.FromNodeOutputIndex = 0;
|
||||
intermediateEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
|
||||
intermediateEdge.ToNodeInputIndex = 0;
|
||||
intermediateEdges.push_back(std::move(intermediateEdge));
|
||||
|
||||
DML_INPUT_GRAPH_EDGE_DESC gammaInputEdge = {};
|
||||
gammaInputEdge.GraphInputIndex = 2;
|
||||
gammaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
|
||||
gammaInputEdge.ToNodeInputIndex = 1;
|
||||
inputEdges.push_back(std::move(gammaInputEdge));
|
||||
|
||||
if (betaDesc.Desc)
|
||||
{
|
||||
DML_INPUT_GRAPH_EDGE_DESC betaInputEdge = {};
|
||||
betaInputEdge.GraphInputIndex = 3;
|
||||
betaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
|
||||
betaInputEdge.ToNodeInputIndex = 2;
|
||||
inputEdges.push_back(std::move(betaInputEdge));
|
||||
}
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
|
||||
outputEdge.GraphOutputIndex = 0;
|
||||
outputEdge.FromNodeIndex = biasDesc.Desc ? 2 : 1;
|
||||
outputEdge.FromNodeOutputIndex = 0;
|
||||
outputEdges.push_back(std::move(outputEdge));
|
||||
|
||||
MLOperatorGraphDesc operatorGraphDesc = {};
|
||||
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
|
||||
operatorGraphDesc.inputEdges = inputEdges.data();
|
||||
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
|
||||
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
|
||||
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
|
||||
operatorGraphDesc.outputEdges = outputEdges.data();
|
||||
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
|
||||
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
|
||||
|
||||
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
|
||||
}
|
||||
};
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);
|
||||
|
||||
} // namespace Dml
|
||||
|
|
@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
|
|||
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
|
||||
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
|
||||
|
|
@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
|
||||
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
|
|
|
|||
|
|
@ -1405,6 +1405,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
|
|||
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
|
||||
using ShapeInferenceHelper_RNN = RecurrentHelper;
|
||||
using ShapeInferenceHelper_GRU = RecurrentHelper;
|
||||
|
|
|
|||
|
|
@ -395,6 +395,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_FusedMatMul = 1;
|
||||
static const int sc_sinceVer_QLinearSigmoid = 1;
|
||||
static const int sc_sinceVer_Attention = 1;
|
||||
static const int sc_sinceVer_SkipLayerNormalization = 1;
|
||||
} // namespace MsftOperatorSet1
|
||||
|
||||
} // namespace OperatorHelper
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ static void RunTest(
|
|||
std::vector<int64_t> output_dims = input_dims;
|
||||
|
||||
auto rocm_ep = DefaultRocmExecutionProvider();
|
||||
auto dml_ep = DefaultDmlExecutionProvider();
|
||||
if (!use_float16) {
|
||||
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<float>("input", input_dims, input_data);
|
||||
|
|
@ -55,6 +56,7 @@ static void RunTest(
|
|||
test.AddOutput<float>("output", output_dims, output_data);
|
||||
test.Run();
|
||||
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
|
||||
dml_ep != nullptr ||
|
||||
rocm_ep != nullptr) {
|
||||
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
|
||||
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
|
||||
|
|
@ -73,7 +75,9 @@ static void RunTest(
|
|||
test.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
|
||||
|
||||
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
|
||||
if (rocm_ep != nullptr) {
|
||||
if (dml_ep != nullptr) {
|
||||
execution_providers.push_back(DefaultDmlExecutionProvider());
|
||||
} else if (rocm_ep != nullptr) {
|
||||
execution_providers.push_back(DefaultRocmExecutionProvider());
|
||||
} else {
|
||||
execution_providers.push_back(DefaultCudaExecutionProvider());
|
||||
|
|
|
|||
|
|
@ -339,8 +339,10 @@ struct TensorCheck<MLFloat16> {
|
|||
const bool has_rel_err = params.relative_error_.has_value();
|
||||
|
||||
float threshold = 0.001f;
|
||||
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
|
||||
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
|
||||
threshold = 0.005f;
|
||||
#elif defined(USE_DML)
|
||||
threshold = 0.008f;
|
||||
#endif
|
||||
for (int i = 0; i < size; ++i) {
|
||||
if (std::isnan(f_expected[i])) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue