[DML EP] Add SkipLayerNormalization (#13849)

### Description

Add SkipLayerNormalization for the DML EP
This commit is contained in:
Patrice Vignola 2022-12-07 01:49:14 -08:00 committed by GitHub
parent 004a1538d3
commit 96d8d2c278
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 179 additions and 4 deletions

View file

@ -1128,6 +1128,7 @@ Do not modify directly.*
|QLinearAdd|*in* A:**T**<br> *in* A_scale:**tensor(float)**<br> *in* A_zero_point:**T**<br> *in* B:**T**<br> *in* B_scale:**tensor(float)**<br> *in* B_zero_point:**T**<br> *in* C_scale:**tensor(float)**<br> *in* C_zero_point:**T**<br> *out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
|**Operator Domain:** *com.microsoft.dml*||||

View file

@ -271,7 +271,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<GatherToSliceFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<MatmulTransposeFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasGeluFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<BiasSoftmaxFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<BiasDropoutFusion>(cuda_rocm_eps));
#ifdef ENABLE_TRAINING
@ -280,7 +280,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<SceLossGradBiasFusion>(cpu_cuda_rocm_eps));
#endif
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<SkipLayerNormFusion>(cpu_cuda_dml_rocm_eps));
transformers.emplace_back(std::make_unique<FastGeluFusion>(cpu_cuda_rocm_eps));
transformers.emplace_back(std::make_unique<QuickGeluFusion>(cpu_cuda_rocm_eps));

View file

@ -0,0 +1,164 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "precomp.h"
namespace Dml
{
class DmlOperatorSkipLayerNormalization : public DmlOperator
{
public:
DmlOperatorSkipLayerNormalization(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext)
{
std::vector<std::optional<uint32_t>> kernelInputIndices = {0, 1, 2, 3, 4};
DmlOperator::Initialize(
kernelCreationContext,
kernelInputIndices,
std::nullopt,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0),
std::nullopt,
kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0));
const float epsilon = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Epsilon, DefaultEpsilon);
int32_t onnxAxis = kernelCreationContext.GetOptionalAttribute<int32_t>(AttrName::Axis, -1);
uint32_t inputDimCount = kernelCreationContext.GetTensorShapeDescription().GetInputTensorDimensionCount(0);
onnxAxis = OperatorHelper::HandleNegativeAxis(onnxAxis, inputDimCount);
std::vector<uint32_t> onnxAxes(inputDimCount - onnxAxis);
std::iota(onnxAxes.begin(), onnxAxes.end(), onnxAxis);
assert(m_inputTensorDescs.size() == 5);
assert(m_outputTensorDescs.size() == 1);
auto inputDesc = m_inputTensorDescs[0].GetDmlDesc();
auto skipDesc = m_inputTensorDescs[1].GetDmlDesc();
auto gammaDesc = m_inputTensorDescs[2].GetDmlDesc();
auto betaDesc = m_inputTensorDescs[3].GetDmlDesc();
auto biasDesc = m_inputTensorDescs[4].GetDmlDesc();
auto outputDesc = m_outputTensorDescs[0].GetDmlDesc();
TensorDesc inputSkipBiasTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes());
DML_TENSOR_DESC inputSkipBiasDmlTensorDesc = inputSkipBiasTensorDesc.GetDmlDesc();
DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipAddDesc = {};
inputSkipAddDesc.ATensor = &inputDesc;
inputSkipAddDesc.BTensor = &skipDesc;
inputSkipAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
DML_OPERATOR_DESC inputSkipAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipAddDesc };
DML_ELEMENT_WISE_ADD_OPERATOR_DESC inputSkipBiasAddDesc = {};
inputSkipBiasAddDesc.ATensor = &inputSkipBiasDmlTensorDesc;
inputSkipBiasAddDesc.BTensor = &biasDesc;
inputSkipBiasAddDesc.OutputTensor = &inputSkipBiasDmlTensorDesc;
DML_OPERATOR_DESC inputSkipBiasAddOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD, &inputSkipBiasAddDesc };
DML_MEAN_VARIANCE_NORMALIZATION1_OPERATOR_DESC mvnDesc = {};
mvnDesc.InputTensor = &inputSkipBiasDmlTensorDesc;
mvnDesc.ScaleTensor = &gammaDesc;
mvnDesc.BiasTensor = betaDesc.Desc ? &betaDesc : nullptr;
mvnDesc.OutputTensor = &outputDesc;
mvnDesc.Axes = onnxAxes.data();
mvnDesc.AxisCount = gsl::narrow_cast<uint32_t>(onnxAxes.size());
mvnDesc.NormalizeVariance = true;
mvnDesc.Epsilon = epsilon;
mvnDesc.FusedActivation = nullptr;
DML_OPERATOR_DESC mvnOpDesc = { DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1, &mvnDesc };
// Construct the graph
std::vector<const DML_OPERATOR_DESC*> opDescs;
opDescs.reserve(3);
std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
inputEdges.reserve(5);
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
intermediateEdges.reserve(2);
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
outputEdges.reserve(1);
// Insert the Input + Skip operation into the graph
opDescs.push_back(&inputSkipAddOpDesc);
DML_INPUT_GRAPH_EDGE_DESC dataInputEdge = {};
dataInputEdge.GraphInputIndex = 0;
dataInputEdge.ToNodeIndex = 0;
dataInputEdge.ToNodeInputIndex = 0;
inputEdges.push_back(std::move(dataInputEdge));
DML_INPUT_GRAPH_EDGE_DESC skipInputEdge = {};
skipInputEdge.GraphInputIndex = 1;
skipInputEdge.ToNodeIndex = 0;
skipInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(std::move(skipInputEdge));
// Insert the InputSkip + Bias operation into the graph
if (biasDesc.Desc)
{
opDescs.push_back(&inputSkipBiasAddOpDesc);
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = 0;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = 1;
intermediateEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(std::move(intermediateEdge));
DML_INPUT_GRAPH_EDGE_DESC biasInputEdge = {};
biasInputEdge.GraphInputIndex = 4;
biasInputEdge.ToNodeIndex = 1;
biasInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(std::move(biasInputEdge));
}
// Insert the MVN operation into the graph
opDescs.push_back(&mvnOpDesc);
DML_INTERMEDIATE_GRAPH_EDGE_DESC intermediateEdge = {};
intermediateEdge.FromNodeIndex = biasDesc.Desc ? 1 : 0;
intermediateEdge.FromNodeOutputIndex = 0;
intermediateEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
intermediateEdge.ToNodeInputIndex = 0;
intermediateEdges.push_back(std::move(intermediateEdge));
DML_INPUT_GRAPH_EDGE_DESC gammaInputEdge = {};
gammaInputEdge.GraphInputIndex = 2;
gammaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
gammaInputEdge.ToNodeInputIndex = 1;
inputEdges.push_back(std::move(gammaInputEdge));
if (betaDesc.Desc)
{
DML_INPUT_GRAPH_EDGE_DESC betaInputEdge = {};
betaInputEdge.GraphInputIndex = 3;
betaInputEdge.ToNodeIndex = biasDesc.Desc ? 2 : 1;
betaInputEdge.ToNodeInputIndex = 2;
inputEdges.push_back(std::move(betaInputEdge));
}
DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
outputEdge.GraphOutputIndex = 0;
outputEdge.FromNodeIndex = biasDesc.Desc ? 2 : 1;
outputEdge.FromNodeOutputIndex = 0;
outputEdges.push_back(std::move(outputEdge));
MLOperatorGraphDesc operatorGraphDesc = {};
operatorGraphDesc.inputEdgeCount = gsl::narrow_cast<uint32_t>(inputEdges.size());
operatorGraphDesc.inputEdges = inputEdges.data();
operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast<uint32_t>(intermediateEdges.size());
operatorGraphDesc.intermediateEdges = intermediateEdges.data();
operatorGraphDesc.outputEdgeCount = gsl::narrow_cast<uint32_t>(outputEdges.size());
operatorGraphDesc.outputEdges = outputEdges.data();
operatorGraphDesc.nodeCount = gsl::narrow_cast<uint32_t>(opDescs.size());
operatorGraphDesc.nodesAsOpDesc = opDescs.data();
SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
}
};
DML_OP_DEFINE_CREATION_FUNCTION(SkipLayerNormalization, DmlOperatorSkipLayerNormalization);
} // namespace Dml

View file

@ -100,6 +100,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(BatchNormalization15);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LayerNormalization17);
DML_OP_EXTERN_CREATION_FUNCTION(SkipLayerNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LRN);
DML_OP_EXTERN_CREATION_FUNCTION(MeanVarianceNormalization);
DML_OP_EXTERN_CREATION_FUNCTION(LpNormalization);
@ -748,6 +749,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)},
{REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)},
{REG_INFO_MS( 1, SkipLayerNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
};
template<typename T>

View file

@ -1405,6 +1405,7 @@ using ShapeInferenceHelper_LRN = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MeanVarianceNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LayerNormalization17 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_SkipLayerNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_LpNormalization = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_RNN = RecurrentHelper;
using ShapeInferenceHelper_GRU = RecurrentHelper;

View file

@ -395,6 +395,7 @@ namespace OperatorHelper
static const int sc_sinceVer_FusedMatMul = 1;
static const int sc_sinceVer_QLinearSigmoid = 1;
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper

View file

@ -37,6 +37,7 @@ static void RunTest(
std::vector<int64_t> output_dims = input_dims;
auto rocm_ep = DefaultRocmExecutionProvider();
auto dml_ep = DefaultDmlExecutionProvider();
if (!use_float16) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
test.AddInput<float>("input", input_dims, input_data);
@ -55,6 +56,7 @@ static void RunTest(
test.AddOutput<float>("output", output_dims, output_data);
test.Run();
} else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) ||
dml_ep != nullptr ||
rocm_ep != nullptr) {
OpTester test("SkipLayerNormalization", 1, onnxruntime::kMSDomain);
test.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
@ -73,7 +75,9 @@ static void RunTest(
test.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (rocm_ep != nullptr) {
if (dml_ep != nullptr) {
execution_providers.push_back(DefaultDmlExecutionProvider());
} else if (rocm_ep != nullptr) {
execution_providers.push_back(DefaultRocmExecutionProvider());
} else {
execution_providers.push_back(DefaultCudaExecutionProvider());

View file

@ -339,8 +339,10 @@ struct TensorCheck<MLFloat16> {
const bool has_rel_err = params.relative_error_.has_value();
float threshold = 0.001f;
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)
#if defined(USE_TENSORRT) || defined(ENABLE_TRAINING) || defined(USE_CUDA) || defined(USE_ROCM)
threshold = 0.005f;
#elif defined(USE_DML)
threshold = 0.008f;
#endif
for (int i = 0; i < size; ++i) {
if (std::isnan(f_expected[i])) {