diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 38040fba09..01c0902022 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1118,6 +1118,7 @@ Do not modify directly.* | | |**Operator Domain:** *com.microsoft*|||| |Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*in* key:**T**
*in* value:**T**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasGelu.cpp new file mode 100644 index 0000000000..5bf865c012 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasGelu.cpp @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorBiasGelu : public DmlOperator +{ +public: + DmlOperatorBiasGelu(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext) + { + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 2); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1); + + // Broadcast bias to have the same dimensions as the input + std::vector inputTensorShape = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0); + DmlOperator::Initialize(kernelCreationContext, std::nullopt, std::nullopt, inputTensorShape); + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + ML_CHECK_VALID_ARGUMENT(inputDescs.size() == 2); + ML_CHECK_VALID_ARGUMENT(outputDescs.size() == 1); + + TensorDesc biasInputTensorDesc(m_inputTensorDescs[0].GetDmlDataType(), m_inputTensorDescs[0].GetSizes()); + DML_TENSOR_DESC biasInputDmlTensorDesc = biasInputTensorDesc.GetDmlDesc(); + + DML_ACTIVATION_GELU_OPERATOR_DESC geluDesc = {}; + DML_OPERATOR_DESC geluOpDesc = { DML_OPERATOR_ACTIVATION_GELU, &geluDesc }; + + DML_ELEMENT_WISE_ADD1_OPERATOR_DESC addDesc = {}; + addDesc.ATensor = &inputDescs[0]; + addDesc.BTensor = &inputDescs[1]; + addDesc.FusedActivation = &geluOpDesc; + addDesc.OutputTensor = &outputDescs[0]; + DML_OPERATOR_DESC addOpDesc = { DML_OPERATOR_ELEMENT_WISE_ADD1, &addDesc }; + + SetDmlOperatorDesc(addOpDesc, kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(BiasGelu, DmlOperatorBiasGelu); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index e5f82afb50..d32ad0cb40 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -233,6 +233,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Erf); DML_OP_EXTERN_CREATION_FUNCTION(Where); DML_OP_EXTERN_CREATION_FUNCTION(Shrink); DML_OP_EXTERN_CREATION_FUNCTION(Gelu); +DML_OP_EXTERN_CREATION_FUNCTION(BiasGelu); DML_OP_EXTERN_CREATION_FUNCTION(OneHot); DML_OP_EXTERN_CREATION_FUNCTION(EyeLike); DML_OP_EXTERN_CREATION_FUNCTION(MaxUnpool); @@ -714,6 +715,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation // Contrib operators {REG_INFO_MS( 1, Gelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, BiasGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)}, {REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 829745251f..7df7a51b78 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1557,6 +1557,7 @@ using ShapeInferenceHelper_ParametricSoftplus = GetOutputShapeAsInputShapeHelper using ShapeInferenceHelper_Dropout = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Shrink = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Gelu = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_BiasGelu = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 781c6b898a..224c6fe4af 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -390,6 +390,7 @@ namespace OperatorHelper static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1; static const int sc_sinceVer_QLinearAdd = 1; static const int sc_sinceVer_Gelu = 1; + static const int sc_sinceVer_BiasGelu = 1; static const int sc_sinceVer_FusedMatMul = 1; static const int sc_sinceVer_QLinearSigmoid = 1; static const int sc_sinceVer_Attention = 1; diff --git a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc index 87de630534..43dde645e3 100644 --- a/onnxruntime/test/contrib_ops/element_wise_ops_test.cc +++ b/onnxruntime/test/contrib_ops/element_wise_ops_test.cc @@ -113,7 +113,7 @@ TEST(BiasGeluTest, Two_One_Dim) { RunBiasGeluTest(input_a_data, input_b_data, {2, 4}, {4}); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) TEST(BiasGeluTest, Two_One_Dim_fp16) { #ifdef USE_CUDA int min_cuda_architecture = 530; @@ -187,6 +187,8 @@ TEST(BiasGeluTest, Two_One_Dim_bfloat16) { execution_providers.push_back(DefaultRocmExecutionProvider()); #elif USE_DNNL execution_providers.push_back(DefaultDnnlExecutionProvider()); +#elif USE_DML + execution_providers.push_back(DefaultDmlExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); }