From 88402f52934ac4a8458e064cbb0f61507e9cf914 Mon Sep 17 00:00:00 2001 From: Tiago Koji Castro Shibata Date: Mon, 29 Jun 2020 00:54:43 -0700 Subject: [PATCH 01/13] Make DML operator registration constexpr (#4219) * Make DML operator registration constexpr * Refactor requiredConstantCpuInputs template * Revert "Refactor requiredConstantCpuInputs template" MSVC crashes compiling the new constexpr with "Internal compiler error" * Fix braces style --- .../src/Operators/OperatorRegistration.cpp | 183 +++++++++--------- 1 file changed, 95 insertions(+), 88 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 6f3942b19d..4557dd568f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -65,7 +65,7 @@ struct OperatorRegistrationInformation gsl::span supportedTensorDataTypes; DmGraphSupport DmGraphSupport; - std::vector requiredConstantCpuInputs; + std::pair, int> requiredConstantCpuInputs = {{}, 0}; // For use by operators such as Sum, which may require multiple calls to DML, in which case they // can't be represented as nodes in an optimized graph yet. @@ -234,60 +234,67 @@ DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); DML_OP_EXTERN_QUERY_FUNCTION(Slice); DML_OP_EXTERN_QUERY_FUNCTION(Resize); -const static char* const typeNameListDefault[1] = {"T"}; -const static char* const typeNameListTwo[2] = { "T1", "T2" }; -const static char* const typeNameListThree[3] = { "T1", "T2", "T3" }; -const static char* const typeNameListFour[4] = { "T1", "T2", "T3", "T4" }; -const static char* const typeNameListTopK[2] = { "T", "I" }; -const static char* const typeNameListLogicalComparison[2] = { "T", "T1" }; -const static char* const typeNameListConstantOfShape[2] = { "T1", "T2" }; -const static char* const typeNameListScatterGather[2] = { "T", "Tind" }; -const static char* const typeNameListScatterGatherND[1] = { "T" }; // Tind is curiously missing, only allowing 64-bit. -const static char* const typeNameListSlice10[2] = { "T", "Tind" }; -const static char* const typeNameListWhere[2] = { "B", "T" }; -const static char* const typeNameListEyeLike[1] = { "T2" }; -const static SupportedTensorDataTypes supportedTypeListAll[1] = {SupportedTensorDataTypes::All}; -const static SupportedTensorDataTypes supportedTypeListFloat32[1] = {SupportedTensorDataTypes::Float32}; -const static SupportedTensorDataTypes supportedTypeListFloat16to32[1] = {SupportedTensorDataTypes::Float16to32}; -const static SupportedTensorDataTypes supportedTypeListFloat16to32Int32[1] = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; -const static SupportedTensorDataTypes supportedTypeListInt8to32[1] = {SupportedTensorDataTypes::Int8to32}; -const static SupportedTensorDataTypes supportedTypeListInt32to64AndFloat16to32[1] = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; -const static SupportedTensorDataTypes supportedTypeListNumericDefault[1] = { SupportedTensorDataTypes::NumericDefault }; -const static SupportedTensorDataTypes supportedTypeListAllScalars[1] = { SupportedTensorDataTypes::AllScalars }; -const static SupportedTensorDataTypes supportedTypeListBool[1] = {SupportedTensorDataTypes::Bool}; -const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; -const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; -const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; -const static SupportedTensorDataTypes supportedTypeListScalars8to32[1] = { SupportedTensorDataTypes::Scalars8to32 }; -const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; -const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::Scalars8to32 }; -const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; -const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; -const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; -const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; -const static SupportedTensorDataTypes supportedTypeListIsNan[2] = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; -const static SupportedTensorDataTypes supportedTypeListIsInf[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; -const static SupportedTensorDataTypes supportedTypeListConstantOfShape[2] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 }; -const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars }; -const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; -const static SupportedTensorDataTypes supportedTypeListLogicalComparison7[2] = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; -const static SupportedTensorDataTypes supportedTypeListLogicalComparison9[2] = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool }; -const static SupportedTensorDataTypes supportedTypeListSigned[1] = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; -const static SupportedTensorDataTypes supportedTypeListRange[1] = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; -const static SupportedTensorDataTypes supportedTypeListInteger[3] = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; -const static SupportedTensorDataTypes supportedTypeListPadWithoutFloat16[1] = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 }; -const static SupportedTensorDataTypes supportedTypeListQLinearMatMul[3] = { +constexpr static std::array typeNameListDefault = {"T"}; +constexpr static std::array typeNameListTwo = { "T1", "T2" }; +constexpr static std::array typeNameListThree = { "T1", "T2", "T3" }; +constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; +constexpr static std::array typeNameListTopK = { "T", "I" }; +constexpr static std::array typeNameListLogicalComparison = { "T", "T1" }; +constexpr static std::array typeNameListConstantOfShape = { "T1", "T2" }; +constexpr static std::array typeNameListScatterGather = { "T", "Tind" }; +constexpr static std::array typeNameListScatterGatherND = { "T" }; // Tind is curiously missing, only allowing 64-bit. +constexpr static std::array typeNameListSlice10 = { "T", "Tind" }; +constexpr static std::array typeNameListWhere = { "B", "T" }; +constexpr static std::array typeNameListEyeLike = { "T2" }; +constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All}; +constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32}; +constexpr static std::array supportedTypeListFloat16to32 = {SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; +constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Int8to32}; +constexpr static std::array supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; +constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; +constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; +constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; +constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; +constexpr static std::array supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; +constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; +constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; +constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 }; +constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars }; +constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; +constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool }; +constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; +constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; +constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; +constexpr static std::array supportedTypeListPadWithoutFloat16 = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 }; +constexpr static std::array supportedTypeListQLinearMatMul = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; -const static SupportedTensorDataTypes supportedTypeListQLinearConv[4] = { +constexpr static std::array supportedTypeListQLinearConv = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; +template +constexpr auto requiredConstantCpuInputs(Args... args) +{ + std::array inputs = {static_cast(args)...}; + return std::make_pair(inputs, static_cast(sizeof...(args))); +} + // Define a single row of registration information. #define REG_INFO(version, operatorName, ...) \ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, @@ -311,7 +318,7 @@ const static SupportedTensorDataTypes supportedTypeListQLinearConv[4] = { #define REG_INFO_MSDML(version, operatorName, ...) \ #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, -const static OperatorRegistrationInformation operatorRegistrationInformationTable[] = +constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { /// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs, /// Input count required for graph support, @@ -327,9 +334,9 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, std::nullopt, QueryMaxPool)}, + {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -346,7 +353,7 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, - {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {2})}, + {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, // Data Reorganization Layers {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, @@ -355,16 +362,16 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, // Adds negative axis. {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes. - {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, {1})}, - {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})}, - {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})}, + {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)}, @@ -384,7 +391,7 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported)}, - {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported, {1})}, + {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -396,19 +403,19 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, Ceil, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1,2})}, + {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, {REG_INFO( 7, Div, typeNameListDefault, supportedTypeListFloat16to32Int32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, + {REG_INFO( 7, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Sum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Cos, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, Sin, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, Tan, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -472,10 +479,10 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, Crop, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, ImageScaler, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO_VER( 7, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1} /*scales*/)}, - {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1} /*scales*/)}, - {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1} /*scales*/)}, - {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1, 2, 3} /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, // Activation Functions {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -510,10 +517,10 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported)}, - {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, {1})}, - {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, {1})}, - {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, {1})}, - {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, {1})}, + {REG_INFO_VER( 10, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_VER( 11, TopK, typeNameListTopK, supportedTypeListTopK, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, OneHot, typeNameListThree, supportedTypeListOneHot, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Fused operators {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -523,19 +530,19 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, - - {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmGraphSupport::Supported)}, - {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, - {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})}, - {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmGraphSupport::Supported, {0,1,2})}, + {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {2})}, - {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {2})}, // 11 is identical to 9. + {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmGraphSupport::Supported)}, + {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, + {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmGraphSupport::Supported)}, + {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO( 11, Range, typeNameListDefault, supportedTypeListRange, DmGraphSupport::Supported, requiredConstantCpuInputs(0,1,2))}, + + {REG_INFO( 9, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, + {REG_INFO( 11, MaxUnpool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9. {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmGraphSupport::NotSupported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmGraphSupport::NotSupported)}, @@ -644,8 +651,8 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) kernelSupportsGraph, // supportsGraph information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr, information.requiresFloatFormatsForGraph, - information.requiredConstantCpuInputs.empty() ? nullptr : information.requiredConstantCpuInputs.data(), - static_cast(information.requiredConstantCpuInputs.size()) + information.requiredConstantCpuInputs.first.data(), + static_cast(information.requiredConstantCpuInputs.second) )); } } From 5f4e63ede66eced1144e1019cd02b6e3d09b3470 Mon Sep 17 00:00:00 2001 From: gwang-msft <62914304+gwang-msft@users.noreply.github.com> Date: Mon, 29 Jun 2020 11:55:45 -0700 Subject: [PATCH 02/13] Add nhwc support for NNAPI EP, add concat op, handle concurrent calls to NNAPI model (#4356) * add support to internally transpose nchw input to nhwc and only transpose back if it is necessary * more changes in nchw<->nhc, fixed small issue in concat * Add option for NNAPI to run on [all device]s/[cpu onl]y/[non-cpu only] * minor code style changes --- .../nnapi_builtin/builders/model_builder.cc | 126 ++++- .../nnapi_builtin/builders/model_builder.h | 43 +- .../nnapi_builtin/builders/op_builder.cc | 442 ++++++++++++++---- .../nnapi/nnapi_builtin/builders/op_builder.h | 6 +- .../nnapi/nnapi_builtin/builders/shaper.cc | 51 +- .../nnapi/nnapi_builtin/builders/shaper.h | 6 + .../providers/nnapi/nnapi_builtin/model.cc | 24 +- .../providers/nnapi/nnapi_builtin/model.h | 18 +- .../nnapi_builtin/nnapi_execution_provider.cc | 94 ++-- 9 files changed, 633 insertions(+), 177 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 6d840eab7a..98844f9c73 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -62,7 +62,7 @@ bool IsValidSupportedNodesVec(const std::vector& supported_node_vec, std::vector> ModelBuilder::GetSupportedNodes() { std::vector> supported_node_vecs; - int32_t android_sdk_ver = nnapi_ ? nnapi_->android_sdk_version : 0; + int32_t android_sdk_ver = GetAndroidSdkVer(); #ifdef __ANDROID__ if (android_sdk_ver < 27) { LOGS_DEFAULT(VERBOSE) << "Android API level " @@ -126,6 +126,7 @@ void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { void ModelBuilder::Prepare() { nnapi_model_ = std::unique_ptr(new Model()); THROW_ON_ERROR(nnapi_->ANeuralNetworksModel_create(&nnapi_model_->model_)); + GetTargetDevices(); PreprocessInitializers(); RegisterInitializers(); RegisterModelInputs(); @@ -144,9 +145,40 @@ static size_t GetPaddedByteSize(size_t size) { return (size + kDefaultByteAlignmentForNNAPI - 1) & ~(kDefaultByteAlignmentForNNAPI - 1); } +void ModelBuilder::GetTargetDevices() { + // GetTargetDevices is only supported on API 29+ + if (GetAndroidSdkVer() < 29) + return; + + if (target_device_option_ == TargetDeviceOption::ALL_DEVICES) + return; + + const std::string nnapi_cpu("nnapi-reference"); + uint32_t num_devices = 0; + THROW_ON_ERROR_WITH_NOTE(nnapi_->ANeuralNetworks_getDeviceCount(&num_devices), + "Getting list of available devices"); + + for (uint32_t i = 0; i < num_devices; i++) { + ANeuralNetworksDevice* device = nullptr; + const char* device_name = nullptr; + THROW_ON_ERROR_WITH_NOTE(nnapi_->ANeuralNetworks_getDevice(i, &device), + "Getting list of available devices"); + + THROW_ON_ERROR_WITH_NOTE(nnapi_->ANeuralNetworksDevice_getName(device, &device_name), + "Getting list of available devices"); + + bool device_is_cpu = nnapi_cpu == device_name; + if ((target_device_option_ == TargetDeviceOption::CPU_DISABLED && !device_is_cpu) || + (target_device_option_ == TargetDeviceOption::CPU_ONLY && device_is_cpu)) { + nnapi_target_devices_.push_back(device); + LOGS_DEFAULT(VERBOSE) << "Target device [" << device_name << "] added"; + } + } +} + void ModelBuilder::GetAllInitializers() { for (const auto& tensor : model_proto_.graph().initializer()) { - initializers_.insert({tensor.name(), tensor}); + initializers_.emplace(tensor.name(), tensor); } } @@ -190,7 +222,7 @@ void ModelBuilder::RegisterInitializers() { OperandType operand_type(type, shape); shaper_.AddShape(name, operand_type.dimensions); - auto index = AddNewOperand(name, operand_type); + auto index = AddNewOperand(name, operand_type, false /* is_nhwc */); const size_t size = operand_type.GetOperandBlobByteSize(); const size_t padded_size = GetPaddedByteSize(size); sizeAll += padded_size; @@ -264,7 +296,7 @@ void ModelBuilder::RegisterModelInputs() { OperandType operand_type(type, shape); shaper_.AddShape(input_name, operand_type.dimensions); - auto index = AddNewOperand(input_name, operand_type); + auto index = AddNewOperand(input_name, operand_type, false /* is_nhwc */); input_index_vec_.push_back(index); nnapi_model_->AddInput(input_name, operand_type); @@ -279,8 +311,15 @@ void ModelBuilder::RegisterModelOutputs() { ORT_THROW("The output of graph is not registered" + output_name); } - output_index_vec_.push_back(operand_indices_[output_name]); - nnapi_model_->AddOutput(output_name, operand_types_.at(output_name)); + std::string nnapi_output_name = output_name; + if (IsOperandNHWC(output_name)) { + // We need to transpose the output still in nhwc back to nchw + nnapi_output_name = GetUniqueName(output_name + "_nhwc_to_nchw"); + TransposeNHWCToNCHW(*this, output_name, nnapi_output_name); + } + + output_index_vec_.push_back(operand_indices_[nnapi_output_name]); + nnapi_model_->AddOutput(output_name, nnapi_output_name, operand_types_.at(nnapi_output_name)); } } @@ -290,11 +329,12 @@ void ModelBuilder::RegisterModelShaper() { } uint32_t ModelBuilder::AddNewOperand(const std::string& name, - const android::nn::wrapper::OperandType& operand_type) { + const OperandType& operand_type, + bool is_nhwc) { THROW_ON_ERROR(nnapi_->ANeuralNetworksModel_addOperand( nnapi_model_->model_, &operand_type.operandType)); auto idx = next_index_++; - RegisterOperand(name, idx, operand_type); + RegisterOperand(name, idx, operand_type, is_nhwc); return idx; } @@ -304,12 +344,14 @@ uint32_t ModelBuilder::AddNewNNAPIOperand(const OperandType& operand_type) { return next_index_++; } -void ModelBuilder::RegisterOperand(const std::string& name, - uint32_t index, - const OperandType& operand_type) { +void ModelBuilder::RegisterOperand(const std::string& name, uint32_t index, + const OperandType& operand_type, bool is_nhwc) { operand_indices_[name] = index; - operand_types_.insert({name, operand_type}); + operand_types_.emplace(name, operand_type); operands_.insert(name); + + if (is_nhwc) + RegisterNHWCOperand(name); } void ModelBuilder::SetOperandValue(uint32_t index, @@ -334,7 +376,7 @@ uint32_t ModelBuilder::AddOperandFromPersistMemoryBuffer( const std::string& name, const void* buffer, const android::nn::wrapper::OperandType& operand_type) { shaper_.AddShape(name, operand_type.dimensions); - auto index = AddNewOperand(name, operand_type); + auto index = AddNewOperand(name, operand_type, false /* is_nhwc */); const size_t size = operand_type.GetOperandBlobByteSize(); // for small size operand, the value will be copied @@ -369,10 +411,11 @@ void ModelBuilder::AddOperations() { void ModelBuilder::AddOperation(int op, const std::vector& input_indices, const std::vector& output_names, - const std::vector& types) { + const std::vector& types, + const std::vector& is_nhwc_vec) { std::vector output_indices; for (size_t i = 0; i < types.size(); i++) { - output_indices.push_back(AddNewOperand(output_names[i], types[i])); + output_indices.push_back(AddNewOperand(output_names[i], types[i], is_nhwc_vec[i])); } THROW_ON_ERROR_WITH_NOTE( @@ -393,7 +436,8 @@ std::unique_ptr ModelBuilder::Compile() { &output_index_vec_[0]), "on identifyInputsAndOutputs"); - if (use_fp16_) { + // relax fp32tofp16 is only available on API 28+ + if (use_fp16_ && GetAndroidSdkVer() > 27) { THROW_ON_ERROR_WITH_NOTE( nnapi_->ANeuralNetworksModel_relaxComputationFloat32toFloat16( nnapi_model_->model_, true), @@ -404,9 +448,17 @@ std::unique_ptr ModelBuilder::Compile() { nnapi_->ANeuralNetworksModel_finish(nnapi_model_->model_), "on model finish"); - THROW_ON_ERROR_WITH_NOTE( - nnapi_->ANeuralNetworksCompilation_create(nnapi_model_->model_, &nnapi_model_->compilation_), - "on create"); + if (!nnapi_target_devices_.empty()) { + THROW_ON_ERROR_WITH_NOTE( + nnapi_->ANeuralNetworksCompilation_createForDevices( + nnapi_model_->model_, nnapi_target_devices_.data(), + nnapi_target_devices_.size(), &nnapi_model_->compilation_), + "on createForDevices"); + } else { + THROW_ON_ERROR_WITH_NOTE( + nnapi_->ANeuralNetworksCompilation_create(nnapi_model_->model_, &nnapi_model_->compilation_), + "on create"); + } THROW_ON_ERROR_WITH_NOTE( nnapi_->ANeuralNetworksCompilation_setPreference( @@ -475,5 +527,41 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) { return unique_name; } +void ModelBuilder::RegisterNHWCOperand(const std::string& name) { + nhwc_operands_.insert(name); +} + +bool ModelBuilder::IsOperandNHWC(const std::string& name) { + return Contains(nhwc_operands_, name); +} + +bool ModelBuilder::GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name) { + if (Contains(nhwc_to_nchw_map_, nhwc_name)) { + nchw_name = nhwc_to_nchw_map_[nhwc_name]; + return true; + } + return false; +} + +bool ModelBuilder::GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name) { + if (Contains(nchw_to_nhwc_map_, nchw_name)) { + nhwc_name = nchw_to_nhwc_map_[nchw_name]; + return true; + } + return false; +} + +void ModelBuilder::SetNHWCToNCHWOperandMap(const std::string& nhwc_name, + const std::string& nchw_name) { + ORT_ENFORCE(!Contains(nhwc_to_nchw_map_, nhwc_name), "A previous nchw to nhwc map exists"); + nhwc_to_nchw_map_[nhwc_name] = nchw_name; +} + +void ModelBuilder::SetNCHWToNHWCOperandMap(const std::string& nchw_name, + const std::string& nhwc_name) { + ORT_ENFORCE(!Contains(nchw_to_nhwc_map_, nchw_name), "A previous nchw to nhwc map exists"); + nchw_to_nhwc_map_[nchw_name] = nhwc_name; +} + } // namespace nnapi } // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index a475ab6659..d9ca4a1b69 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -17,6 +17,15 @@ class ModelBuilder { public: using Shape = Shaper::Shape; + enum class TargetDeviceOption : int8_t { + ALL_DEVICES, // use all avaliable target devices + /* TODO support this option + SINGLE_DEVICE, // use a single target device, must be given + */ + CPU_DISABLED, // use all avaliable target devices except CPU + CPU_ONLY, // use CPU only + }; + ModelBuilder(ONNX_NAMESPACE::ModelProto& model_proto); ~ModelBuilder() = default; @@ -29,7 +38,8 @@ class ModelBuilder { // Add an NNAPI operation (operator) void AddOperation(int op, const std::vector& input_indices, const std::vector& output_names, - const std::vector& types); + const std::vector& types, + const std::vector& is_nhwc_vec); // Find if an output has a fuseable activation (Relu) int32_t FindActivation(const std::string& output); @@ -48,7 +58,8 @@ class ModelBuilder { // Register informations for a particular operand void RegisterOperand(const std::string& name, uint32_t index, - const android::nn::wrapper::OperandType& operand_type); + const android::nn::wrapper::OperandType& operand_type, + bool is_nhwc); // Generate an unique name for intermediate result std::string GetUniqueName(const std::string& base_name); @@ -84,6 +95,18 @@ class ModelBuilder { const ONNX_NAMESPACE::ModelProto& GetOnnxModel() const { return model_proto_; } + void RegisterNHWCOperand(const std::string& name); + bool IsOperandNHWC(const std::string& name); + + // Get the operand transposed to nchw/nhwc from given nhwc/nchw operand, if it exists + bool GetNCHWOperand(const std::string& nhwc_name, std::string& nchw_name); + bool GetNHWCOperand(const std::string& nchw_name, std::string& nhwc_name); + + void SetNHWCToNCHWOperandMap(const std::string& nhwc_name, + const std::string& nchw_name); + void SetNCHWToNHWCOperandMap(const std::string& nchw_name, + const std::string& nhwc_name); + private: const NnApi* nnapi_{nullptr}; ONNX_NAMESPACE::ModelProto& model_proto_; @@ -91,7 +114,7 @@ class ModelBuilder { uint32_t name_token_{0}; - bool use_nchw_{true}; + bool use_nchw_{false}; bool use_fp16_{false}; android::nn::wrapper::ExecutePreference exe_pref_{ android::nn::wrapper::ExecutePreference::PREFER_FAST_SINGLE_ANSWER}; @@ -109,11 +132,21 @@ class ModelBuilder { std::unordered_map> op_builders_; + // Operands in nhwc + std::unordered_set nhwc_operands_; + + // Maps between nhwc and nchw, and vice versa + std::unordered_map nhwc_to_nchw_map_; + std::unordered_map nchw_to_nhwc_map_; + std::vector input_index_vec_; std::vector output_index_vec_; std::unordered_set unique_names_; + TargetDeviceOption target_device_option_{TargetDeviceOption::ALL_DEVICES}; + std::vector nnapi_target_devices_; + uint32_t next_index_ = 0; bool IsNodeSupported(const ONNX_NAMESPACE::NodeProto& node); @@ -121,6 +154,7 @@ class ModelBuilder { // Convert the onnx model to ANeuralNetworksModel void Prepare(); + void GetTargetDevices(); void GetAllInitializers(); void PreprocessInitializers(); void RegisterInitializers(); @@ -134,7 +168,8 @@ class ModelBuilder { uint32_t AddNewNNAPIOperand(const android::nn::wrapper::OperandType& type); uint32_t AddNewOperand(const std::string& name, - const android::nn::wrapper::OperandType& operand_type); + const android::nn::wrapper::OperandType& operand_type, + bool is_nhwc); IOpBuilder* GetOpBuilder(const ONNX_NAMESPACE::NodeProto& node); }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc index 4a3c14dd9a..3b69ee44a6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.cc @@ -30,12 +30,87 @@ const float* GetTensorFloatData(const ONNX_NAMESPACE::TensorProto& tensor) { : tensor.float_data().data(); } +void AddTransposeOperator(ModelBuilder& model_builder, + const std::string& input, + const std::string& perm_name, + vector perm, + const std::string& output, + bool output_is_nhwc) { + auto& shaper(model_builder.GetShaper()); + const auto& operand_indices(model_builder.GetOperandIndices()); + const auto& operand_types(model_builder.GetOperandTypes()); + + std::vector input_indices; + input_indices.push_back(operand_indices.at(input)); // input + + ModelBuilder::Shape perm_dimen = {SafeInt(perm.size())}; + OperandType perm_operand_type(Type::TENSOR_INT32, perm_dimen); + uint32_t perm_idx = model_builder.AddOperandFromPersistMemoryBuffer( + perm_name, perm.data(), perm_operand_type); + + input_indices.push_back(perm_idx); // permutation + shaper.Transpose(input, perm, output); + const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); + model_builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, {output}, + {output_operand_type}, {output_is_nhwc}); +} + +void TransposeBetweenNCHWAndNHWC(ModelBuilder& model_builder, + const std::string& input, + const std::string& output, + bool nchw_to_nhwc) { + ORT_ENFORCE(!model_builder.UseNCHW(), "model_builder.UseNCHW() is on"); + const auto& shaper(model_builder.GetShaper()); + ORT_ENFORCE( + 4 == shaper[input].size(), + "TransposeNCHWToNHWC input has to be a 4d tensor, actual dimensions: " + + std::to_string(shaper[input].size())); + + std::string perm_name; + vector perm; + if (nchw_to_nhwc) { + perm_name = model_builder.GetUniqueName(input + "nchw_to_nhwc_perm"); + perm = {0, 2, 3, 1}; + } else { // nhwc_to_nchw + perm_name = model_builder.GetUniqueName(input + "nhwc_to_nchw_perm"); + perm = {0, 3, 1, 2}; + } + + AddTransposeOperator(model_builder, input, perm_name, perm, output, nchw_to_nhwc); + + if (nchw_to_nhwc) { + model_builder.SetNCHWToNHWCOperandMap(input, output); + } else { // nhwc_to_nchw + model_builder.SetNHWCToNCHWOperandMap(input, output); + } + + LOGS_DEFAULT(VERBOSE) << "Operand [" << input << "] with shape " + << Shape2String(shaper[input]) + << " is transposed " + << (nchw_to_nhwc ? "nchw_to_nhwc" : "nhwc_to_nchw") + << " to [" << output << "] with shape " + << Shape2String(shaper[output]); +} + +void TransposeNHWCToNCHW(ModelBuilder& model_builder, + const std::string& input, + const std::string& output) { + TransposeBetweenNCHWAndNHWC(model_builder, input, output, false /* nchw_to_nhwc */); +} + +void TransposeNCHWToNHWC(ModelBuilder& model_builder, + const std::string& input, + const std::string& output) { + TransposeBetweenNCHWAndNHWC(model_builder, input, output, true /* nchw_to_nhwc */); +} + void AddBinaryOperator(int32_t op_type, ModelBuilder& model_builder, const std::string& input1, const std::string& input2, int32_t fuse_code, - const std::string& output) { + const std::string& output, + bool output_is_nhwc) { auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); @@ -46,41 +121,7 @@ void AddBinaryOperator(int32_t op_type, input_indices.push_back(model_builder.AddOperandFromScalar(fuse_code)); shaper.Eltwise(input1, input2, output); const OperandType output_operand_type(operand_types.at(input1).type, shaper[output]); - model_builder.AddOperation(op_type, input_indices, {output}, {output_operand_type}); -} - -void AddPoolOperator(int32_t op_type, - ModelBuilder& model_builder, - const std::string& input, - const vector& onnx_pads, - const vector& onnx_strides, - const vector& kernel_shape, - int32_t fuse_code, - const std::string& output) { - auto& shaper(model_builder.GetShaper()); - const auto& operand_indices(model_builder.GetOperandIndices()); - const auto& operand_types(model_builder.GetOperandTypes()); - bool use_nchw = model_builder.UseNCHW(); - - std::vector input_indices; - input_indices.push_back(operand_indices.at(input)); - input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[1])); - input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[3])); - input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[0])); - input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[2])); - input_indices.push_back(model_builder.AddOperandFromScalar(onnx_strides[1])); - input_indices.push_back(model_builder.AddOperandFromScalar(onnx_strides[0])); - input_indices.push_back(model_builder.AddOperandFromScalar(kernel_shape[1])); - input_indices.push_back(model_builder.AddOperandFromScalar(kernel_shape[0])); - input_indices.push_back(model_builder.AddOperandFromScalar(fuse_code)); - input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw)); - - shaper.Pool(input, - onnx_pads, onnx_strides, kernel_shape, - use_nchw, - output); - const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.AddOperation(op_type, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(op_type, input_indices, {output}, {output_operand_type}, {output_is_nhwc}); } int GetType(const ONNX_NAMESPACE::ModelProto& model_proto, @@ -159,7 +200,7 @@ Shaper::Shape GetShape(const ONNX_NAMESPACE::ModelProto& model_proto, } enum DataLayout { - L_NCHW = 0, + L_0231 = 0, L_1230 = 1, }; @@ -187,8 +228,8 @@ uint32_t AddInitializerInNewLayout(ModelBuilder& model_builder, auto out_t = shape[0], in_t = shape[1], h_t = shape[2], w_t = shape[3]; ModelBuilder::Shape dest_shape; - if (new_layout == L_NCHW) - dest_shape = {out_t, h_t, w_t, in_t}; // L_NCHW + if (new_layout == L_0231) + dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 else dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv weight @@ -205,7 +246,7 @@ uint32_t AddInitializerInNewLayout(ModelBuilder& model_builder, w; uint32_t nnapi_idx; - if (new_layout == L_NCHW) { // L_NCHW + if (new_layout == L_0231) { // L_0231 nnapi_idx = out * h_t * w_t * in_t + h * w_t * in_t + w * in_t + @@ -389,12 +430,33 @@ void BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, else { ORT_THROW("UnaryOpBuilder, unknown op: " + op); } - const auto& input1 = node.input(0); - const auto& input2 = node.input(1); + std::string input1 = node.input(0); + std::string input2 = node.input(1); + bool input1_is_nhwc = model_builder.IsOperandNHWC(input1); + bool input2_is_nhwc = model_builder.IsOperandNHWC(input2); + bool output_is_nhwc = false; + + if (input1_is_nhwc == input2_is_nhwc) { + output_is_nhwc = input1_is_nhwc; + } else if (input1_is_nhwc) { + // need transpsoe input1 back to nchw + const auto& nhwc_input = node.input(0); + if (!model_builder.GetNCHWOperand(nhwc_input, input1)) { + input1 = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); + TransposeNHWCToNCHW(model_builder, nhwc_input, input1); + } + } else { // input2_is_nhwc + // need transpsoe input2 back to nchw + const auto& nhwc_input = node.input(1); + if (!model_builder.GetNCHWOperand(nhwc_input, input2)) { + input2 = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); + TransposeNHWCToNCHW(model_builder, nhwc_input, input2); + } + } + const auto& output = node.output(0); int32_t fuse_code = model_builder.FindActivation(output); - AddBinaryOperator(op_code, model_builder, - input1, input2, fuse_code, output); + AddBinaryOperator(op_code, model_builder, input1, input2, fuse_code, output, output_is_nhwc); } #pragma endregion @@ -415,16 +477,17 @@ void ReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input = node.input(0); const auto& output = node.output(0); + bool output_is_nhwc = model_builder.IsOperandNHWC(input); shaper.Identity(input, output); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); // skip this relu if it is some op's fuse output if (Contains(model_builder.GetFusedActivations(), node.name())) { - model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); + model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); } else { std::vector input_indices; input_indices.push_back(operand_indices.at(input)); - model_builder.AddOperation(ANEURALNETWORKS_RELU, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(ANEURALNETWORKS_RELU, input_indices, {output}, {output_operand_type}, {output_is_nhwc}); } } @@ -462,31 +525,35 @@ bool TransposeOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, void TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ONNX_NAMESPACE::NodeProto& node) { auto& shaper(model_builder.GetShaper()); - const auto& operand_indices(model_builder.GetOperandIndices()); - const auto& operand_types(model_builder.GetOperandTypes()); + + auto input = node.input(0); + const auto& output = node.output(0); NodeAttrHelper helper(node); - - const auto& input = node.input(0); - std::vector input_indices; - input_indices.push_back(operand_indices.at(input)); // input - vector perm = helper.Get("perm", vector()); auto input_dims = shaper[input].size(); if (perm.empty()) { for (int32_t i = input_dims - 1; i >= 0; i--) perm.push_back(i); + } else { + ORT_ENFORCE(perm.size() == input_dims, "Perm and input should have same dimension"); } - ModelBuilder::Shape perm_dimen = {SafeInt(input_dims)}; - std::string perm_name = model_builder.GetUniqueName(node.name() + input + "perm"); - OperandType perm_operand_type(Type::TENSOR_INT32, perm_dimen); - uint32_t perm_idx = model_builder.AddOperandFromPersistMemoryBuffer(perm_name, perm.data(), perm_operand_type); - input_indices.push_back(perm_idx); + if (model_builder.IsOperandNHWC(input)) { + ORT_ENFORCE(input_dims == 4, "Only 4D shape can be nhwc"); - const auto& output = node.output(0); - shaper.Transpose(input, perm, output); - const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.AddOperation(ANEURALNETWORKS_TRANSPOSE, input_indices, {output}, {output_operand_type}); + // we are using nhwc here, but the axis is in nchw, need to transpose axis from nchw to nhwc + const int32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2}; + for (size_t i = 0; i < perm.size(); i++) + perm[i] = axis_nchw_to_nhwc[perm[i]]; + } + + std::string perm_name = model_builder.GetUniqueName(node.name() + input + "perm"); + + // It is possible this onnx transpose operator can be nchw->nhwc, but so far I don't see + // any scenario will do this since onnx is nchw only, assume the output is always not nhwc + // even it is, there will be extra transpose in the onnx model to convert it back to nchw + // before conv/pool/... operators + AddTransposeOperator(model_builder, input, perm_name, perm, output, false /* is_nhwc */); } #pragma endregion op_transpose @@ -551,7 +618,17 @@ void ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& operand_types(model_builder.GetOperandTypes()); const auto& initializers(model_builder.GetInitializerTensors()); - const auto& input = node.input(0); + auto input = node.input(0); + + if (model_builder.IsOperandNHWC(input)) { + // We want to transpose nhwc operand back to nchw before reshape + const auto& nhwc_input = node.input(0); + if (!model_builder.GetNCHWOperand(nhwc_input, input)) { + input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); + TransposeNHWCToNCHW(model_builder, nhwc_input, input); + } + } + const auto& output = node.output(0); std::vector input_indices; input_indices.push_back(operand_indices.at(input)); // input @@ -576,7 +653,8 @@ void ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, shaper.Reshape(input, shape, output); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(ANEURALNETWORKS_RESHAPE, input_indices, + {output}, {output_operand_type}, {false}); } #pragma endregion op_reshape @@ -676,10 +754,13 @@ void BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil const auto tensor_b_name = model_builder.GetUniqueName(node.name() + input + "_imm_b"); const auto tensor_imm_product_name = model_builder.GetUniqueName(node.name() + input + "_imm_mul"); ModelBuilder::Shape tensor_a_dimen; - if (model_builder.UseNCHW()) - tensor_a_dimen = {size, 1, 1}; // {C, H, W} - else + + bool input_is_nhwc = model_builder.IsOperandNHWC(input); + bool output_is_nhwc = input_is_nhwc; + if (input_is_nhwc) tensor_a_dimen = {size}; + else // input is nchw + tensor_a_dimen = {size, 1, 1}; // {C, H, W} shaper.AddShape(tensor_a_name, tensor_a_dimen); shaper.AddShape(tensor_b_name, tensor_a_dimen); @@ -693,7 +774,8 @@ void BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil model_builder, input, tensor_a_name, ANEURALNETWORKS_FUSED_NONE, - tensor_imm_product_name); + tensor_imm_product_name, + output_is_nhwc); // Add int32_t fuse_code = model_builder.FindActivation(output); @@ -701,7 +783,8 @@ void BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil model_builder, tensor_imm_product_name, tensor_b_name, fuse_code, - output); + output, + output_is_nhwc); } #pragma endregion op_batchnormalization @@ -782,17 +865,36 @@ bool PoolOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, void PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ONNX_NAMESPACE::NodeProto& node) { auto& shaper(model_builder.GetShaper()); + const auto& operand_indices(model_builder.GetOperandIndices()); + const auto& operand_types(model_builder.GetOperandTypes()); + NodeAttrHelper helper(node); - const auto& input = node.input(0); + auto input = node.input(0); + bool use_nchw = model_builder.UseNCHW(); + bool input_is_nhwc = model_builder.IsOperandNHWC(input); + bool output_is_nhwc = false; + if (use_nchw) { + ORT_ENFORCE(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC"); + } else { + output_is_nhwc = true; + if (!input_is_nhwc) { + const auto& nchw_input = node.input(0); + if (!model_builder.GetNHWCOperand(nchw_input, input)) { + input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc"); + TransposeNCHWToNHWC(model_builder, nchw_input, input); + } + } + } + const auto& output = node.output(0); const auto& op = node.op_type(); - int32_t operationType; + int32_t op_type; if (op == "AveragePool" || op == "GlobalAveragePool") - operationType = ANEURALNETWORKS_AVERAGE_POOL_2D; + op_type = ANEURALNETWORKS_AVERAGE_POOL_2D; else // (op == "MaxPool" || op == "GlobalMaxPool") - operationType = ANEURALNETWORKS_MAX_POOL_2D; + op_type = ANEURALNETWORKS_MAX_POOL_2D; vector onnx_pads, onnx_strides, kernel_shape; if (op == "AveragePool" || op == "MaxPool") { @@ -811,12 +913,25 @@ void PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } int32_t fuse_code = model_builder.FindActivation(output); - AddPoolOperator(operationType, - model_builder, - input, - onnx_pads, onnx_strides, kernel_shape, - fuse_code, - output); + std::vector input_indices; + input_indices.push_back(operand_indices.at(input)); + input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[1])); + input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[3])); + input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[0])); + input_indices.push_back(model_builder.AddOperandFromScalar(onnx_pads[2])); + input_indices.push_back(model_builder.AddOperandFromScalar(onnx_strides[1])); + input_indices.push_back(model_builder.AddOperandFromScalar(onnx_strides[0])); + input_indices.push_back(model_builder.AddOperandFromScalar(kernel_shape[1])); + input_indices.push_back(model_builder.AddOperandFromScalar(kernel_shape[0])); + input_indices.push_back(model_builder.AddOperandFromScalar(fuse_code)); + input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw)); + + shaper.Pool(input, + onnx_pads, onnx_strides, kernel_shape, + use_nchw, + output); + const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); + model_builder.AddOperation(op_type, input_indices, {output}, {output_operand_type}, {output_is_nhwc}); } #pragma endregion op_pool @@ -877,7 +992,6 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& operand_types(model_builder.GetOperandTypes()); const auto& initializers(model_builder.GetInitializerTensors()); NodeAttrHelper helper(node); - bool use_nchw = model_builder.UseNCHW(); // onnx strides are in the order height, width // while nnapi strides are in the order width, height @@ -892,7 +1006,23 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto onnx_dilations = helper.Get("dilations", vector{1, 1}); const auto group = helper.Get("group", 1); - const auto& input = node.input(0); + auto input = node.input(0); + bool use_nchw = model_builder.UseNCHW(); + bool input_is_nhwc = model_builder.IsOperandNHWC(input); + bool output_is_nhwc = false; + if (use_nchw) { + ORT_ENFORCE(!input_is_nhwc, "model_builder.UseNCHW() but input is NHWC"); + } else { + output_is_nhwc = true; + if (!input_is_nhwc) { + const auto& nchw_input = node.input(0); + if (!model_builder.GetNHWCOperand(nchw_input, input)) { + input = model_builder.GetUniqueName(nchw_input + "_nchw_to_nhwc"); + TransposeNCHWToNHWC(model_builder, nchw_input, input); + } + } + } + const auto& weight = node.input(1); const auto& output = node.output(0); @@ -905,7 +1035,7 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (conv2d) { input_indices.push_back(AddInitializerInNewLayout( - model_builder, weight, L_NCHW)); + model_builder, weight, L_0231)); } else { // depthwise_conv2d input_indices.push_back(AddInitializerInNewLayout( model_builder, weight, L_1230)); @@ -927,13 +1057,13 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& weight_type = operand_types.at(weight).type; if (weight_type == Type::TENSOR_FLOAT32) { - float buffer[bias_dimen[0]]; - for (uint32_t i = 0; i < bias_dimen[0]; i++) { + vector buffer(bias_dimen[0]); + for (uint32_t i = 0; i < buffer.size(); i++) { buffer[i] = 0.f; } OperandType operandType(Type::TENSOR_FLOAT32, bias_dimen); bias_idx_val = model_builder.AddOperandFromPersistMemoryBuffer( - bias, &buffer[0], operandType); + bias, buffer.data(), operandType); } else { ORT_THROW("Unknown weight type " + TypeToStr(weight_type)); } @@ -952,7 +1082,7 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } int32_t fuse_code = model_builder.FindActivation(output); input_indices.push_back(model_builder.AddOperandFromScalar(fuse_code)); - // TODO support API 27 + // TODO support API 28 input_indices.push_back(model_builder.AddOperandFromScalar(use_nchw)); input_indices.push_back(model_builder.AddOperandFromScalar(onnx_dilations[1])); input_indices.push_back(model_builder.AddOperandFromScalar(onnx_dilations[0])); @@ -973,7 +1103,7 @@ void ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.AddOperation(operationCode, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(operationCode, input_indices, {output}, {output_operand_type}, {output_is_nhwc}); } #pragma endregion op_conv @@ -1017,6 +1147,8 @@ void CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input = node.input(0); const auto& output = node.output(0); + bool output_is_nhwc = model_builder.IsOperandNHWC(input); + auto to = helper.Get("to", 0); Type type; switch (to) { @@ -1035,7 +1167,8 @@ void CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, input_indices.push_back(operand_indices.at(input)); shaper.Identity(input, output); const OperandType output_operand_type(type, shaper[output]); - model_builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(ANEURALNETWORKS_CAST, input_indices, {output}, + {output_operand_type}, {output_is_nhwc}); } #pragma endregion @@ -1076,7 +1209,16 @@ void SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& operand_types(model_builder.GetOperandTypes()); NodeAttrHelper helper(node); - const auto& input = node.input(0); + auto input = node.input(0); + if (model_builder.IsOperandNHWC(input)) { + // We want to transpose nhwc operand back to nchw before softmax + const auto& nhwc_input = node.input(0); + if (!model_builder.GetNCHWOperand(nhwc_input, input)) { + input = model_builder.GetUniqueName(nhwc_input + "_nhwc_to_nchw"); + TransposeNHWCToNCHW(model_builder, nhwc_input, input); + } + } + const auto& output = node.output(0); float beta = 1.f; int32_t axis = helper.Get("axis", 1); @@ -1087,7 +1229,8 @@ void SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, shaper.Identity(input, output); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(ANEURALNETWORKS_SOFTMAX, input_indices, {output}, + {output_operand_type}, {false}); } #pragma endregion @@ -1110,12 +1253,14 @@ void IdentityOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input = node.input(0); const auto& output = node.output(0); + bool output_is_nhwc = model_builder.IsOperandNHWC(input); + std::vector input_indices; input_indices.push_back(operand_indices.at(input)); // input shaper.Identity(input, output); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); - model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); + model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type, output_is_nhwc); } #pragma endregion @@ -1141,6 +1286,16 @@ bool GemmOpBuilder::IsOpSupportedImpl( const auto& op = node.op_type(); const auto& initializers(model_builder.GetInitializerTensors()); + if (GetShape(model_builder.GetOnnxModel(), node.input(0)).size() != 2) { + LOGS_DEFAULT(VERBOSE) << "A must be 2D"; + return false; + } + + if (GetShape(model_builder.GetOnnxModel(), node.input(0)).size() != 2) { + LOGS_DEFAULT(VERBOSE) << "B must be 2D"; + return false; + } + if (op == "MatMul") { // Only support A*B B is an initializer if (!Contains(initializers, node.input(1))) { LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known"; @@ -1172,7 +1327,10 @@ bool GemmOpBuilder::IsOpSupportedImpl( const auto c_shape = GetShape(model_builder.GetOnnxModel(), node.input(2)); if (c_shape.size() != 1 || c_shape[0] != (transB == 0 ? b_shape[1] : b_shape[0])) { - LOGS_DEFAULT(VERBOSE) << "C of Gemm must be a vector of b_shape[0]"; + LOGS_DEFAULT(VERBOSE) << "C of Gemm must be a vector of b_shape[0]" + << " b_shape: " << Shape2String(b_shape) + << " c_shape: " << Shape2String(c_shape); + return false; } } @@ -1243,8 +1401,8 @@ void GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, shaper.FC(input1, input2, output); const OperandType output_operand_type(operand_types.at(input1).type, shaper[output]); - model_builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, - input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(ANEURALNETWORKS_FULLY_CONNECTED, input_indices, {output}, + {output_operand_type}, {false}); } #pragma endregion @@ -1285,6 +1443,8 @@ void UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input = node.input(0); const auto& output = node.output(0); + bool output_is_nhwc = model_builder.IsOperandNHWC(input); + shaper.Identity(input, output); const OperandType output_operand_type(operand_types.at(input).type, shaper[output]); @@ -1312,7 +1472,97 @@ void UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } std::vector input_indices; input_indices.push_back(operand_indices.at(input)); - model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}); + model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}, {output_is_nhwc}); +} + +#pragma endregion + +#pragma region op_concat + +class ConcatOpBuilder : public BaseOpBuilder { + private: + bool IsOpSupportedImpl(ModelBuilder& model_builder, + const ONNX_NAMESPACE::NodeProto& node) override; + + void AddToModelBuilderImpl(ModelBuilder& model_builder, + const ONNX_NAMESPACE::NodeProto& node) override; +}; + +bool ConcatOpBuilder::IsOpSupportedImpl(ModelBuilder& model_builder, + const ONNX_NAMESPACE::NodeProto& node) { + if (GetShape(model_builder.GetOnnxModel(), node.input(0)).size() > 4) { + LOGS_DEFAULT(VERBOSE) << "Concat supports at most 4D shape"; + return false; + } + + return true; +} + +void ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const ONNX_NAMESPACE::NodeProto& node) { + auto& shaper(model_builder.GetShaper()); + const auto& operand_indices(model_builder.GetOperandIndices()); + const auto& operand_types(model_builder.GetOperandTypes()); + NodeAttrHelper helper(node); + + std::vector input_indices; + const auto& input0 = node.input(0); + bool all_input_have_same_layout = true; + bool output_is_nhwc = false; + + // First we want to see if all the input are smae layout + for (int i = 0; i < node.input_size() - 1; i++) { + all_input_have_same_layout = + all_input_have_same_layout && + model_builder.IsOperandNHWC(node.input(i)) == model_builder.IsOperandNHWC(node.input(i + 1)); + } + + std::vector inputs; + inputs.reserve(node.input_size()); + if (all_input_have_same_layout) { + // if all the inputs are of same layout, output will be the same layout + if (model_builder.IsOperandNHWC(input0)) { + output_is_nhwc = true; + } + + for (const auto& input : node.input()) { + input_indices.push_back(operand_indices.at(input)); + inputs.push_back(input); + } + } else { + // if all the inputs are not same layout, + // will need transpos those nhwc tensors back to nchw + for (auto input : node.input()) { + if (model_builder.IsOperandNHWC(input)) { + std::string nhwc_input = input; + input = model_builder.GetUniqueName(input + "_nhwc_to_nchw"); + TransposeNHWCToNCHW(model_builder, nhwc_input, input); + } + input_indices.push_back(operand_indices.at(input)); + inputs.push_back(input); + } + } + + int32_t axis = helper.Get("axis", 1); + int rank = shaper[input0].size(); + if (axis < 0) { // NNAPI does not support negative axis + axis = rank + axis; + } + + if (output_is_nhwc) { + ORT_ENFORCE(rank == 4, "nhwc is only on 4d shape, input " + input0 + + " has rank: " + std::to_string(rank)); + // we are using nhwc here, but the axis is in nwhw, need to transpose axis from nchw to nhwc + const uint32_t axis_nchw_to_nhwc[4]{0, 3, 1, 2}; + axis = axis_nchw_to_nhwc[axis]; + } + input_indices.push_back(model_builder.AddOperandFromScalar(axis)); + + const auto& output = node.output(0); + shaper.Concat(inputs, axis, output); + const OperandType output_operand_type(operand_types.at(input0).type, shaper[output]); + model_builder.AddOperation(ANEURALNETWORKS_CONCATENATION, input_indices, {output}, + {output_operand_type}, {output_is_nhwc}); } #pragma endregion @@ -1368,6 +1618,8 @@ CreateOpBuilders() { op_map.emplace("Tanh", unary_op_builder); } + op_map.emplace("Concat", std::make_shared()); + return op_map; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h index 128972768c..d0901ee1c1 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h @@ -31,5 +31,9 @@ class IOpBuilder { std::unordered_map> CreateOpBuilders(); +void TransposeNHWCToNCHW(ModelBuilder& model_builder, + const std::string& input, + const std::string& output); + } // namespace nnapi -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc index ad2217aee6..3f9a2251d7 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.cc @@ -345,6 +345,39 @@ void Shaper::FC(const std::string& input1_name, const std::string& input2_name, } } +void Shaper::Concat(const std::vector& input_names, + const int32_t axis, + const std::string& output_name) { + std::vector dimens; + for (const auto& input_name : input_names) { + auto& dimen = shape_map_.at(input_name); + if (!dimens.empty()) { + for (size_t i = 0; i < dimens[0].size(); i++) { + if ((int32_t)i == axis) + continue; + + ORT_ENFORCE(dimen[i] == dimens[0][i], "Wrong input for concat"); + } + } + + dimens.push_back(shape_map_.at(input_name)); + } + + auto output_dimen = dimens[0]; + for (size_t i = 1; i < dimens.size(); i++) { + output_dimen[axis] += dimens[i][axis]; + } + + shape_map_[output_name] = output_dimen; + + if (!shaper_finalized_) { + shape_ops_.push_back( + [input_names, axis, output_name](Shaper& shaper) { + shaper.Concat(input_names, axis, output_name); + }); + } +} + void Shaper::AddShape(const std::string& name, const Shape& shape) { shape_map_[name] = shape; } @@ -354,10 +387,12 @@ void Shaper::UpdateShape(const std::string& name, const Shape& new_shape) { "Cannot UpdateShape while shaper is not finalized"); const auto& old_shape = shape_map_.at(name); - if (old_shape != new_shape && Product(shape_map_.at(name)) != 0) - ORT_THROW("The shape should be same size or old shape has size 0"); + if (old_shape != new_shape) { + if (Product(old_shape) != 0) + ORT_THROW("The shape should be same size or old shape has size 0 (dynamic shape)"); - shape_map_[name] = new_shape; + shape_map_[name] = new_shape; + } } void Shaper::UpdateDynamicDimensions() { @@ -373,3 +408,13 @@ void Shaper::Clear() { shape_map_.clear(); shape_ops_.clear(); } + +std::string Shape2String(const Shaper::Shape& shape) { + std::ostringstream os; + os << "[ "; + for (const auto& dim : shape) + os << dim << " "; + + os << "]"; + return os.str(); +} diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h index 6a4818bcce..862c6134d9 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/shaper.h @@ -47,6 +47,10 @@ class Shaper { const std::string& input2_name, const std::string& output_name); + void Concat(const std::vector& input_names, + const int32_t axis, + const std::string& output_name); + // If the shape of certain input is dynamic // Use the following 2 functions to update the particular shape // and calculate the new output shape @@ -68,3 +72,5 @@ class Shaper { std::unordered_map shape_map_; std::vector> shape_ops_; }; + +std::string Shape2String(const Shaper::Shape& shape); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc index 4fcd7b6bb3..a22dcf22cc 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "model.h" #include "core/providers/nnapi/nnapi_builtin/builders/helper.h" #include "core/providers/nnapi/nnapi_builtin/nnapi_lib/nnapi_implementation.h" @@ -25,12 +27,18 @@ Model::~Model() { void Model::AddInput(const std::string& name, const android::nn::wrapper::OperandType& operand_type) { input_names_.push_back(name); - operand_types_.insert({name, operand_type}); + operand_types_.emplace(name, operand_type); } -void Model::AddOutput(const std::string& name, const android::nn::wrapper::OperandType& operand_type) { - output_names_.push_back(name); - operand_types_.insert({name, operand_type}); +void Model::AddOutput(const std::string& onnx_output_name, + const std::string& nnapi_output_name, + const android::nn::wrapper::OperandType& operand_type) { + LOGS_DEFAULT(VERBOSE) << "Model::AddOutput output name " << onnx_output_name + << " shape " << Shape2String(operand_type.dimensions); + + output_names_.push_back(onnx_output_name); + onnx_to_nnapi_output_map_.emplace(onnx_output_name, nnapi_output_name); + operand_types_.emplace(nnapi_output_name, operand_type); } const std::vector& Model::GetInputs() const { @@ -46,9 +54,10 @@ const android::nn::wrapper::OperandType& Model::GetInputType(const std::string& } const android::nn::wrapper::OperandType Model::GetOutputType(const std::string& name) const { - const auto& output_type = operand_types_.at(name); + const auto& nnapi_output_name = onnx_to_nnapi_output_map_.at(name); + const auto& output_type = operand_types_.at(nnapi_output_name); android::nn::wrapper::OperandType type( - output_type.type, shaper_for_exeuction_[name], output_type.operandType.scale, output_type.operandType.zeroPoint); + output_type.type, shaper_for_exeuction_[nnapi_output_name], output_type.operandType.scale, output_type.operandType.zeroPoint); return type; } @@ -79,6 +88,9 @@ void Model::SetInputBuffer(const int32_t index, const InputBuffer& input) { void Model::SetOutputBuffer(const int32_t index, const OutputBuffer& output) { PrepareForExecution(); + LOGS_DEFAULT(VERBOSE) << "Model::SetOutputBuffer, output shape " + << Shape2String(output.type.dimensions); + THROW_ON_ERROR(nnapi_->ANeuralNetworksExecution_setOutput( execution_, index, &output.type.operandType, output.buffer, output.type.GetOperandBlobByteSize())); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h index 985c02eb66..75c3fbb0f9 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/model.h @@ -4,6 +4,7 @@ #pragma once #include "builders/shaper.h" +#include "core/platform/ort_mutex.h" #include "nnapi_lib/NeuralNetworksWrapper.h" namespace onnxruntime { @@ -90,6 +91,9 @@ class Model { // Execute the NNAPI model void Predict(); + // Mutex for exclusive lock to this model object + OrtMutex& GetMutex() { return mutex_; } + private: const NnApi* nnapi_{nullptr}; bool prepared_for_exe_ = false; @@ -112,9 +116,21 @@ class Model { std::unordered_map input_map_; std::unordered_map output_map_; + // We may transpose the nnapi output to nchw with a different name + // This is map is to lookup the nnapi output from the onnx output + std::unordered_map onnx_to_nnapi_output_map_; + + OrtMutex mutex_; + Model(); void AddInput(const std::string& name, const android::nn::wrapper::OperandType& operand_type); - void AddOutput(const std::string& name, const android::nn::wrapper::OperandType& operand_type); + + // It is possible that the actual output from NNAPI model is not the same as the name of + // the output from the onnx model, need to have both names and add mapping between them + void AddOutput(const std::string& onnx_output_name, + const std::string& nnapi_output_name, + const android::nn::wrapper::OperandType& operand_type); + void SetShaper(const Shaper shaper) { shaper_ = shaper; } void SetInputBuffer(const int32_t index, const InputBuffer& input); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index 93b0c4f6ce..91201cb517 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -66,12 +66,12 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view onnxruntime::Graph& graph_build = model.MainGraph(); for (const auto& node : graph_view.Nodes()) { std::vector inputs, outputs; - for (auto input : node.InputDefs()) { + for (auto* input : node.InputDefs()) { auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); inputs.push_back(&n_input); all_node_inputs.insert(input->Name()); } - for (auto output : node.OutputDefs()) { + for (auto* output : node.OutputDefs()) { auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); outputs.push_back(&n_output); } @@ -113,9 +113,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view for (const auto& index : group) { sub_graph->nodes.push_back(node_index[index]); - const auto node = graph_view.GetNode(node_index[index]); + const auto* node = graph_view.GetNode(node_index[index]); - for (const auto& input : node->InputDefs()) { + for (const auto* input : node->InputDefs()) { const auto it = fused_outputs.find(input); if (it != fused_outputs.end()) { fused_outputs.erase(it); @@ -135,7 +135,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::unordered_set processed_outputs; for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { const auto node_idx = it->GetNode().Index(); - const auto output = node->OutputDefs()[it->GetSrcArgIndex()]; + const auto* output = node->OutputDefs()[it->GetSrcArgIndex()]; if (node_set.find(node_idx) != node_set.end()) { const auto iter = fused_inputs.find(output); @@ -152,7 +152,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view processed_outputs.insert(output); } - for (const auto& output : node->OutputDefs()) { + for (const auto* output : node->OutputDefs()) { if (processed_outputs.find(output) != processed_outputs.end()) continue; @@ -209,13 +209,6 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return result; } -std::string GetShape(const std::vector& dimensions) { - std::string ret = ""; - for (auto dim : dimensions) - ret += std::to_string(dim) + " "; - return "[" + ret + "]"; -} - common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes, std::vector& node_compute_funcs) { using namespace android::nn::wrapper; @@ -235,6 +228,8 @@ common::Status NnapiExecutionProvider::Compile(const std::vector nnapi_model = builder.Compile(); // Build map from input name to its index in input definitions @@ -275,8 +270,6 @@ common::Status NnapiExecutionProvider::Compile(const std::vector(state); const size_t num_inputs = ort.KernelContext_GetInputCount(context); const size_t num_outputs = ort.KernelContext_GetOutputCount(context); @@ -316,45 +309,50 @@ common::Status NnapiExecutionProvider::Compile(const std::vectorSetInputBuffers(inputs); - std::vector outputs; - outputs.reserve(num_outputs); - for (size_t i = 0; i < num_outputs; i++) { - const auto output_name = model->GetOutputs()[i]; - const auto model_output_type = model->GetOutputType(output_name); - const auto output_shape = model_output_type.dimensions; + // From this point we will need to take the exclusive lock on the model until the Predict is + // performed, to block other threads (if any) to modify this particular model + { + std::unique_lock lock(model->GetMutex()); + model->SetInputBuffers(inputs); + std::vector outputs; + outputs.reserve(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + const auto output_name = model->GetOutputs()[i]; + const auto model_output_type = model->GetOutputType(output_name); + const auto output_shape = model_output_type.dimensions; - std::vector int64_output_shape(output_shape.begin(), - output_shape.end()); - auto output_idx = model->GetMappedOutputIdx(output_name); - auto* output_tensor = ort.KernelContext_GetOutput(context, output_idx, - int64_output_shape.data(), - int64_output_shape.size()); + std::vector int64_output_shape(output_shape.begin(), + output_shape.end()); + auto output_idx = model->GetMappedOutputIdx(output_name); + auto* output_tensor = ort.KernelContext_GetOutput(context, output_idx, + int64_output_shape.data(), + int64_output_shape.size()); - void* output_buffer = nullptr; - switch (model_output_type.type) { - case Type::TENSOR_FLOAT32: - output_buffer = ort.GetTensorMutableData(output_tensor); - break; - case Type::TENSOR_INT32: - output_buffer = ort.GetTensorMutableData(output_tensor); - break; - default: - ORT_THROW("Unsupported output type: " + TypeToStr(model_output_type.type)); - break; + void* output_buffer = nullptr; + switch (model_output_type.type) { + case Type::TENSOR_FLOAT32: + output_buffer = ort.GetTensorMutableData(output_tensor); + break; + case Type::TENSOR_INT32: + output_buffer = ort.GetTensorMutableData(output_tensor); + break; + default: + return Status(common::ONNXRUNTIME, common::FAIL, + "Unsupported output type: " + TypeToStr(model_output_type.type)); + break; + } + + if (model_output_type.GetOperandBlobByteSize() == 0) { + return Status(common::ONNXRUNTIME, common::FAIL, "We do not support dynamic output shape for now"); + } + + outputs.push_back({output_buffer, std::move(model_output_type)}); } - if (model_output_type.GetOperandBlobByteSize() == 0) { - return Status(common::ONNXRUNTIME, common::FAIL, "We do not support dynamic output shape for now"); - } - - outputs.push_back({output_buffer, std::move(model_output_type)}); + model->SetOutputBuffers(outputs); + model->Predict(); } - model->SetOutputBuffers(outputs); - - model->Predict(); - return Status::OK(); }; From 35a048ef9b9e840abde5e04cf335af09aaac8729 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 29 Jun 2020 14:27:06 -0700 Subject: [PATCH 03/13] Ignore one failed test in DML (#4366) 2020-06-29 08:51:32.9157882 [E:onnxruntime:Default, runner.cc:452 DataRunner::RunTaskImpl] keras2coreml_Dense_ImageNet:output=output1:expected 0.233292 (3e6ee400), got 0.231783 (3e6d587b), diff: 0.00150879, tol=0.00123329 idx=52. 1 of 255 differ --- onnxruntime/test/onnx/main.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 3ae3ebee4b..db82906abe 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -433,7 +433,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { static const ORTCHAR_T* cuda_flaky_tests[] = { ORT_TSTR("fp16_inception_v1"), ORT_TSTR("fp16_shufflenet"), ORT_TSTR("fp16_tiny_yolov2")}; - static const ORTCHAR_T* dml_disabled_tests[] = {ORT_TSTR("mlperf_ssd_resnet34_1200"), ORT_TSTR("mlperf_ssd_mobilenet_300"), ORT_TSTR("mask_rcnn"), ORT_TSTR("faster_rcnn"), ORT_TSTR("tf_pnasnet_large"), ORT_TSTR("zfnet512")}; + static const ORTCHAR_T* dml_disabled_tests[] = {ORT_TSTR("mlperf_ssd_resnet34_1200"), ORT_TSTR("mlperf_ssd_mobilenet_300"), ORT_TSTR("mask_rcnn"), ORT_TSTR("faster_rcnn"), ORT_TSTR("tf_pnasnet_large"), ORT_TSTR("zfnet512"), ORT_TSTR("keras2coreml_Dense_ImageNet")}; static const ORTCHAR_T* dnnl_disabled_tests[] = {ORT_TSTR("test_densenet121"), ORT_TSTR("test_resnet18v2"), ORT_TSTR("test_resnet34v2"), ORT_TSTR("test_resnet50v2"), ORT_TSTR("test_resnet101v2"), ORT_TSTR("test_resnet101v2"), ORT_TSTR("test_vgg19"), ORT_TSTR("tf_inception_resnet_v2"), ORT_TSTR("tf_inception_v1"), ORT_TSTR("tf_inception_v3"), ORT_TSTR("tf_inception_v4"), ORT_TSTR("tf_mobilenet_v1_1.0_224"), ORT_TSTR("tf_mobilenet_v2_1.0_224"), ORT_TSTR("tf_mobilenet_v2_1.4_224"), ORT_TSTR("tf_nasnet_large"), ORT_TSTR("tf_pnasnet_large"), ORT_TSTR("tf_resnet_v1_50"), ORT_TSTR("tf_resnet_v1_101"), ORT_TSTR("tf_resnet_v1_101"), From 465140b3842289ddc86d908bffe47c0701037059 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Mon, 29 Jun 2020 16:07:42 -0700 Subject: [PATCH 04/13] Misc fixes to Conv and ConvTranspose CUDA kernels (#4281) --- onnxruntime/core/providers/cuda/nn/conv.cc | 17 ++-- .../core/providers/cuda/nn/conv_transpose.cc | 88 +++++++++++-------- 2 files changed, 60 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index f7004ee2be..0cccf3b41d 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -95,10 +95,6 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { Tensor* Y = context->Output(0, TensorShape(s_.y_dims)); y_data = reinterpret_cast(Y->template MutableData()); - // special case when there is a dim value of 0 in the shape. - if (Y->Shape().Size() == 0) - return Status::OK(); - std::vector x_dims_cudnn = x_dims; std::vector y_dims_cudnn = y_dims; if (rank < 2) { @@ -112,12 +108,21 @@ Status Conv::ComputeInternal(OpKernelContext* context) const { strides.push_back(1); dilations.push_back(1); } - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); if (w_dims_changed) ORT_RETURN_IF_ERROR(s_.filter_desc.Set(w_dims, CudnnTensor::GetDataType())); + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `filter_desc` + // 2) `y_dims` in s_.y_dims + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, mode, CudnnTensor::GetDataType())); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 3fe3a53f71..e258d8d685 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -67,7 +67,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ { std::lock_guard lock(s_.mutex); - // TODO: add a global cache if need to handle cases for multiple frames running simultaneuously with different batch_size + // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size bool input_dims_changed = (s_.last_x_dims != x_dims); bool w_dims_changed = (s_.last_w_dims != w_dims); if (input_dims_changed || w_dims_changed) { @@ -82,11 +82,6 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ ConvTransposeAttributes::Prepare p; ORT_RETURN_IF_ERROR(conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding)); - // Bail out early if one of the dimensions is zero. - if (p.Y->Shape().Size() == 0) { - return Status::OK(); - } - auto y_dims = p.Y->Shape().GetDims(); if (x_dimensions == 3) { y_dims.insert(y_dims.begin() + 2, 1); @@ -98,12 +93,20 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ } s_.y_dims = y_dims; - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); - if (w_dims_changed) ORT_RETURN_IF_ERROR(s_.filter_desc.Set(w_dims, CudnnTensor::GetDataType())); + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `filter_desc` + // 2) `y_dims` in s_.y_dims + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, mode, CudnnTensor::GetDataType())); @@ -155,42 +158,49 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ s_.algo = perf.algo; s_.workspace_bytes = perf.memory; } - } - if (!y_data) { - auto y_dims = s_.y_dims; - if (x_dimensions == 3) { - y_dims.erase(y_dims.begin() + 2); + // The following block will be executed in case there has been no change in the shapes of the + // input and the filter compared to the previous run + if (!y_data) { + auto y_dims = s_.y_dims; + if (x_dimensions == 3) { + y_dims.erase(y_dims.begin() + 2); + } + Tensor* Y = context->Output(0, TensorShape(y_dims)); + y_data = reinterpret_cast(Y->template MutableData()); + + // Bail out early if one of the output dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } } - Tensor* Y = context->Output(0, TensorShape(y_dims)); - y_data = reinterpret_cast(Y->template MutableData()); - } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; + const auto alpha = Consts::One; + const auto beta = Consts::Zero; - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes); + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes); - CUDNN_RETURN_IF_ERROR( - cudnnConvolutionBackwardData( - CudnnHandle(), - &alpha, - s_.filter_desc, - w_data, - s_.x_tensor, - x_data, - s_.conv_desc, - s_.algo, - workspace.get(), - s_.workspace_bytes, - &beta, - s_.y_tensor, - y_data)); + CUDNN_RETURN_IF_ERROR( + cudnnConvolutionBackwardData( + CudnnHandle(), + &alpha, + s_.filter_desc, + w_data, + s_.x_tensor, + x_data, + s_.conv_desc, + s_.algo, + workspace.get(), + s_.workspace_bytes, + &beta, + s_.y_tensor, + y_data)); - if (has_bias) { - const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); - auto b_data = reinterpret_cast(B->template Data()); - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + if (has_bias) { + const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); + auto b_data = reinterpret_cast(B->template Data()); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + } } return Status::OK(); From 2601f8e1b48bc8783d150cc6b7389bc357dda101 Mon Sep 17 00:00:00 2001 From: Weixing Zhang Date: Mon, 29 Jun 2020 21:46:13 -0700 Subject: [PATCH 05/13] Support to build CUDA EP for NV Ampere GPU (#4345) Co-authored-by: Weixing Zhang --- cmake/CMakeLists.txt | 10 ++++++---- cmake/onnxruntime_providers.cmake | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 11c51fb763..f807aaeb9c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -834,9 +834,8 @@ if (onnxruntime_USE_CUDA) string(APPEND CMAKE_CUDA_FLAGS "-cudart shared") endif() enable_language(CUDA) - string(REGEX REPLACE "([0-9]+)\\.([0-9]+).*" "\\1" CUDA_VERSION_MAJOR "${CMAKE_CUDA_COMPILER_VERSION}") - message( STATUS "CUDA_VERSION_MAJOR: ${CUDA_VERSION_MAJOR}") - if (CUDA_VERSION_MAJOR EQUAL 11) + message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11) set(CMAKE_CUDA_STANDARD 14) else() set(CMAKE_CUDA_STANDARD 11) @@ -867,13 +866,16 @@ if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${ONNXRUNTIME_CUDA_LIBRARIES}) # the following compute capabilities are deprecated in CUDA 11 Toolkit - if (CUDA_VERSION_MAJOR LESS 11) + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_30,code=sm_30") # K series set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_50,code=sm_50") # M series endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_52,code=sm_52") # M60 set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_60,code=sm_60") # P series set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_70,code=sm_70") # V series + if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_80,code=sm_80") # A series + endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr --default-stream legacy") if (NOT WIN32) set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --expt-relaxed-constexpr --compiler-options -fPIC") diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 4d9fc96061..ea3c9bbeeb 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -246,7 +246,7 @@ if (onnxruntime_USE_CUDA) set_target_properties(onnxruntime_providers_cuda PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(onnxruntime_providers_cuda PROPERTIES FOLDER "ONNXRuntime") - if (CUDA_VERSION_MAJOR LESS 11) + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11) target_include_directories(onnxruntime_providers_cuda PRIVATE ${PROJECT_SOURCE_DIR}/external/cub) endif() From 55f25a4bbf67fd26b091ef2758ddfbd0404f0d8e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 29 Jun 2020 23:26:23 -0700 Subject: [PATCH 06/13] Update Attention op to support attention mask for GPT-2 (#4330) * Support another two format of mask_index input: 2D attention mask, or 1D mask index with end and start positions. * Update dynamic axes of gpt2 with past state * Update script to fuse model with attention mask --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 37 +- .../contrib_ops/cpu/bert/attention_cpu_base.h | 42 +- .../contrib_ops/cpu/bert/attention_helper.h | 62 +- .../contrib_ops/cuda/bert/attention.cc | 4 +- .../contrib_ops/cuda/bert/attention_impl.cu | 343 ++++++++--- .../contrib_ops/cuda/bert/attention_impl.h | 29 +- .../quantization/attention_quantization.cc | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 13 +- .../tools/transformers/benchmark_gpt2.py | 62 +- .../transformers/fusion_gpt_attention.py | 46 +- .../test/contrib_ops/attention_op_test.cc | 546 +++++++++++++++++- 11 files changed, 969 insertions(+), 216 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 760e86a1a9..8fc118c7e7 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -48,6 +48,7 @@ Status AttentionBase::CheckInputs(const Tensor* input, dims.size()); } int batch_size = static_cast(dims[0]); + int sequence_length = static_cast(dims[1]); int hidden_size = static_cast(dims[2]); if (hidden_size % num_heads_ != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -74,25 +75,10 @@ Status AttentionBase::CheckInputs(const Tensor* input, } if (bias_dims[0] != weights_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 2 dimension 0 should have same length as dimension 1 of input 1"); - } - - if (mask_index != nullptr) { // mask_index is optional - // unidirectional (like GPT2) does not need mask input. Here we do not allowed the input for unidirectional. - if (is_unidirectional_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is not allowed for unidirectional"); - } - - const auto& mask_dims = mask_index->Shape().GetDims(); - if (mask_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1 dimension, got ", - mask_dims.size()); - } - if (static_cast(mask_dims[0]) != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' and 'input' shall have same length at dimension 0"); - } + "Input 'bias' dimension 0 should have same length as dimension 1 of input 'weights'"); } + int past_sequence_length = 0; if (past != nullptr) { // past is optional if (!is_unidirectional_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past' is only allowed for unidirectional"); @@ -115,8 +101,24 @@ Status AttentionBase::CheckInputs(const Tensor* input, if (static_cast(past_dims[4]) != hidden_size / num_heads_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_); } + past_sequence_length = static_cast(past_dims[3]); } + if (mask_index != nullptr) { // mask_index is optional + const auto& mask_dims = mask_index->Shape().GetDims(); + if (mask_dims.size() == 1) { + if (static_cast(mask_dims[0]) != batch_size && static_cast(mask_dims[0]) != 2 * batch_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' dimension 0 shall have length of batch_size or 2 * batch_size"); + } + } else if (mask_dims.size() == 2) { + if (static_cast(mask_dims[0]) != batch_size || static_cast(mask_dims[1]) != past_sequence_length + sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)"); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1 or 2 dimensions, got ", + mask_dims.size()); + } + } return Status::OK(); } @@ -174,7 +176,6 @@ Status Attention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - // Compute Q, K, V // gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH) auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 1f01504ded..4252098dbf 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -48,26 +48,21 @@ class AttentionCPUBase : public AttentionBase { auto attention_probs = allocator->Alloc(attention_probs_bytes); BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator)); - size_t mask_data_bytes = 0; - if (mask_index != nullptr) { - mask_data_bytes = SafeInt(batch_size) * sequence_length * all_sequence_length * sizeof(T); - } else if (is_unidirectional_) { - mask_data_bytes = SafeInt(sequence_length) * all_sequence_length * sizeof(T); - } - void* mask_data = nullptr; - if (mask_data_bytes > 0) { + if (mask_index != nullptr || (is_unidirectional_ && sequence_length > 1)) { + size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * all_sequence_length * sizeof(T); mask_data = allocator->Alloc(mask_data_bytes); memset(mask_data, 0, mask_data_bytes); } BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); const int32_t* mask_index_data = mask_index != nullptr ? mask_index->template Data() : nullptr; + const std::vector* mask_index_dims = mask_index != nullptr ? &(mask_index->Shape().GetDims()) : nullptr; const T* past_data = past != nullptr ? past->template Data() : nullptr; T* present_data = present != nullptr ? present->template MutableData() : nullptr; ComputeAttentionProbs(static_cast(attention_probs), Q, K, - mask_index_data, static_cast(mask_data), + mask_index_data, mask_index_dims, static_cast(mask_data), batch_size, sequence_length, past_sequence_length, head_size, past_data, present_data, tp); @@ -89,17 +84,18 @@ class AttentionCPUBase : public AttentionBase { // 1 x mask_data(B, N, S, S*) // II.attention_probs(B, N, S, S*) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer for the attention probs. Its size is BxNxSxS - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxSxH - const int32_t* mask_index, // mask index. nullptr if no mask or its size is B - T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention - int past_sequence_length, // sequence length of past state - int head_size, // head size of self-attention - const T* past, // past state - T* present, // present state + void ComputeAttentionProbs(T* attention_probs, // output buffer for the attention probs. Its size is BxNxSxS + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxSxH + const int32_t* mask_index, // mask index. nullptr if no mask or its size is B + const std::vector* mask_index_dims, // mask index shape + T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise + int batch_size, // batch size of self-attention + int sequence_length, // sequence length of self-attention + int past_sequence_length, // sequence length of past state + int head_size, // head size of self-attention + const T* past, // past state + T* present, // present state ThreadPool* tp) const { const int all_sequence_length = past_sequence_length + sequence_length; // S* = S' + S const size_t past_chunk_length = static_cast(past_sequence_length * head_size); // S' x H @@ -108,7 +104,7 @@ class AttentionCPUBase : public AttentionBase { { if (mask_data != nullptr) { - PrepareMask(mask_index, mask_data, is_unidirectional_, batch_size, sequence_length, past_sequence_length); + PrepareMask(mask_index, mask_index_dims, mask_data, is_unidirectional_, batch_size, sequence_length, past_sequence_length); } else { // no any mask memset(attention_probs, 0, batch_size * num_heads_ * sequence_length * all_sequence_length * sizeof(T)); } @@ -123,9 +119,9 @@ class AttentionCPUBase : public AttentionBase { for (std::ptrdiff_t i = begin; i != end; ++i) { const std::ptrdiff_t batch_index = i / num_heads_; - // broadcast mask data: SxS* or (Bx)SxS* -> (BxNx)SxS* + // broadcast mask data: (Bx)SxS* -> (BxNx)SxS* if (mask_data != nullptr) { - const T* broadcast_data_src = is_unidirectional_ ? reinterpret_cast(mask_data) : reinterpret_cast(mask_data) + batch_index * sequence_length * all_sequence_length; + const T* broadcast_data_src = reinterpret_cast(mask_data) + batch_index * sequence_length * all_sequence_length; T* broadcast_data_dest = reinterpret_cast(attention_probs) + sequence_length * all_sequence_length * i; memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * all_sequence_length * sizeof(T)); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 7507cb2ab2..5acb8c2a30 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -61,38 +61,66 @@ inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPoo template void PrepareMask(const int32_t* mask_index, + const std::vector* mask_index_dims, T* mask_data, bool is_unidirectional, int batch_size, int sequence_length, int past_sequence_length) { const int all_sequence_length = past_sequence_length + sequence_length; - T* p_mask = mask_data; - if (is_unidirectional) { - // unidirectional mask has shape SxS* - for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - p_mask[s_i * all_sequence_length + m_i] = static_cast(-10000.0); - } - } - return; - } - ORT_ENFORCE(mask_index, "mask index should not be null."); + // mask_data has been filled with 0, and its shape is BxSxS* + T* p_mask = mask_data; + + bool is_raw_attention_mask = (nullptr != mask_index_dims && mask_index_dims->size() == 2); + bool has_mask_start_position = (nullptr != mask_index_dims && mask_index_dims->size() == 1 && static_cast(mask_index_dims->at(0)) == 2 * batch_size); + for (int b_i = 0; b_i < batch_size; b_i++) { // TODO: mask_index can be used in softmax to save some calculation. - // Convert mask_index to mask (-10000 means out of range, which will be 0 after softmax): B => BxS* - int valid_length = mask_index[b_i]; - for (int m_i = valid_length; m_i < all_sequence_length; m_i++) { - p_mask[m_i] = static_cast(-10000.0); + + if (nullptr != mask_index) { + if (is_raw_attention_mask) { + // Raw attention mask has value 0 or 1. Here we convert 0 to -10000.0, and 1 to 0.0. + const int32_t* raw_mask = mask_index + b_i * all_sequence_length; + for (int m_i = 0; m_i < all_sequence_length; m_i++) { + p_mask[m_i] = (raw_mask[m_i] > 0) ? static_cast(0.0f) : static_cast(-10000.0f); + } + } else { + // mask_index is 1D: (B) or (2B) => (Bx)S* + + // Handle right-side padding: mask value at or after the end position will be -10000.0 + int end_position = mask_index[b_i]; + for (int m_i = end_position; m_i < all_sequence_length; m_i++) { + p_mask[m_i] = static_cast(-10000.0f); + } + + // Handle left-side padding: mask value before the start position will be -10000.0 + if (has_mask_start_position) { + int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); + for (int m_i = 0; m_i < start_position; m_i++) { + p_mask[m_i] = static_cast(-10000.0f); + } + } + } } - // Broadcast mask from BxS* to BxSxS* + // Broadcast mask from (Bx)S* to (Bx)SxS* for (int s_i = 1; s_i < sequence_length; s_i++) { memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); } - p_mask += sequence_length * sequence_length; + + // Apply unidirectional mask. + if (is_unidirectional) { + for (int s_i = 0; s_i < sequence_length - 1; s_i++) { + for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { + p_mask[s_i * all_sequence_length + m_i] += static_cast(-10000.0f); + } + } + } + + p_mask += sequence_length * all_sequence_length; } + } // Concatenate a past state chunk S'xH with input state chunk SxH into present state chunk S*xH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 158d0e8cc0..a6b3f0b3b0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -89,6 +89,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (!LaunchAttentionKernel( reinterpret_cast(gemm_buffer.get()), nullptr == mask_index ? nullptr : mask_index->template Data(), + nullptr == mask_index ? nullptr : &(mask_index->Shape().GetDims()), output->template MutableData(), batch_size, sequence_length, @@ -100,8 +101,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { is_unidirectional_, past_sequence_length, nullptr == past ? nullptr : past->template Data(), - nullptr == present ? nullptr : present->template MutableData() - )) { + nullptr == present ? nullptr : present->template MutableData())) { // Get last error to reset it to cudaSuccess. CUDA_CALL(cudaGetLastError()); return Status(common::ONNXRUNTIME, common::FAIL); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 9e7a61b518..54206eb5fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -40,8 +40,8 @@ static size_t AlignTo(size_t a, size_t b) { return CeilDiv(a, b) * b; } -size_t ScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int past_sequence_length) { - const size_t len = batch_size * num_heads * sequence_length * (sequence_length + past_sequence_length); +size_t ScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) { + const size_t len = batch_size * num_heads * sequence_length * all_sequence_length; const size_t bytes = len * element_size; const size_t alignment = 256; @@ -57,11 +57,16 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length) { size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size; - return qkv_size + 2 * ScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length); + return qkv_size + 2 * ScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); } template -__device__ inline void Softmax(const int past_sequence_length, const int sequence_length, const int valid_length, const T* input, T* output, bool is_unidirectional) { +__device__ inline void Softmax(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* input, + T* output) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -70,18 +75,17 @@ __device__ inline void Softmax(const int past_sequence_length, const int sequenc float thread_data_max(-CUDART_INF_F); - const int num_valid = is_unidirectional ? past_sequence_length + (blockIdx.x % sequence_length) + 1 : valid_length; - // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int all_sequence_length = past_sequence_length + sequence_length; const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int i = threadIdx.x; i < num_valid; i += TPB) { - const int index = offset + i; - if (thread_data_max < float(input[index])) { - thread_data_max = float(input[index]); + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + if (thread_data_max < float(input[index])) { + thread_data_max = float(input[index]); + } } } @@ -94,10 +98,12 @@ __device__ inline void Softmax(const int past_sequence_length, const int sequenc __syncthreads(); float thread_data_sum(0.f); - for (int i = threadIdx.x; i < num_valid; i += TPB) { - const int index = offset + i; - const float val = input[index]; - thread_data_sum += expf(val - max_block); + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + const float val = input[index]; + thread_data_sum += expf(val - max_block); + } } const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum()); @@ -108,13 +114,19 @@ __device__ inline void Softmax(const int past_sequence_length, const int sequenc for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { const int index = offset + i; - const float val = (i < num_valid) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f; + const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } } template -__device__ inline void SoftmaxSmall(const int past_sequence_length, const int sequence_length, const int valid_length, const T* input, T* output, bool is_unidirectional) { +__device__ inline void SoftmaxSmall(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* input, + T* output, + bool is_unidirectional) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; @@ -122,22 +134,32 @@ __device__ inline void SoftmaxSmall(const int past_sequence_length, const int se __shared__ float max_block; // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int all_sequence_length = past_sequence_length + sequence_length; const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; const int index = offset + threadIdx.x; - const int num_valid = is_unidirectional ? past_sequence_length + (blockIdx.x % sequence_length) + 1 : valid_length; + bool is_valid = false; // whether it has attention mask == 1. + + // Update end position for unidirectional. + int end = valid_end; + if (is_unidirectional) { + int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + if (end_unid <= valid_start) { + // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. + // So [0, end_unid) will also have value after softmax. + is_valid = threadIdx.x < end_unid; + } else { + end = min(valid_end, end_unid); + } + } + + is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end); // e^x is represented as infinity if x is large enough, like 100.f. // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. // a math transform as below is leveraged to get a stable softmax: // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float thread_data_max(-CUDART_INF_F); - if (threadIdx.x < num_valid) { - thread_data_max = input[index]; - } - - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), num_valid); + float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F); + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); // Store max value if (threadIdx.x == 0) { @@ -146,106 +168,239 @@ __device__ inline void SoftmaxSmall(const int past_sequence_length, const int se __syncthreads(); float thread_data_exp(0.f); - if (threadIdx.x < num_valid) { - const float val = input[index]; - thread_data_exp = expf(val - max_block); + if (is_valid) { + thread_data_exp = expf(float(input[index]) - max_block); } - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), num_valid); + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end); - // Store max value + // Store value of 1.0/sum. if (threadIdx.x == 0) { - sum_reverse_block = (num_valid == 0) ? 0.f : (1.f) / sum; + sum_reverse_block = (1.f) / sum; } __syncthreads(); // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. if (threadIdx.x < all_sequence_length) { - // this will be 0 for threadIdx.x >= num_valid output[index] = T(thread_data_exp * sum_reverse_block); } } template -__global__ void SoftmaxKernelSmall(const int past_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) { - SoftmaxSmall(past_sequence_length, sequence_length, sequence_length, input, output, is_unidirectional); +__device__ inline void SoftmaxWithMask2DSmall(const int all_sequence_length, + const int sequence_length, + const int* attention_mask, // 2D attention mask + const T* input, + T* output, + const bool is_unidirectional, + const float scalar) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + + float thread_data = -CUDART_INF_F; + if (threadIdx.x < all_sequence_length) { + const int& mask = attention_mask[blockIdx.y * all_sequence_length + threadIdx.x]; + float mask_value = mask > 0 ? 0.0f : -10000.0f; + + if (is_unidirectional) { + int from_index = all_sequence_length - sequence_length + (blockIdx.x % sequence_length); // offset of from token in all sequence length. + if (threadIdx.x > from_index) { + mask_value += -10000.0f; + } + } + + thread_data = float(input[index]) * scalar + mask_value; + } + + const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length); + + // Store value of 1.0/sum + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } } template -__global__ void SoftmaxKernel(const int past_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) { - Softmax(past_sequence_length, sequence_length, sequence_length, input, output, is_unidirectional); +__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional); +} + +template +__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) { + Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, input, output); } template bool ComputeSoftmax( - cudaStream_t stream, const int past_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, const T* input, T* output, bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (sequence_length <= 32) { + if (all_sequence_length <= 32) { const int blockSize = 32; - SoftmaxKernelSmall<<>>(past_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (sequence_length <= 128) { + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 128) { const int blockSize = 128; - SoftmaxKernelSmall<<>>(past_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (sequence_length == 384) { - const int blockSize = 384; - SoftmaxKernelSmall<<>>(past_sequence_length, sequence_length, input, output, is_unidirectional); - } else { + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 256) { const int blockSize = 256; - SoftmaxKernel<<>>(past_sequence_length, sequence_length, input, output, is_unidirectional); + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (!is_unidirectional) { + const int blockSize = 1024; + SoftmaxKernel<<>>(all_sequence_length, sequence_length, input, output); + } else { + ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); } return CUDA_CALL(cudaPeekAtLastError()); } template -__global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) { - __shared__ int num_valid; +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) { + __shared__ int start_position; + __shared__ int end_position; if (threadIdx.x == 0) { - num_valid = min(sequence_length, mask_index[blockIdx.y]); + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } } __syncthreads(); - SoftmaxSmall(0, sequence_length, num_valid, input, output, false); + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional); } template -__global__ void MaskedSoftmaxKernel(const int sequence_length, const int* mask_index, const T* input, T* output) { - __shared__ int num_valid; +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) { + __shared__ int start_position; + __shared__ int end_position; if (threadIdx.x == 0) { - num_valid = min(sequence_length, mask_index[blockIdx.y]); + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } } __syncthreads(); - Softmax(0, sequence_length, num_valid, input, output, false); + Softmax(all_sequence_length, sequence_length, end_position, start_position, input, output); +} + +template +__global__ void SoftmaxWithMask2DSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) { + SoftmaxWithMask2DSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); } template -bool ComputeMaskedSoftmax(cudaStream_t stream, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const T* input, T* output) { - // Mask is of length batch_size and assumes the valid region is contiguous starting - // from the beginning of the sequence - +bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) { const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (sequence_length <= 32) { + if (all_sequence_length <= 32) { const int blockSize = 32; MaskedSoftmaxKernelSmall - <<>>(sequence_length, mask_index, input, output); - } else if (sequence_length <= 128) { + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 128) { const int blockSize = 128; MaskedSoftmaxKernelSmall - <<>>(sequence_length, mask_index, input, output); - } else if (sequence_length == 384) { - const int blockSize = 384; - MaskedSoftmaxKernelSmall - <<>>(sequence_length, mask_index, input, output); - } else { + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 256) { const int blockSize = 256; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (!is_unidirectional) { + const int blockSize = 1024; MaskedSoftmaxKernel - <<>>(sequence_length, mask_index, input, output); + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output); + } else { + ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + +template +bool ComputeSoftmaxWithMask2D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + if (all_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxWithMask2DSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxWithMask2DSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxWithMask2DSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxWithMask2DSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxWithMask2DSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxWithMask2DSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); + } else { + ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024."); } return CUDA_CALL(cudaPeekAtLastError()); @@ -389,7 +544,7 @@ __global__ void ConcatPastToPresent(const int sequence_length, const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; - const int is_v = blockIdx.z; // 0 for k, 1 for v + const int is_v = blockIdx.z; // 0 for k, 1 for v const int all_sequence_length = gridDim.x; const int batch_size = gridDim.y; @@ -409,7 +564,7 @@ __global__ void ConcatPastToPresent(const int sequence_length, const int past_NSH = num_heads * past_SH; const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size); present[out_offset] = past[in_offset]; -} else if (s < all_sequence_length) { + } else if (s < all_sequence_length) { const int SH = sequence_length * H; const int NSH = num_heads * SH; const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size); @@ -418,7 +573,7 @@ __global__ void ConcatPastToPresent(const int sequence_length, } bool LaunchConcatPastToPresent(cudaStream_t stream, - const int past_sequence_length, + const int all_sequence_length, const int sequence_length, const int batch_size, const int head_size, @@ -426,13 +581,11 @@ bool LaunchConcatPastToPresent(cudaStream_t stream, const float* past, const float* k_v, float* present) { - const int all_sequence_length = past_sequence_length + sequence_length; const dim3 grid(all_sequence_length, batch_size, 2); if (0 == (head_size & 1)) { const dim3 block(head_size / 2, num_heads, 1); ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); - } else - { + } else { const dim3 block(head_size, num_heads, 1); ConcatPastToPresent<<>>(sequence_length, past, k_v, present); } @@ -440,7 +593,7 @@ bool LaunchConcatPastToPresent(cudaStream_t stream, } bool LaunchConcatPastToPresent(cudaStream_t stream, - const int past_sequence_length, + const int all_sequence_length, const int sequence_length, const int batch_size, const int head_size, @@ -448,14 +601,13 @@ bool LaunchConcatPastToPresent(cudaStream_t stream, const half* past, const half* k_v, half* present) { - const int all_sequence_length = past_sequence_length + sequence_length; const dim3 grid(all_sequence_length, batch_size, 2); if (0 == (head_size % 4)) { const dim3 block(head_size / 4, num_heads, 1); ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); } else if (0 == (head_size & 1)) { const dim3 block(head_size / 2, num_heads, 1); - ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); + ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. const dim3 block(head_size, num_heads, 1); ConcatPastToPresent<<>>(sequence_length, past, k_v, present); @@ -486,9 +638,10 @@ bool QkvToContext( cublasHandle_t& cublas, cudaStream_t stream, const int batch_size, const int sequence_length, const int num_heads, const int head_size, const size_t element_size, const T* input, T* output, T* workspace, - const int* mask_index, + const int* mask_index, const std::vector* mask_index_dims, bool is_unidirectional, int past_sequence_length, const T* past, T* present) { - const size_t bytes = ScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length); + const int all_sequence_length = past_sequence_length + sequence_length; + const size_t bytes = ScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); T* scratch1 = workspace; T* scratch2 = scratch1 + (bytes / element_size); T* scratch3 = scratch2 + (bytes / element_size); @@ -513,9 +666,9 @@ bool QkvToContext( // Concat past (2xBxNxS'xH) to present (2xBxNxS*xH): // past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH) // past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH) - const int present_size_per_batch = (past_sequence_length + sequence_length) * head_size; + const int present_size_per_batch = all_sequence_length * head_size; if (nullptr != present) { - if (!LaunchConcatPastToPresent(stream, past_sequence_length, sequence_length, batch_size, head_size, num_heads, past, k, present)) { + if (!LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads, past, k, present)) { return false; } @@ -524,24 +677,33 @@ bool QkvToContext( v = present + batches * present_size_per_batch; } + bool use_2d_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() == 2); + // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); - const int all_sequence_length = past_sequence_length + sequence_length; const int temp_matrix_size = sequence_length * all_sequence_length; + T alpha = (T)(use_2d_attention_mask ? 1.0f : rsqrt_head_size); if (!CUBLAS_CALL(CublasGemmStridedBatched( - cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, rsqrt_head_size, k, head_size, present_size_per_batch, + cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, alpha, k, head_size, present_size_per_batch, q, head_size, size_per_batch, 0.f, scratch1, all_sequence_length, temp_matrix_size, batches))) { return false; } // apply softmax and store result P to scratch2: BxNxSxS* - if (nullptr != mask_index) { - if (!ComputeMaskedSoftmax(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) { + if (use_2d_attention_mask) { // 2d attention mask + if (!ComputeSoftmaxWithMask2D(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size)) { return false; } - } else { - if (!ComputeSoftmax(stream, past_sequence_length, sequence_length, batch_size, num_heads, scratch1, scratch2, is_unidirectional)) { + } else if (nullptr != mask_index) { // 1d mask index + ORT_ENFORCE(nullptr != mask_index_dims && mask_index_dims->size() == 1); + // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. + const int* mask_start = (mask_index_dims->at(0) > batch_size) ? mask_index + batch_size : nullptr; + if (!ComputeSoftmaxWithMask1D(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, mask_start, scratch1, scratch2, is_unidirectional)) { + return false; + } + } else { // no mask + if (!ComputeSoftmax(stream, all_sequence_length, sequence_length, batch_size, num_heads, scratch1, scratch2, is_unidirectional)) { return false; } } @@ -560,6 +722,7 @@ bool QkvToContext( bool LaunchAttentionKernel( const void* input, const int* mask_index, + const std::vector* mask_index_dims, void* output, const int batch_size, const int sequence_length, @@ -579,13 +742,13 @@ bool LaunchAttentionKernel( return QkvToContext(cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), - mask_index, is_unidirectional, + mask_index, mask_index_dims, is_unidirectional, past_sequence_length, reinterpret_cast(past), reinterpret_cast(present)); } else { return QkvToContext(cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size, reinterpret_cast(input), reinterpret_cast(output), reinterpret_cast(workspace), - mask_index, is_unidirectional, + mask_index, mask_index_dims, is_unidirectional, past_sequence_length, reinterpret_cast(past), reinterpret_cast(present)); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 6e58e73072..8a4ecffe4b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -16,20 +16,21 @@ size_t GetAttentionWorkspaceSize( int past_sequence_length); bool LaunchAttentionKernel( - const void* input, // Input tensor - const int* mask_index, // Mask index (length of each sequence). NULL means no mask. - void* output, // Output tensor - int batch_size, // Batch size (B) - int sequence_length, // Sequence length (S) - int num_heads, // Number of attention heads (N) - int head_size, // Hidden layer size per head (H) - void* workspace, // Temporary buffer - cublasHandle_t& cublas, // Cublas handle - const size_t element_size, // Element size of input tensor - bool is_unidirectional, // Whether there is unidirecitonal mask. - int past_sequence_length, // Sequence length in past state - const void* past, // Past state input - void* present // Present state output + const void* input, // Input tensor + const int* mask_index, // Attention mask raw data or index (end position of each sequence, or end positions and start positions). NULL means no mask. + const std::vector* mask_index_dims, // Mask index shape + void* output, // Output tensor + int batch_size, // Batch size (B) + int sequence_length, // Sequence length (S) + int num_heads, // Number of attention heads (N) + int head_size, // Hidden layer size per head (H) + void* workspace, // Temporary buffer + cublasHandle_t& cublas, // Cublas handle + const size_t element_size, // Element size of input tensor + bool is_unidirectional, // Whether there is unidirecitonal mask. + int past_sequence_length, // Sequence length in past state + const void* past, // Past state input + void* present // Present state output ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 974fdecf2f..08e1f6b3b4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -173,6 +173,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { if (!LaunchAttentionKernel( reinterpret_cast(gemm_buffer.get()), nullptr == mask_index ? nullptr : mask_index->template Data(), + nullptr == mask_index ? nullptr : &(mask_index->Shape().GetDims()), output->template MutableData(), batch_size, sequence_length, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index eb67c26e52..93b69f59e9 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -291,9 +291,14 @@ const char* contrib_ops_auto_pad_doc = void RegisterBertSchemas() { static const char* Attention_ver1_doc = R"DOC( -Multi-Head Self Attention that can be either unidirectional (like GPT2) or bidirectional (like BERT). -The mask_index input is optional. Unidirectional and mask_index input are mutually exclusive. When unidirectional is 1, the -mask_index shall not be provided.)DOC"; +Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). +The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length), +we also support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size), +where value of each element is the end position, or valid length of actual sequence excluding padding. When input has +left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by +the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past +and present state are optional. Present state could appear in output even when past state is not in input. +)DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(Attention) .SetDomain(kMSDomain) @@ -308,7 +313,7 @@ mask_index shall not be provided.)DOC"; .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") - .Input(3, "mask_index", "Attention mask index with shape (batch_size).", "M", OpSchema::Optional) + .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) .Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T") .Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional) diff --git a/onnxruntime/python/tools/transformers/benchmark_gpt2.py b/onnxruntime/python/tools/transformers/benchmark_gpt2.py index c1af86536f..a02fffb78c 100644 --- a/onnxruntime/python/tools/transformers/benchmark_gpt2.py +++ b/onnxruntime/python/tools/transformers/benchmark_gpt2.py @@ -75,7 +75,8 @@ def pytorch_inference(model, inputs, total_runs=100): input_ids, past, attention_mask, position_ids = inputs # Convert it back to fp32 as the PyTroch model cannot deal with half input. - attention_mask = attention_mask.to(dtype=torch.float32) if attention_mask else None + if attention_mask is not None: + attention_mask = attention_mask.to(dtype=torch.float32) past = [p.to(dtype=torch.float32) for p in past] latency = [] @@ -120,29 +121,37 @@ def onnxruntime_inference(ort_session, inputs, total_runs=100): return ort_outputs, average_latency -def get_dummy_inputs(batch_size, past_sequence_length, num_attention_heads, hidden_size, num_layer, vocab_size, device, - use_attention_mask, float16): +def get_dummy_inputs(batch_size, past_sequence_length, sequence_length, num_attention_heads, hidden_size, num_layer, + vocab_size, device, use_attention_mask, float16): float_type = torch.float16 if float16 else torch.float32 past_shape = [2, batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)] dummy_past = [torch.rand(past_shape, dtype=float_type, device=device) for _ in range(num_layer)] - dummy_input_ids = torch.randint(low=0, high=vocab_size - 1, size=(batch_size, 1), dtype=torch.int64, device=device) + dummy_input_ids = torch.randint(low=0, + high=vocab_size - 1, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=device) if use_attention_mask: - dummy_attention_mask = torch.ones([batch_size, 1], dtype=float_type, device=device) - dummy_position_ids = torch.ones([batch_size, 1], dtype=torch.int64, device=device) * past_sequence_length + dummy_attention_mask = torch.ones([batch_size, past_sequence_length + sequence_length], + dtype=float_type, + device=device) + dummy_position_ids = torch.ones([batch_size, sequence_length], dtype=torch.int64, + device=device) * past_sequence_length return dummy_input_ids, dummy_past, dummy_attention_mask, dummy_position_ids return dummy_input_ids, dummy_past, None, None -def get_output_shapes(batch_size, past_sequence_length, config, use_LMHead): +def get_output_shapes(batch_size, past_sequence_length, sequence_length, config, use_LMHead): num_attention_heads = config.num_attention_heads hidden_size = config.hidden_size num_layer = config.n_layer vocab_size = config.vocab_size - last_state_shape = [batch_size, 1, vocab_size] if use_LMHead else [batch_size, 1, hidden_size] + last_state_shape = [batch_size, sequence_length, vocab_size + ] if use_LMHead else [batch_size, sequence_length, hidden_size] present_state_shape = [ - 2, batch_size, num_attention_heads, past_sequence_length + 1, + 2, batch_size, num_attention_heads, past_sequence_length + sequence_length, int(hidden_size / num_attention_heads) ] @@ -327,7 +336,7 @@ def parse_arguments(): parser.add_argument('-b', '--batch_sizes', nargs='+', type=int, default=[1], help="batch size") parser.add_argument('-s', - '--sequence_lengths', + '--past_sequence_lengths', nargs='+', type=int, default=[8, 16, 32, 64, 128, 256], @@ -375,23 +384,24 @@ def export_onnx(model, config, tokenizer, device, output_dir, use_LMHead, use_at # Shape of input tensors: # input_ids: (batch_size, seq_len) # past_{i}: (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads) - # attention_mask: (batch_size, seq_len) + # attention_mask: (batch_size, past_seq_len + seq_len) # Shape of output tensors: # last_state: (batch_size, seq_len, hidden_size) # or prediction_scores: (batch_size, seq_len, vocab_size) - # present_{i}: (2, batch_size, num_heads, all_seq_len, hidden_size/num_heads) + # present_{i}: (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads) dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'}, output_names[0]: {0: 'batch_size', 1: 'seq_len'}} for name in past_names: dynamic_axes[name] = {1: 'batch_size', 3: 'past_seq_len'} for name in present_names: - dynamic_axes[name] = {1: 'batch_size', 3: 'all_seq_len'} + dynamic_axes[name] = {1: 'batch_size', 3: 'past_seq_len+seq_len'} if use_attention_mask: - dynamic_axes['attention_mask'] = {0: 'batch_size', 1: 'all_seq_len'} + dynamic_axes['attention_mask'] = {0: 'batch_size', 1: 'past_seq_len+seq_len'} dynamic_axes['position_ids'] = {0: 'batch_size', 1: 'seq_len'} dummy_inputs = get_dummy_inputs(batch_size=1, past_sequence_length=1, + sequence_length=1, num_attention_heads=config.num_attention_heads, hidden_size=config.hidden_size, num_layer=num_layer, @@ -469,17 +479,15 @@ def main(): if not os.path.exists(output_dir): os.makedirs(output_dir) - use_torchscript = False - model_class = MyGPT2LMHeadModel if args.model_class == 'GPT2LMHeadModel' else MyGPT2Model use_LMHead = (args.model_class == 'GPT2LMHeadModel') model_name = args.model_name - config = AutoConfig.from_pretrained(model_name, torchscript=use_torchscript, cache_dir=cache_dir) + config = AutoConfig.from_pretrained(model_name, torchscript=False, cache_dir=cache_dir) model = model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir) tokenizer = GPT2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir) - #if use_torchscript: - # model = torch.jit.trace(model, (input_ids, past)) + + # This scirpt does not support float16 for PyTorch. #if args.float16: # model.half() @@ -521,10 +529,14 @@ def main(): if session is None: return + # One word is generated for each inference. This length does not include that of past state. + sequence_length = 1 + # Allocate output buffers for IO Binding output_buffers = {} if not args.disable_ort_io_binding: - max_output_shapes = get_output_shapes(max(args.batch_sizes), max(args.sequence_lengths), config, use_LMHead) + max_output_shapes = get_output_shapes(max(args.batch_sizes), max(args.past_sequence_lengths), sequence_length, + config, use_LMHead) output_buffers = get_output_buffers(max_output_shapes, device, args.float16) csv_filename = args.result_csv or "benchmark_result_{}.csv".format(datetime.now().strftime("%Y%m%d-%H%M%S")) @@ -537,12 +549,12 @@ def main(): csv_writer.writeheader() for batch_size in args.batch_sizes: - for past_sequence_length in args.sequence_lengths: + for past_sequence_length in args.past_sequence_lengths: logger.debug(f"Running test for batch_size={batch_size} past_sequence_length={past_sequence_length}...") - dummy_inputs = get_dummy_inputs(batch_size, past_sequence_length, config.num_attention_heads, - config.hidden_size, config.n_layer, config.vocab_size, device, - args.use_attention_mask, args.float16) - output_shapes = get_output_shapes(batch_size, past_sequence_length, config, use_LMHead) + dummy_inputs = get_dummy_inputs(batch_size, past_sequence_length, sequence_length, + config.num_attention_heads, config.hidden_size, config.n_layer, + config.vocab_size, device, args.use_attention_mask, args.float16) + output_shapes = get_output_shapes(batch_size, past_sequence_length, sequence_length, config, use_LMHead) try: latencies = inference(model, diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index 31d7b1ebff..9089f8c5ad 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -20,12 +20,13 @@ class FusionGptAttention(Fusion): def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") self.num_heads = num_heads + self.utils = FusionUtils(model) + self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32 - def create_attention_node(self, gemm, gemm_qkv, past, present, input, output): + def create_attention_node(self, gemm, gemm_qkv, past, present, input, output, mask=''): attention_node_name = self.model.create_node_name('GptAttention') - mask_index = '' attention_node = helper.make_node('Attention', - inputs=[input, gemm.input[1], gemm.input[2], mask_index, past], + inputs=[input, gemm.input[1], gemm.input[2], mask, past], outputs=[attention_node_name + "_output", present], name=attention_node_name) attention_node.domain = "com.microsoft" @@ -114,6 +115,7 @@ class FusionGptAttention(Fusion): logger.debug("Add and LayerNormalization shall have one same input") return + input_mask_nodes = None qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes @@ -122,7 +124,7 @@ class FusionGptAttention(Fusion): ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable if mask_nodes is None: - logger.debug("fuse_attention: failed to match mask path") + logger.debug("fuse_attention: failed to match unidirectional mask path") return div_mask = mask_nodes[-1] @@ -131,11 +133,27 @@ class FusionGptAttention(Fusion): return else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]) + i, qk_nodes, _ = self.model.match_parent_paths( + matmul_qkv, [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]), + (['Softmax', 'Add', 'Where', 'Div', 'MatMul'], [0, 0, 0, 1, 0])], output_name_to_node) if qk_nodes is None: - logger.debug("fuse_attention: failed to match qk path") + logger.debug("fuse_attention: failed to match qk nodes") return - (softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes + + where_qk = qk_nodes[-3] + div_qk = qk_nodes[-2] + matmul_qk = qk_nodes[-1] + + if i == 1: + add_qk = qk_nodes[1] + _, input_mask_nodes, _ = self.model.match_parent_paths( + add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [1, 0, 1, 0, 0, 0]), + (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [1, 0, 1, 0, 0])], + output_name_to_node) + if input_mask_nodes is None: + logger.debug("fuse_attention: failed to match input attention mask path") + return + mask_nodes = self.model.match_parent_path( where_qk, ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], @@ -188,8 +206,20 @@ class FusionGptAttention(Fusion): logger.info("expect past to be same") return + attention_mask_input_name = '' + if input_mask_nodes is not None: + input_name = input_mask_nodes[-1].input[0] + if input_name in self.casted_attention_mask: + attention_mask_input_name = self.casted_attention_mask[input_name] + elif self.model.find_graph_input(input_name): + casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(input_name) + self.casted_attention_mask[input_name] = attention_mask_input_name + else: + attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(input_name) + self.casted_attention_mask[input_name] = attention_mask_input_name + self.create_attention_node(gemm, gemm_qkv, past, present, layernorm_before_attention.output[0], - reshape_qkv.output[0]) + reshape_qkv.output[0], attention_mask_input_name) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 897401c197..831505ad0e 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "gtest/gtest.h" @@ -8,12 +9,17 @@ namespace onnxruntime { namespace test { +enum MaskIndexType { + kMaskIndexEnd = 0, + kMaskIndexEndAndStart, + kMaskRaw +}; static void RunAttentionTest( const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] const std::vector& weights_data, // weights: [hidden_size, 3 * hidden_size] const std::vector& bias_data, // bias: [3 * hidden_size] - const std::vector& mask_index_data, // mask_index: [batch_size] or empty + const std::vector& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty const std::vector& output_data, // output: [batch_size, sequence_length, hidden_size] int batch_size, int sequence_length, @@ -23,14 +29,14 @@ static void RunAttentionTest( bool is_unidirectional = false, bool use_past_state = false, int past_sequence_length = 0, - int head_size = 0, const std::vector* past_data = nullptr, - const std::vector* present_data = nullptr) { + const std::vector* present_data = nullptr, + MaskIndexType mask_index_type = kMaskIndexEnd) { int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16; - + int head_size = hidden_size / number_of_heads; if (enable_cpu || enable_cuda) { OpTester tester("Attention", 1, onnxruntime::kMSDomain); tester.AddAttribute("num_heads", static_cast(number_of_heads)); @@ -39,9 +45,14 @@ static void RunAttentionTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector weights_dims = {hidden_size, 3 * hidden_size}; std::vector bias_dims = {3 * hidden_size}; - std::vector mask_index_dims = {batch_size}; - std::vector past_dims = {2, batch_size, head_size, past_sequence_length, head_size}; - std::vector present_dims = {2, batch_size, head_size, past_sequence_length + sequence_length, head_size}; + + std::vector mask_index_dims_1 = {batch_size}; + std::vector mask_index_dims_2 = {2 * batch_size}; + std::vector mask_index_dims_3 = {batch_size, past_sequence_length + sequence_length}; + std::vector mask_index_dims = (mask_index_type == kMaskIndexEnd ? mask_index_dims_1 : (mask_index_type == kMaskIndexEndAndStart ? mask_index_dims_2 : mask_index_dims_3)); + + std::vector past_dims = {2, batch_size, number_of_heads, past_sequence_length, head_size}; + std::vector present_dims = {2, batch_size, number_of_heads, past_sequence_length + sequence_length, head_size}; std::vector output_dims = input_dims; if (use_float16) { @@ -59,8 +70,7 @@ static void RunAttentionTest( if (mask_index_data.size() > 0) { // mask index is optional. tester.AddInput("mask_index", mask_index_dims, mask_index_data); } else { - std::vector dims = {static_cast(mask_index_data.size())}; - tester.AddInput("", dims, mask_index_data); + tester.AddMissingOptionalInput(); } if (use_past_state) { @@ -451,10 +461,10 @@ TEST(AttentionTest, AttentionEmptyPastState) { bool is_unidirectional = true; bool use_past_state = true; int past_sequence_length = 0; - int head_size = 2; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - use_past_state, past_sequence_length, head_size, &past_data, &present_data); + use_past_state, past_sequence_length, &past_data, &present_data); } TEST(AttentionTest, AttentionPastStateBatch1) { @@ -550,10 +560,10 @@ TEST(AttentionTest, AttentionPastStateBatch1) { bool is_unidirectional = true; bool use_past_state = true; int past_sequence_length = 3; - int head_size = 2; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - use_past_state, past_sequence_length, head_size, &past_data, &present_data); + use_past_state, past_sequence_length, &past_data, &present_data); } TEST(AttentionTest, AttentionPastStateBatch2) { @@ -653,10 +663,516 @@ TEST(AttentionTest, AttentionPastStateBatch2) { bool is_unidirectional = true; bool use_past_state = true; int past_sequence_length = 3; - int head_size = 2; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, - use_past_state, past_sequence_length, head_size, &past_data, &present_data); + use_past_state, past_sequence_length, &past_data, &present_data); +} + +TEST(AttentionTest, AttentionPastStateBatch2WithPadding) { + int batch_size = 2; + int sequence_length = 1; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + -0.10902753f, 0.0041178204f, 0.1871525f, -0.20399982f, + 0.027207348f, -0.25321805f, 0.12869114f, 0.023136809f}; + + std::vector weight_data = { + -0.4738484025001526f, + -0.2613658607006073f, + -0.0978037416934967f, + -0.34988933801651f, + 0.2243240624666214f, + -0.0429205559194088f, + 0.418695330619812f, + 0.17441125214099884f, + -0.18825532495975494f, + 0.18357256054878235f, + -0.5806483626365662f, + -0.02251487597823143f, + + 0.08742205798625946f, + 0.14734269678592682f, + 0.2387014478445053f, + 0.2884027063846588f, + 0.6490834355354309f, + 0.16965825855731964f, + -0.06346885114908218f, + 0.4073973298072815f, + -0.03070945478975773f, + 0.4110257923603058f, + 0.07896808534860611f, + 0.16783113777637482f, + + 0.0038893644232302904f, + 0.06946629285812378f, + 0.36680519580841064f, + -0.07261059433221817f, + -0.14960581064224243f, + 0.020944256335496902f, + -0.09378612786531448f, + -0.1336742341518402f, + 0.06061394885182381f, + 0.2205914407968521f, + -0.03519909828901291f, + -0.18405692279338837f, + + 0.22149960696697235f, + -0.1884360909461975f, + -0.014074507169425488f, + 0.4252440333366394f, + 0.24987126886844635f, + -0.31396418809890747f, + 0.14036843180656433f, + 0.2854192554950714f, + 0.09709841012954712f, + 0.09935075044631958f, + -0.012154420837759972f, + 0.2575816512107849f}; + + std::vector bias_data = { + 0.4803391396999359f, + -0.5254325866699219f, + -0.42926454544067383f, + -0.2059524953365326f, + -0.12773379683494568f, + -0.09542735666036606f, + -0.35286077857017517f, + -0.07646317780017853f, + -0.04590314254164696f, + -0.03752850368618965f, + -0.013764488510787487f, + -0.18478283286094666f}; + + // One sequence has both left padding and right padding + std::vector mask_index_data = {4, 3, 0, 2}; + + std::vector output_data = { + 0.14902574f, 0.62273371f, 0.43022552f, 0.12759127f, + 0.18029204f, 0.07451740f, 0.73694098f, 0.17766341f}; + + std::vector past_data = { + 0.42028648f, 0.55855948f, 0.044569403f, 0.76525789f, 0.13962431f, 0.40977913f, 0.36911047f, 0.83399564f, 0.36905321f, 0.91414654f, 0.17300875f, 0.78793788f, + 0.10279467f, 0.80501258f, 0.089550517f, 0.85371113f, 0.61801594f, 0.91222942f, 0.88626182f, 0.069776468f, 0.10591964f, 0.84836882f, 0.83520192f, 0.0098680854f, + 0.3113814f, 0.63999802f, 0.28603253f, 0.98899829f, 0.044405211f, 0.95105386f, 0.81278932f, 0.63969064f, 0.14494057f, 0.11349615f, 0.87086016f, 0.20983537f, + 0.35107401f, 0.90144604f, 0.68950737f, 0.18928574f, 0.18029204f, 0.074517399f, 0.70763874f, 0.48440042f, 0.58114725f, 0.1048766f, 0.73694098f, 0.17766342f}; + + std::vector present_data = { + 0.42028648f, 0.55855948f, 0.044569403f, 0.76525789f, 0.13962431f, 0.40977913f, -0.22849128f, -0.022080801f, 0.36911047f, 0.83399564f, 0.36905321f, 0.91414654f, 0.17300875f, 0.78793788f, -0.4449589f, -0.17704415f, 0.10279467f, 0.80501258f, 0.089550517f, 0.85371113f, 0.61801594f, 0.91222942f, -0.2994619f, -0.14412443f, 0.88626182f, 0.069776468f, 0.10591964f, 0.84836882f, 0.83520192f, 0.0098680854f, -0.33421949f, -0.18547727f, + 0.3113814f, 0.63999802f, 0.28603253f, 0.98899829f, 0.044405211f, 0.95105386f, -0.033968594f, -0.034833729f, 0.81278932f, 0.63969064f, 0.14494057f, 0.11349615f, 0.87086016f, 0.20983537f, 0.045759238f, -0.26863033f, 0.35107401f, 0.90144604f, 0.68950737f, 0.18928574f, 0.18029204f, 0.074517399f, -0.033201858f, -0.10592631f, 0.70763874f, 0.48440042f, 0.58114725f, 0.1048766f, 0.73694098f, 0.17766342f, -0.054369561f, -0.24562015f}; + + bool is_unidirectional = true; + bool use_past_state = true; + int past_sequence_length = 3; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, false, is_unidirectional, + use_past_state, past_sequence_length, &past_data, &present_data, kMaskIndexEndAndStart); +} + +TEST(AttentionTest, AttentionBatch2MaskIndex2) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + std::vector mask_index_data = {2, 2, 0, 0}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f, + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); +} + +TEST(AttentionTest, AttentionRightPaddingMaskIndex2) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask_index < sequence_length + std::vector mask_index_data = {1, 0}; + + std::vector output_data = { + 8.6899995803833008f, -0.13000002503395081f, 4.25f, 5.6499996185302734f, + 8.6899995803833008f, -0.13000002503395081f, 4.2499995231628418f, 5.6499991416931152f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); +} + +TEST(AttentionTest, AttentionLeftPaddingMaskIndex2) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask start position > 0. + std::vector mask_index_data = {2, 1}; + + std::vector output_data = { + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); +} + +TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask start position > 0. + std::vector mask_index_data = {2, 2, 1, 0}; + + std::vector output_data = { + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 3.14959716796875f, 0.10843672603368759f, 4.25f, 5.65f, + 3.9696791172027588f, 0.073143675923347473f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); +} + +TEST(AttentionTest, AttentionBatch2AttentionMask) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask start position > 0. + std::vector mask_index_data = {0, 1, 1, 1}; + + std::vector output_data = { + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 3.14959716796875f, 0.10843672603368759f, 4.25f, 5.65f, + 3.9696791172027588f, 0.073143675923347473f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); +} + +TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask start position > 0. + std::vector mask_index_data = {0, 1, 1, 1}; + + std::vector output_data = { + 3.967245340f, 0.07324841f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 3.96967912f, 0.07314367f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = true; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); +} + +TEST(AttentionTest, AttentionMask1DEndNoWord) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test that all attention masks are zero. + std::vector mask_index_data = {0, 0}; + + std::vector output_data = { + 3.96724534f, 0.07324841f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.96724534f, 0.07324841f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEnd); +} + +TEST(AttentionTest, AttentionMask1DNoWord) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test that all attention masks are zero. + std::vector mask_index_data = {0, 0, 2, 2}; + + std::vector output_data = { + 3.96724534f, 0.07324841f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.96724534f, 0.07324841f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); +} + +TEST(AttentionTest, AttentionMask2DNoWord) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test that all attention masks are zero. + std::vector mask_index_data = {0, 0, 0, 0}; + + std::vector output_data = { + 3.96724534f, 0.07324841f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.96724534f, 0.07324841f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); +} + +TEST(AttentionTest, AttentionMaskIndexOutOfRange) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test end_position > sequence length, or start_position < 0 + std::vector mask_index_data = {3, 2, 0, -1}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f, + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } TEST(AttentionTest, AttentionPastState_dynamic) { From a4127fc185c25dba0c918ba9e0996805ad35a92e Mon Sep 17 00:00:00 2001 From: Faith Xu Date: Tue, 30 Jun 2020 01:51:09 -0700 Subject: [PATCH 07/13] Add stale bot (#4323) * Add stalebot * Update exemptLabels --- .github/stale.yml | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .github/stale.yml diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 0000000000..fcff8d38e6 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,22 @@ +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 60 + +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 + +# Issues with these labels will never be considered stale +exemptLabels: + - "contributions are welcome" + - documentation + - enhancement + +# Label to use when marking an issue as stale +staleLabel: wontfix + +# Comment to post when marking an issue as stale. Set to `false` to disable +markComment: > + This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details. + +# Comment to post when closing a stale issue. Set to `false` to disable +closeComment: > + This issue has been automatically closed due to inactivity. Please reactivate if further support is needed. From 89c6da99b5d229d67b465a2ecf158a4254932cd3 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 30 Jun 2020 08:21:20 -0700 Subject: [PATCH 08/13] fix output shape calc for matmul (#4362) --- onnxruntime/core/providers/cpu/math/matmul_helper.h | 2 +- onnxruntime/test/providers/cpu/math/matmul_test.cc | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/math/matmul_helper.h b/onnxruntime/core/providers/cpu/math/matmul_helper.h index fca73c2d99..2d8d6c14aa 100644 --- a/onnxruntime/core/providers/cpu/math/matmul_helper.h +++ b/onnxruntime/core/providers/cpu/math/matmul_helper.h @@ -29,7 +29,7 @@ class MatMulComputeHelper { // A: [M1, M2, ... K], B: [N, K]^T // A: [M1, M2, ... K], B: [1, ..., 1, K, N] // A: [M1, M2, ... K], B: [1, ..., 1, N, K]^T - if (!transa && left_num_dims >= 2 && right_num_dims >= 2 && + if (!transa && left_num_dims >= 2 && right_num_dims >= 2 && left_num_dims >= right_num_dims && right_shape.SizeToDimension(right_num_dims - 1) == right_shape[right_num_dims - 2]) { M_ = left_shape.SizeToDimension(left_num_dims - 1); K_ = left_shape[left_num_dims - 1]; diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 1e6bece98c..418dee24f0 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -77,6 +77,13 @@ std::vector> GenerateTestCases() {2, 2, 4}, {20, 23, 26, 29, 56, 68, 80, 92, 92, 113, 134, 155, 128, 158, 188, 218}}); + test_cases.push_back( + {"test 2D special 3", + {2, 6}, + {1, 1, 6, 1}, + {1, 1, 2, 1}, + {55, 145}}); + return test_cases; } From 4380b8ba681647786080ffc329c8dd6c189e20ac Mon Sep 17 00:00:00 2001 From: ytaous <4484531+ytaous@users.noreply.github.com> Date: Tue, 30 Jun 2020 10:29:48 -0700 Subject: [PATCH 09/13] adjust bs size (#4375) Co-authored-by: Ethan Tao --- orttraining/tools/ci_test/run_bert_perf_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/tools/ci_test/run_bert_perf_test.py b/orttraining/tools/ci_test/run_bert_perf_test.py index 14f72a2c92..d3909f36f5 100644 --- a/orttraining/tools/ci_test/run_bert_perf_test.py +++ b/orttraining/tools/ci_test/run_bert_perf_test.py @@ -26,7 +26,7 @@ def main(): Config = namedtuple('Config', ['use_mixed_precision', 'max_seq_length', 'batch_size', 'max_predictions_per_seq']) configs = [ - Config(True, 128, 66, 20), + Config(True, 128, 64, 20), Config(True, 512, 10, 80), Config(False, 128, 33, 20), Config(False, 512, 5, 80) From 755675541a2572f7feef27ce515673a6c478134d Mon Sep 17 00:00:00 2001 From: Tracy Sharpe <42477615+tracysh@users.noreply.github.com> Date: Tue, 30 Jun 2020 10:50:58 -0700 Subject: [PATCH 10/13] NCHWc + Sigmoid optimization (#4360) Add support to avoid reordering NCHWc tensors due to the Swish activation (x * sigmoid(x)) in EfficientNet/EfficientDet models. --- .../core/optimizer/nchwc_transformer.cc | 7 ++-- .../test/optimizer/nchwc_optimizer_test.cc | 35 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/optimizer/nchwc_transformer.cc b/onnxruntime/core/optimizer/nchwc_transformer.cc index 6da13740c1..f00709dcbc 100644 --- a/onnxruntime/core/optimizer/nchwc_transformer.cc +++ b/onnxruntime/core/optimizer/nchwc_transformer.cc @@ -640,7 +640,8 @@ void NchwcTransformerImpl::TransformConcat(Node& node) { } // After doing a Conv/Add fusion, there may be an activation node that could now -// be fused into the Conv node as well. +// be fused into the Conv node as well. Otherwise, this is an elementwise +// operation that can directly use the NCHWc input. void NchwcTransformerImpl::TransformActivation(Node& node) { auto& input_defs = node.MutableInputDefs(); @@ -943,7 +944,9 @@ void NchwcTransformerImpl::Transform(Node& node) { TransformBinary(node, false); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Concat", {4, 11})) { TransformConcat(node); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6})) { + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Relu", {6}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Sigmoid", {6}) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "Tanh", {6})) { TransformActivation(node); } else if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "BatchNormalization", {7, 9})) { TransformBatchNormalization(node); diff --git a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc index f51cc59205..64ba6f5e2a 100644 --- a/onnxruntime/test/optimizer/nchwc_optimizer_test.cc +++ b/onnxruntime/test/optimizer/nchwc_optimizer_test.cc @@ -1219,6 +1219,41 @@ TEST(NchwcOptimizerTests, Upsample) { } } +TEST(NchwcOptimizerTests, Activation) { + auto test_case = [&](const std::string& activation_op_type) { + auto build_test_case = [&](NchwcTestHelper& helper) { + auto* input_arg = helper.MakeInput({1, 48, 11, 15}); + auto* conv1_output_arg = helper.MakeIntermediate(); + auto* activation_output_arg = helper.MakeIntermediate(); + auto* mul_output_arg = helper.MakeIntermediate(); + auto* output_arg = helper.MakeOutput(); + + helper.AddConvNode(input_arg, conv1_output_arg, {32, 48, 3, 3}); + helper.AddNode(activation_op_type, {conv1_output_arg}, {activation_output_arg}); + helper.AddNode("Add", {conv1_output_arg, activation_output_arg}, {mul_output_arg}); + helper.AddConvNode(mul_output_arg, output_arg, {16, 32, 1, 1}); + }; + + auto check_nchwc_graph = [&](NchwcInferenceSession& session) { + auto op_to_count = session.CountOpsInGraph(); + EXPECT_EQ(op_to_count["nchwc.Conv"], 2); + EXPECT_EQ(op_to_count["nchwc.ReorderInput"], 1); + EXPECT_EQ(op_to_count["nchwc.ReorderOutput"], 1); + EXPECT_EQ(op_to_count[activation_op_type], 1); + EXPECT_EQ(op_to_count["Add"], 1); + }; + + NchwcOptimizerTester(build_test_case, check_nchwc_graph); + }; + + // Verify that the optimizer doesn't add reorders for these activations that + // cannot be fused with a convolution. + std::vector activation_op_types{"Relu", "Sigmoid", "Tanh"}; + for (auto& activation_op_type : activation_op_types) { + test_case(activation_op_type); + } +} + #endif } // namespace test From 37b624b68887158636a308eb366f2a89cd4b0601 Mon Sep 17 00:00:00 2001 From: Cecilia Liu Date: Tue, 30 Jun 2020 13:12:50 -0700 Subject: [PATCH 11/13] Match More EmbedLayerNormalization Patterns for Bert Model Graph Fusion (#4354) match more embed patterns for bert base cased --- .../core/optimizer/embed_layer_norm_fusion.cc | 250 ++++++++---------- .../test/optimizer/graph_transform_test.cc | 29 ++ .../fusion/embed_layer_norm_format6.onnx | Bin 0 -> 2461 bytes .../transform/fusion/embed_layer_norm_gen.py | 79 ++++++ 4 files changed, 220 insertions(+), 138 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx diff --git a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc index 8136e39ed7..a1bec8a3ca 100644 --- a/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/embed_layer_norm_fusion.cc @@ -103,24 +103,24 @@ static void AddNodes(std::vector& node_indices, It is because they are matched as part of other subgraph. */ -static bool MatchPositionSubgraph( +static bool MatchInputToConcatSubgraph( Graph& graph, - const Node& expand_node, + const Node& cur_node, const NodeArg* input_ids, + const int index, const logging::Logger& logger, std::vector& subgraph_node_indices, const NodeIndex expected_gather_node_1_index) { subgraph_node_indices.clear(); - std::vector expand_parent_path1{ - {0, 1, "Concat", {4, 11}, kOnnxDomain}, + {0, index, "Concat", {4, 11}, kOnnxDomain}, {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, {0, 0, "Gather", {1, 11}, kOnnxDomain}, {0, 0, "Shape", {1}, kOnnxDomain}, }; std::vector edges; - if (!graph_utils::FindPath(expand_node, true, expand_parent_path1, edges, logger)) { + if (!graph_utils::FindPath(cur_node, true, expand_parent_path1, edges, logger)) { DEBUG_LOG("Failed to find path 1 of position shape."); return false; } @@ -184,60 +184,80 @@ static bool MatchPositionSubgraph( } /** Match subgraph like the following: - (input_ids) - / \ - Shape Shape - | | - ^Gather (indice=0)^ Gather (indice=1)--+ - ^|^ ^|^ | - ^Unsqueeze^ ^Unsqueeze^ Unsqueeze - ^\^ ^/^ | - ^\^ ^/^ ConstantOfShape - ^\^ ^/^ | - ^Concat^ NonZero - | | - | Transpose - | | - | Squeeze - | | - | Cast - | | - | Unsqueeze - +--|----------------------------+ - | | - Expand - | - Gather - - Note that position gather node is the node in the bottom of above sub-graph. - Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. -*/ -static bool MatchPositionEmbeddingSubgraph1( + * + * Shape -> ^Gather (indice=0)^ -> ^Unsqueeze^ + * / | +-----------------------+ + * / v | | + * [input_ids] ^Concat^ -> *Reshape* -> *Equal* -> *Where* -> Expand -> Gather + * \ | | ("position") + * Shape -> ^Gather (indice=1)^ -> ^Unsqueeze^ | + * | | + * +-------------- # one of the below subgraph patterns # ---------------+ + * # Unsqueeze -> ConstantOfShape -> NonZero -> Transpose -> Squeeze -> (Cast) -> Unsqueeze # + * # or # + * # (Cast (to=7)) -> Range (start=0, delta=1) -> Unsqueeze # + * + * Note that position gather node is the node in the bottom of above sub-graph. + * Paths in ^^ are alternative path to be matched if path input_ids -> Shape -> Expand -> Gather is not found. + * Path in ** is an alternative path to check. + */ +static bool MatchPositionEmbeddingSubgraphsFromGather( Graph& graph, const Node& position_gather_node, const NodeArg* input_ids, const logging::Logger& logger, std::vector& subgraph_node_indices) { subgraph_node_indices.clear(); - std::vector pg_edges; // Look for Path 1: - // Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze --> Cast --> Unsqueeze --> Expand --> Gather - if (!graph_utils::FindPath(position_gather_node, true, - {{0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Cast", {9}, kOnnxDomain}, - {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, - {0, 0, "Transpose", {1}, kOnnxDomain}, - {0, 0, "NonZero", {9}, kOnnxDomain}, - {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, - {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, - {0, 0, "Gather", {1, 11}, kOnnxDomain}, - {0, 0, "Shape", {1}, kOnnxDomain}}, - pg_edges, logger)) { + // Shape --> Gather --> Unsqueeze --> ConstantOfShape --> NonZero --> Transpose --> Squeeze + // --> Cast --> Unsqueeze --> Expand --> Gather + std::vector parent_path_1{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Cast", {9}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, + {0, 0, "Transpose", {1}, kOnnxDomain}, + {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Look for Path 2 (Path 1 with no cast): + std::vector parent_path_2{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Squeeze", {1, 11}, kOnnxDomain}, + {0, 0, "Transpose", {1}, kOnnxDomain}, + {0, 0, "NonZero", {9}, kOnnxDomain}, + {0, 0, "ConstantOfShape", {9}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Path 3 Pattern: + // Shape -> Gather -> Cast (to=7) -> Range (start=0, delta=1) -> Unsqueeze -> Expand + std::vector parent_path_3{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Range", {1, 11}, kOnnxDomain}, + {0, 1, "Cast", {9}, kOnnxDomain}, + {0, 0, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Path 4 pattern (Path 3 with no "Cast"): + std::vector parent_path_4{ + {0, 1, "Expand", {8}, kOnnxDomain}, + {0, 0, "Unsqueeze", {1, 11}, kOnnxDomain}, + {0, 0, "Range", {1, 11}, kOnnxDomain}, + {0, 1, "Gather", {1, 11}, kOnnxDomain}, + {0, 0, "Shape", {1}, kOnnxDomain}}; + // Match one of the three path patterns. + if (!graph_utils::FindPath(position_gather_node, true, parent_path_1, pg_edges, logger) && + !graph_utils::FindPath(position_gather_node, true, parent_path_2, pg_edges, logger) && + !graph_utils::FindPath(position_gather_node, true, parent_path_3, pg_edges, logger) && + !graph_utils::FindPath(position_gather_node, true, parent_path_4, pg_edges, logger)) { return false; } - const size_t gather_index = 8; + const size_t gather_index = pg_edges.size() - 2; // All nodes in Path 1 must have only 1 output edge, except the gather node allowed 1 or 2 output edges for (size_t i = 0; i < pg_edges.size(); i++) { if (!optimizer_utils::CheckOutputEdges(graph, pg_edges[i]->GetNode(), 1)) { @@ -251,6 +271,18 @@ static bool MatchPositionEmbeddingSubgraph1( Node& expand_node = *graph.GetNode(pg_edges[0]->GetNode().Index()); Node& gather_node = *graph.GetNode(pg_edges[gather_index]->GetNode().Index()); + if (pg_edges[2]->GetNode().OpType() == "Range") { + // Check if the values in "start" and "delta" attributes in Range are expected. + Node& range_node = *graph.GetNode(pg_edges[2]->GetNode().Index()); + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) { + DEBUG_LOG("The first input of Range should be a constant with value 0."); + return false; + } + if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) { + DEBUG_LOG("The third input of Range should be a constant with value 1."); + return false; + } + } if (gather_node.GetOutputEdgesCount() == 1) { // Check if the second input of the Gather node in the path has a constant input of 1 @@ -279,7 +311,35 @@ static bool MatchPositionEmbeddingSubgraph1( subgraph_node_indices.push_back(shape_node_index); } else { // gather_output_edges_count == 2 - if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node.Index())) { + // Match optional Reshape -> Equal -> Where -> Expand + // | | + // -------------------- + std::vector pg_edges_2; + std::vector path_to_match_1{ + {0, 1, "Where", {9}, kOnnxDomain}, + {0, 0, "Equal", {1, 11}, kOnnxDomain}, + {0, 0, "Reshape", {5}, kOnnxDomain}}; + if (graph_utils::FindPath(expand_node, true, path_to_match_1, pg_edges_2, logger)) { + if (!optimizer_utils::CheckOutputEdges(graph, pg_edges_2[0]->GetNode(), 1) || + !optimizer_utils::CheckOutputEdges(graph, pg_edges_2[1]->GetNode(), 1) || + !optimizer_utils::CheckOutputEdges(graph, pg_edges_2[2]->GetNode(), 2)) { + DEBUG_LOG("Optional position subgraph nodes number of outputs unexpected."); + return false; + } + Node& where_node = *graph.GetNode(pg_edges_2[0]->GetNode().Index()); + Node& reshape_node = *graph.GetNode(pg_edges_2[2]->GetNode().Index()); + if (where_node.MutableInputDefs()[2] != reshape_node.MutableOutputDefs()[0]) { + DEBUG_LOG("Optional position subgraph nodes Where node is expected to be the parent of Reshape."); + return false; + } + // Match [input_ids] -> Gather -> Shape -> Unsqueeze from Reshape node. + if (!MatchInputToConcatSubgraph(graph, reshape_node, input_ids, 0, logger, subgraph_node_indices, gather_node.Index())) { + DEBUG_LOG("Failed to match position subgraph."); + return false; + } + AddNodes(subgraph_node_indices, pg_edges_2); + } else if (!MatchInputToConcatSubgraph(graph, expand_node, input_ids, 1, logger, subgraph_node_indices, gather_node.Index())) { + // Match [input_ids] -> Gather -> Shape -> Unsqueeze from Expand node. DEBUG_LOG("Failed to match position subgraph."); return false; } @@ -290,90 +350,6 @@ static bool MatchPositionEmbeddingSubgraph1( return true; } -/** Match subgraph like the following: - (input_ids) - / \ - Shape Shape - | | - Gather (indice=0) Gather (indice=1)--+ - | | | - Unsqueeze Unsqueeze Cast(to=7) (Cast is optional) - \ / | - \ / Range(start=0, delta=1) - \ / | - Concat Unsqueeze - | | - +--|----------------------------+ - | | - Expand - | - Gather - - Note that position gather node is the node in the bottom of above sub-graph. -*/ - -static bool MatchPositionEmbeddingSubgraph2( - Graph& graph, - const Node& position_gather_node, - const NodeArg* input_ids, - const logging::Logger& logger, - std::vector& subgraph_node_indices) { - subgraph_node_indices.clear(); - - // Match Gather <-- Expand <-- Unsqueeze <-- Range <-- Cast <-- Gather - // Since Range is from opset 11, we only match opset 11 here. - std::vector position_parent_nodes; - std::vector position_embedding_path_symbolic{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {11}, kOnnxDomain}, - {0, 0, "Range", {11}, kOnnxDomain}, - {0, 1, "Cast", {9}, kOnnxDomain}, - {0, 0, "Gather", {11}, kOnnxDomain}}; - std::vector edges; - - if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_symbolic, edges, logger)) { - // Cast node might be removed by other optimizer. Here we check a pattern without Cast node. - std::vector position_embedding_path_no_cast{ - {0, 1, "Expand", {8}, kOnnxDomain}, - {0, 0, "Unsqueeze", {11}, kOnnxDomain}, - {0, 0, "Range", {11}, kOnnxDomain}, - {0, 1, "Gather", {11}, kOnnxDomain}}; - if (!graph_utils::FindPath(position_gather_node, true, position_embedding_path_no_cast, edges, logger)) { - DEBUG_LOG("Failed to find path 1."); - return false; - } - } - - size_t last_edge = edges.size() - 1; - for (size_t i = 0; i < edges.size(); i++) { - if (!optimizer_utils::CheckOutputEdges(graph, edges[i]->GetNode(), (i == last_edge ? 2u : 1u))) { - DEBUG_LOG("Output edge count not expected for nodes in path 1."); - return false; - } - } - - Node& expand_node = *graph.GetNode(edges[0]->GetNode().Index()); - Node& range_node = *graph.GetNode(edges[2]->GetNode().Index()); - Node& gather_node_1 = *graph.GetNode(edges[last_edge]->GetNode().Index()); - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[0]), int64_t(0), true)) { - DEBUG_LOG("The first input of Range should be a constant with value 0."); - return false; - } - if (!optimizer_utils::IsInitializerWithExpectedValue(graph, *(range_node.InputDefs()[2]), int64_t(1), true)) { - DEBUG_LOG("The third input of Range should be a constant with value 1."); - return false; - } - - if (!MatchPositionSubgraph(graph, expand_node, input_ids, logger, subgraph_node_indices, gather_node_1.Index())) { - DEBUG_LOG("Failed to match position subgraph."); - return false; - } - - AddNodes(subgraph_node_indices, edges); - - return true; -} - static bool MatchPositionEmbeddingSubgraph( Graph& graph, const Node& add_node, @@ -422,10 +398,8 @@ static bool MatchPositionEmbeddingSubgraph( } } } else { - if (!MatchPositionEmbeddingSubgraph1(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { - if (!MatchPositionEmbeddingSubgraph2(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { - return false; - } + if (!MatchPositionEmbeddingSubgraphsFromGather(graph, position_gather_node, input_ids, logger, subgraph_node_indices)) { + return false; } } diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 94be6c3469..c59c6f34f6 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2424,6 +2424,35 @@ TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat5) { } } +TEST_F(GraphTransformationTests, EmbedLayerNormFusionFormat6) { + auto model_uri = MODEL_FOLDER "fusion/embed_layer_norm_format6.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_); + ASSERT_TRUE(ret.IsOK()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Shape"], 0); + EXPECT_EQ(op_to_count["Expand"], 0); + EXPECT_EQ(op_to_count["Gather"], 0); + EXPECT_EQ(op_to_count["Unsqueeze"], 0); + EXPECT_EQ(op_to_count["Reshape"], 0); + EXPECT_EQ(op_to_count["Equal"], 0); + EXPECT_EQ(op_to_count["Where"], 0); + EXPECT_EQ(op_to_count["LayerNormalization"], 0); + EXPECT_EQ(op_to_count["SkipLayerNormalization"], 0); + EXPECT_EQ(op_to_count["ReduceSum"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["Add"], 2); + EXPECT_EQ(op_to_count["Cast"], 3); + EXPECT_EQ(op_to_count["Attention"], 1); + EXPECT_EQ(op_to_count["EmbedLayerNormalization"], 1); +} + TEST_F(GraphTransformationTests, DynamicQuantizeMatMulTest) { auto model_uri = MODEL_FOLDER "fusion/dynamic_quantize_matmul.onnx"; std::shared_ptr p_model; diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_format6.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dd510762244fa54a7552be7bbe2f7e2ccfa24ffd GIT binary patch literal 2461 zcmb_e%Wm676cuGjB(H2)o;a=1Md40?1R0=#r6hIh)=td8Eea%A*bM??LyX0VMOY+d zaoD(7WzkRQCSBzN`WxLj-;~aSk0EWNT||(C`_w)6Njdh;pEuwZxM4icxE}^7f^Sv|N#Y6R8n&i!@3E5UWW#hrB-@7X!7zp81yfz{!#L!qUvRQLcFth?P}{&5OalOR zmeHB^bTuvgx{Acz1F3~ws8j&2MIV{+sQnajwL-Zo5_exnQ}_<-l+y+Gw~98abioeT zj8lJv+y#wKm`b~=NZdV?5_kxe%BZ<+Y*q8|wKd;1YMghxM2Uc%&uJX!a%&p*4wMXj z0C%3GK8uc7Ac|BPqJb3a`Bjck@#V1WiXtV8Aon&lXH|&}@4g`@-G73n9W`(cjs^kCR2@zh>ddl*6haIu2=yR+K=Q$V}Xl#psFjekPphEPO+Gn8ajVZoOE9 zkHyOp8$pK9WAjw#Z~pzy?d|uL1=|4nWJNSRMJw{#pU_lqms_&9^H>Op3vHfk^yG5= z3Uy;NP{!hxv2)$>ipTCD3$h72%A)cd2|PAg1?QBpc@ReV%FdylCiKrvf6Ggff+HDO z3}^&HHsu=fk|40V&$(!#tiur8IE(x#qd~gS7QlB_CP{QR3MY#sOn#=WPMKh0)v2Mn!hq{1a|h+&fsOM=ji{ze4Q!bWatF7%Shxfm;R^x1)2 zrm~~{E7?~5Eg25&9joP44+|}^D#?fpv`ukGG;evu1tIsHw$m9}c(c`d`>@q&JrZy6 zc`H9P-ZM6n96253b;+vf^W?lI+)DYTuXml$ey_DY+tYKQ46 bljE<|;;9WnpTkYZLTj}+B38$1510Q1ntJcX literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index f7b9cd30ab..97e7abfb55 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -143,7 +143,86 @@ def GenerateModel5(model_name): model = helper.make_model(graph) onnx.save(model, model_name) +def GenerateModel6(model_name): + nodes = [ # LayerNorm subgraph + helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), + helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"), + helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"), + helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"), + helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0), + + helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"), + helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"), + helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"), + + helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], + "range"), + helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + + helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"), + helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"), + helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"), + helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["word_add_pos_out"], "word_add_pos"), + helper.make_node("Gather", ["seg_embed", "segment_ids"], ["seg_gather_out"], "seg_gather"), + helper.make_node("Add", ["word_add_pos_out", "seg_gather_out"], ["add3_out"], "add3"), + helper.make_node("LayerNormalization", ["add3_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752), + helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), + helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), + helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], + "att", + domain="com.microsoft", + num_heads=2), + helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"), + helper.make_node("Add", ["matmul_out", "add_bias"], ["add_out"], "add"), + helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2") + ] + + + # hidden_size=4, num_heads=2, max_seq_length=3 + initializers = [ # initializers + helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), + helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), + helper.make_tensor('start_0', TensorProto.INT64, [], [0]), + helper.make_tensor('delta_1', TensorProto.INT64, [], [1]), + helper.make_tensor('word_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('pos_embed', TensorProto.FLOAT, [4, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('qkv_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), + helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]), + helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]), + helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]), + ] + + graph = helper.make_graph( + nodes, + "EmbedLayerNorm_format6", #name + [ # inputs + helper.make_tensor_value_info('input_ids', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info('segment_ids', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info('input_mask', TensorProto.INT64, ['batch', 3]), + ], + [ # outputs + helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, ['batch', 3, 4]), + ], + initializers) + + model = helper.make_model(graph) + onnx.save(model, model_name) GenerateModel3('embed_layer_norm_format3.onnx', True) GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) GenerateModel5('embed_layer_norm_format5.onnx') +GenerateModel6('embed_layer_norm_format6.onnx') \ No newline at end of file From 0404763f23e4d2e1b06128b3545ad7c1042a7a66 Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 30 Jun 2020 14:30:59 -0700 Subject: [PATCH 12/13] Update function body initialization for ONNX functions (#4332) * Update function body initialization * minor fix * changes per review comments * minor fix * format fix * add function initialization in mixed precision transformer * more updates * more fixes --- include/onnxruntime/core/graph/graph.h | 21 ++++++- .../core/framework/graph_partitioner.cc | 7 ++- onnxruntime/core/graph/graph.cc | 58 ++++++++++++------- .../core/graph/mixed_precision_transformer.cc | 9 ++- 4 files changed, 68 insertions(+), 27 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 5f41d7f819..69632df596 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -101,7 +101,21 @@ class Node { /** Gets the Node's Node::Type. */ Node::Type NodeType() const noexcept; - /** Gets the function body if the #NodeType is fused, or nullptr if not. */ + /** + Gets the function body if applicable otherwise nullptr + @param try_init_func_body If not already intialized, initialize the function body + (only applicable to operators which are defined as function in ONNX spec). + Function body can be initialized in 2 cases : + 1. For nodes of type "Fused" + 2. For nodes which are defined as functions in ONNX spec (example: DynamicQuantizeLinear) + For all other cases this will always return nullptr. + Nodes of type "Fused" are created during partitioning and the function body + initialization for such nodes also happens during node creation. Therefore, + initialization of function body will happen via this method only in case 2 mentioned above. + */ + const Function* GetFunctionBody(bool try_int_func_body = true) noexcept; + + /** Gets the function body if applicable otherwise nullptr. */ const Function* GetFunctionBody() const noexcept; /** Gets the node description. */ @@ -779,7 +793,10 @@ class Graph { @param node Node with Node::Type of Node::Type::Fused @returns Status indicating success or providing an error message. */ - Status InlineFunction(Node& node); + Status InlineFunction(Node& node); + + /** Initialize function body for the given node */ + void InitFunctionBodyForNode(Node& node); /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will be used as a GraphProto attribute in another Node.. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 3466ee8a47..1a604048ba 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -197,12 +197,13 @@ Status GraphPartitioner::Partition(Graph& graph, bool export_dll, FuncManager& f if (nullptr == node_func) { continue; } - nodes_need_inline.push_back(&node); + nodes_need_inline.push_back(&node); } - } + } + for (auto* node : nodes_need_inline) { // If the node has a functionbody with no kernel and cannot be inlined - // it is a invalid function + // it is an invalid function ORT_RETURN_IF_ERROR(graph.InlineFunction(*node)); } diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 8d6fc7457a..f6cca85ff5 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -178,8 +178,8 @@ bool NodeArg::HasTensorOrScalarShape() const { const auto type_case = type->value_case(); switch (type_case) { case TypeProto::kTensorType: - case TypeProto::kSparseTensorType: - // Standard tensor has a valid shape field while + case TypeProto::kSparseTensorType: + // Standard tensor has a valid shape field while // scalar's shape is empty. Thus, we don't need to // check shape here. return true; @@ -433,6 +433,19 @@ void Node::SetNodeType(Node::Type node_type) noexcept { node_type_ = node_type; } +const Function* Node::GetFunctionBody(bool try_init_func_body) noexcept { + if (nullptr != func_body_) { + return func_body_; + } + + // Initialize function body + if (try_init_func_body) { + graph_->InitFunctionBodyForNode(*this); + } + + return func_body_; +} + const Function* Node::GetFunctionBody() const noexcept { return func_body_; } @@ -1314,7 +1327,6 @@ void Graph::ReverseDFSFrom(const std::vector& from, const std::function& enter, const std::function& leave, const std::function& comp) const { - ReverseDFSFrom(from, enter, leave, comp, {}); } @@ -2037,22 +2049,6 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { node.op_ = nullptr; } - if (node.op_ && (node.op_->HasFunction() || node.op_->HasContextDependentFunction())) { - onnx::FunctionProto onnx_function_proto; - onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto); - if (node.op_->HasContextDependentFunction()) { - node.op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto); - } else { - onnx_function_proto = *(node.op_->GetFunction()); - } - - auto func_ptr = onnxruntime::make_unique(*this, node.Index(), onnx_function_proto, - logger_); - - function_container_.emplace_back(std::move(func_ptr)); - node.SetFunctionBody(*function_container_.back()); - } - if (!node.op_) { return Status(ONNXRUNTIME, FAIL, "Fatal error: " + node.OpType() + " is not a registered function/op"); } @@ -2100,6 +2096,26 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) { return Status::OK(); } +void Graph::InitFunctionBodyForNode(Node& node) { + if (node.op_ && (node.op_->HasFunction() || node.op_->HasContextDependentFunction())) { + onnx::FunctionProto onnx_function_proto; + if (node.op_->HasContextDependentFunction()) { + NodeProto node_proto; + node.ToProto(node_proto); + onnx::FunctionBodyBuildContextImpl function_body_ctx(node_proto); + node.op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto); + } else { + onnx_function_proto = *(node.op_->GetFunction()); + } + + auto func_ptr = onnxruntime::make_unique(*this, node.Index(), onnx_function_proto, + logger_); + + function_container_.emplace_back(std::move(func_ptr)); + node.SetFunctionBody(*function_container_.back()); + } +} + void Graph::FindAllSubgraphs(std::vector& subgraphs) { for (auto& node : Nodes()) { for (auto& subgraph : node.MutableSubgraphs()) { @@ -2398,8 +2414,8 @@ const std::vector& Graph::GetValueInfo() const noexcept { return value_info_; } -void Graph::AddValueInfo(const NodeArg* new_value_info){ - for(const auto* info : value_info_){ +void Graph::AddValueInfo(const NodeArg* new_value_info) { + for (const auto* info : value_info_) { ORT_ENFORCE(info->Name() != new_value_info->Name(), "Error: trying to add an existing value info."); } value_info_.push_back(new_value_info); diff --git a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc index 82b66eb7a4..4b00f4c0a6 100644 --- a/orttraining/orttraining/core/graph/mixed_precision_transformer.cc +++ b/orttraining/orttraining/core/graph/mixed_precision_transformer.cc @@ -376,7 +376,14 @@ Status TransformGraphForMixedPrecision(Graph& graph, const std::unordered_set& weights_to_train, bool use_fp16_initializer, std::unordered_map& fp32_weight_name_to_fp16_node_arg) { - // Stag 1: Convert whole graph including forward and backward to FP16 + // Stage 1: Convert whole graph including forward and backward to FP16 + // Initialize function body for all function nodes + // This is required to make sure after converting inputs\weights to FP16 + // the new NodeArg updates are correctly propagated to the function body nodes as well. + for (auto& node : graph.Nodes()) { + graph.InitFunctionBodyForNode(node); + } + // Insert Cast node to convert inputs from FP32 to FP16 for (const NodeArg* input : graph.GetInputs()) { if (input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { From 6365760906ebf7f2198a48b6612440b333dbd32e Mon Sep 17 00:00:00 2001 From: Sherlock Date: Tue, 30 Jun 2020 15:43:14 -0700 Subject: [PATCH 13/13] BiasDropoutFusion (#4167) * Implement BiasDropout Fusion and Kernel Dropout kernel for residual input BiasDropout Fusion to take residual input Fix BiasDropout Kernel Optimize DropoutGrad with 4 elements per thread * Add graph transformer UT * MLTypeCallDispatcher for RatioData * Use MLTypeDispatcher for ratio tensor * Handle traing_mode input for BiasDropout fusion * Add test case for missing ratio input * Replace using FinalizeNodeFusion * Make BiasDropout kernel template-less * Make DropoutGrad template-less * Make Dropout and TrainableDropout template-less * Regenerate onnx file for UT * Minior fix on divmod in BiasDropoutKernel * Adjust pt frontend test due to dropout randomnesss * Make dropout kernel opeartion in fp32 Co-authored-by: Sherlock Huang --- .../providers/cuda/cuda_execution_provider.cc | 20 +- onnxruntime/core/providers/cuda/nn/dropout.cc | 36 +-- onnxruntime/core/providers/cuda/nn/dropout.h | 62 +++-- .../core/providers/cuda/nn/dropout_impl.cu | 9 +- .../python/onnxruntime_test_ort_trainer.py | 12 +- .../fusion/bias_dropout_fusion1.onnx | Bin 0 -> 278 bytes .../fusion/bias_dropout_fusion2.onnx | Bin 0 -> 278 bytes .../fusion/bias_dropout_residual_fusion1.onnx | Bin 0 -> 356 bytes .../fusion/bias_dropout_residual_fusion2.onnx | Bin 0 -> 356 bytes ...bias_dropout_residual_fusion_mismatch.onnx | Bin 0 -> 338 bytes .../fusion/bias_dropout_residual_gen.py | 114 +++++++++ ...bias_trainabledropout_residual_fusion.onnx | Bin 0 -> 328 bytes .../core/graph/gradient_schema_defs.cc | 48 ++++ .../core/optimizer/bias_dropout_fusion.cc | 200 +++++++++++++++ .../core/optimizer/bias_dropout_fusion.h | 24 ++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../test/optimizer/graph_transform_test.cc | 29 +++ .../training_ops/cpu/nn/dropout_op_test.cc | 165 ++++++++++-- .../cuda/cuda_training_kernels.cc | 235 +++++++----------- .../training_ops/cuda/nn/dropout.cc | 235 ++++++++++++------ .../training_ops/cuda/nn/dropout.h | 21 +- .../training_ops/cuda/nn/dropout_impl.cu | 117 ++++++++- .../training_ops/cuda/nn/dropout_impl.h | 13 + 23 files changed, 1026 insertions(+), 316 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_dropout_fusion1.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_dropout_fusion2.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion1.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion2.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_mismatch.onnx create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py create mode 100644 onnxruntime/test/testdata/transform/fusion/bias_trainabledropout_residual_fusion.onnx create mode 100644 orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc create mode 100644 orttraining/orttraining/core/optimizer/bias_dropout_fusion.h diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 18a0e1a7ed..d0cd9c3e1e 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -813,15 +813,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, int64_t, GatherND); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16_MLFloat16, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16_float, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, MLFloat16_double, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float_MLFloat16, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float_float, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, float_double, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double_MLFloat16, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double_float, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, double_double, Dropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum); static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { @@ -1341,15 +1333,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, }; diff --git a/onnxruntime/core/providers/cuda/nn/dropout.cc b/onnxruntime/core/providers/cuda/nn/dropout.cc index 199b987c81..56b22e3593 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.cc +++ b/onnxruntime/core/providers/cuda/nn/dropout.cc @@ -6,30 +6,18 @@ namespace onnxruntime { namespace cuda { -#define REGISTER_KERNEL_TYPED(T1, T2) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Dropout, \ - kOnnxDomain, \ - 12, \ - T1##_##T2, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(1) \ - .InputMemoryType(2), \ - Dropout); - -REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16) -REGISTER_KERNEL_TYPED(MLFloat16, float) -REGISTER_KERNEL_TYPED(MLFloat16, double) -REGISTER_KERNEL_TYPED(float, MLFloat16) -REGISTER_KERNEL_TYPED(float, float) -REGISTER_KERNEL_TYPED(float, double) -REGISTER_KERNEL_TYPED(double, MLFloat16) -REGISTER_KERNEL_TYPED(double, float) -REGISTER_KERNEL_TYPED(double, double) +ONNX_OPERATOR_KERNEL_EX( + Dropout, + kOnnxDomain, + 12, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .InputMemoryType(1) + .InputMemoryType(2), + Dropout); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/dropout.h b/onnxruntime/core/providers/cuda/nn/dropout.h index a9fc195738..0ef8456f44 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout.h +++ b/onnxruntime/core/providers/cuda/nn/dropout.h @@ -12,10 +12,36 @@ namespace onnxruntime { namespace cuda { -template +template +struct GetRatioDataImpl { + void operator()(const Tensor* ratio, float& ratio_data) const { + ratio_data = static_cast(*(ratio->template Data())); + ORT_ENFORCE(ratio_data >= 0.0f && ratio_data < 1.0f, "ratio_data is outside range [0, 1)"); + } +}; + +template +struct DropoutComputeImpl { + void operator()(const cudaDeviceProp& prop, + const int64_t N, + const float ratio_data, + PhiloxGenerator& generator, + const Tensor& X, + Tensor& Y, + bool* mask_data) const { + typedef typename ToCudaType::MappedType CudaT; + + const CudaT* X_data = reinterpret_cast(X.template Data()); + CudaT* Y_data = reinterpret_cast(Y.template MutableData()); + + DropoutKernelImpl(prop, N, ratio_data, generator, X_data, Y_data, mask_data); + } +}; + +template class Dropout final : public CudaKernel { public: - Dropout(const OpKernelInfo& info) : CudaKernel(info), default_ratio_(0.5) { + Dropout(const OpKernelInfo& info) : CudaKernel(info) { int64_t seed = 0; if (info.GetAttr("seed", &seed).IsOK()) { generator_ = onnxruntime::make_unique(static_cast(seed)); @@ -26,48 +52,40 @@ class Dropout final : public CudaKernel { private: mutable std::unique_ptr generator_; - const float default_ratio_; + static constexpr float default_ratio_ = 0.5f; }; -template -Status Dropout::ComputeInternal(OpKernelContext* context) const { - typedef typename ToCudaType::MappedType CudaT; - +template +Status Dropout::ComputeInternal(OpKernelContext* context) const { //Get X_data const Tensor* X = context->Input(0); if (X == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "X Input is not available."); const TensorShape& shape = X->Shape(); - auto X_data = reinterpret_cast(X->template Data()); const int64_t N = shape.Size(); //Get Y_data auto Y = context->Output(0, shape); - auto Y_data = reinterpret_cast(Y->template MutableData()); //Get mask_data auto mask = context->Output(1, shape); ORT_ENFORCE(!mask || mask->Shape().Size() == N); //Get the ratio_data - float ratio_data; + float ratio_data = default_ratio_; auto ratio = context->Input(1); - - static_assert(std::is_same::value || std::is_same::value || std::is_same::value, - "T2 must be float16 or float or double"); - if (ratio) { - ratio_data = static_cast(*(ratio->template Data())); - } else { - ratio_data = default_ratio_; + utils::MLTypeCallDispatcher t_disp(ratio->GetElementType()); + t_disp.Invoke(ratio, ratio_data); } - ORT_ENFORCE(ratio_data >= 0.0f && ratio_data < 1.0f); const Tensor* training_mode = context->Input(2); //Check for inference mode. if ((0 == ratio_data /*Backward compat with TrainableDropout*/) || (!trainable_dropout && (training_mode == nullptr || *(training_mode->Data()) == false))) { + const void* X_data = X->DataRaw(); + void* Y_data = Y->MutableDataRaw(); if (Y_data != X_data) { - CUDA_CALL_THROW(cudaMemcpyAsync(Y_data, X_data, N * sizeof(T1), cudaMemcpyDeviceToDevice)); + CUDA_CALL_THROW(cudaMemcpyAsync(Y_data, X_data, X->SizeInBytes(), cudaMemcpyDeviceToDevice)); } // If mask is requested, return all 1s. @@ -85,8 +103,10 @@ Status Dropout::ComputeInternal(OpKernelContext* cont return temp_mask_buffer.get(); }(); - PhiloxGenerator& generator = generator_ != nullptr ? *generator_.get() : PhiloxGenerator::Default(); - DropoutKernelImpl(GetDeviceProp(), N, ratio_data, generator, X_data, Y_data, mask_data); + PhiloxGenerator& generator = generator_ ? *generator_ : PhiloxGenerator::Default(); + + utils::MLTypeCallDispatcher t_disp(X->GetElementType()); + t_disp.Invoke(GetDeviceProp(), N, ratio_data, generator, *X, *Y, mask_data); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/dropout_impl.cu b/onnxruntime/core/providers/cuda/nn/dropout_impl.cu index d0ebc21851..0abe3f774a 100644 --- a/onnxruntime/core/providers/cuda/nn/dropout_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/dropout_impl.cu @@ -35,7 +35,7 @@ __global__ void DropoutKernel( T* Y_data, bool* mask_data) { const float p = 1.0f - ratio; - const T scale = T(1.0f / p); + const float scale = 1.0f / p; CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; CUDA_LONG step_size = gridDim.x * blockDim.x * UNROLL; @@ -52,12 +52,13 @@ __global__ void DropoutKernel( // use of Philox_4x32_10 is to generate a multiple of 4 times number of threads. for (CUDA_LONG id = idx; id < rounded_size; id += step_size) { float4 rand = curand_uniform4(&state); - + + #pragma unroll for (CUDA_LONG i = 0; i < UNROLL; i++) { CUDA_LONG li = id + gridDim.x * blockDim.x * i; if (li < N) { mask_data[li] = (&rand.x)[i] < p; - Y_data[li] = X_data[li] * T(mask_data[li]) * scale; + Y_data[li] = T(float(X_data[li]) * mask_data[li] * scale); } } @@ -76,7 +77,7 @@ void DropoutKernelImpl( bool* mask_data) { const int block_size = 256; const int blocks_per_sm = prop.maxThreadsPerMultiProcessor / block_size; - const int grid_size = std::min(prop.multiProcessorCount * blocks_per_sm, static_cast(CeilDiv(N, block_size))); + const int grid_size = std::min(prop.multiProcessorCount * blocks_per_sm, static_cast(CeilDiv(N, block_size * UNROLL))); // Compute the number of random numbers generated by each thread, and increment philox generator offset by that amount. const uint64_t counter_offset = static_cast(((N - 1) / (block_size * grid_size * UNROLL) + 1) * UNROLL); diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 52790dd40a..8566c31577 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -655,9 +655,7 @@ class TestOrtTrainer(unittest.TestCase): assert np.array_equal(state_dict[key], loaded_state_dict[key]) def testBertTrainingBasic(self): - expected_losses = [ - 11.02906322479248, 11.094074249267578, 11.00899887084961, 11.06129264831543, - 11.029067039489746, 11.040265083312988, 11.046793937683105, 10.993699073791504] + expected_losses = [11.034271, 11.125311, 11.006095, 11.046938, 11.027476, 11.015745, 11.060884, 10.971851] expected_eval_loss = [10.95898914] actual_losses, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=False) @@ -669,14 +667,12 @@ class TestOrtTrainer(unittest.TestCase): # print('eval_loss actual: ', actual_eval_loss) # import pdb; pdb.set_trace() - rtol = 1e-04 + rtol = 1e-03 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") def testBertTrainingGradientAccumulation(self): - expected_losses = [ - 11.02906322479248, 11.094074249267578, 11.008995056152344, 11.061283111572266, - 11.029059410095215, 11.04024887084961, 11.04680347442627, 10.993708610534668] + expected_losses = [11.034271, 11.125311, 11.006093, 11.046929, 11.027471, 11.015731, 11.060894, 10.971855] expected_eval_loss = [10.959011] actual_losses, actual_eval_loss = runBertTrainingTest( @@ -689,7 +685,7 @@ class TestOrtTrainer(unittest.TestCase): # print('eval_loss actual: ', actual_eval_loss) # import pdb; pdb.set_trace() - rtol = 1e-04 + rtol = 1e-03 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_fusion1.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_fusion1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..40ce7de83230249a800e1ab46242e6fb5da4f441 GIT binary patch literal 278 zcmd;J7vjm!%d5~$tw_u*$Vs*O%g80o#puYz=p@9En37@;pI=%c#R6g}F*~NDa5-_o zg}Jzk5=%1k9so7EMz#O| literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_fusion2.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_fusion2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..9075927dd0189e9f33733559a1ea469d716a12b4 GIT binary patch literal 278 zcmd;J7vjm!%d5~$tw_u*$Vs*O%g80o#puMv=qSXIn37@;pI=%c#R6g}F*~NDa5-_o zg}Jzk5=%1kUUs}8jA6vZ%+VkbB>?n>5H}YmPz@_sBnha?nSd@Q7A^(>9so7HMz#O| literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion1.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion1.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ab104bf83105ad714097574b284db09e13e1c973 GIT binary patch literal 356 zcmaivJBz|V6oq%zNX8o%!c|ceEm8&q-EterC{|XMHpA!y88kDJ2Uza^Tm4T?eBoAh zmvg_vIfuGfH(DQu@lhm)ef;zH2|XYY5QbDpIpU_syH0rpl}Jihz*7b<6}d8eV|13o zSDp%`mEQ8i$QV7DnuGP0esmI%?Yi_XjRJRed%^%-IUq|U(`(xJnuZOgM8@wlpBI@j zI&AqQaL%X2guNCBn|>`Mg#GWa3hWr{Lyr<{gNxNd@A&v}d_#r)c2<75)`Oa5syD^D RmTz2Z%S_)e%z+QTegS$;S&aYy literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion2.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion2.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dec9a1423d97351e6b8a43a2af79f1e1b72fc6e4 GIT binary patch literal 356 zcmaivI}5@v6orje8gCU!R1`&rjs-y-U8GT*oE%+3G)00okv_oD|EK;ZQ(x%hb~*Pu zoO7^Es#>YjAUyHtcnBw#B6I;!KmoyAh#}Ki+OhJ>D+NNp56bmvq#$Kkw)rN@`Sx;w#X8xmEUxc zXWg#LDg7?sH{6Pk5dH13_VOh?=wd`#V3TsD_H6V%#!$lN_S!Xri31hGl%I-mt^T++ NhME0gs09VLdIxZSS&aYy literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_mismatch.onnx b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_fusion_mismatch.onnx new file mode 100644 index 0000000000000000000000000000000000000000..12559ee280afe93eea8ac5f774d4fc958a798769 GIT binary patch literal 338 zcmah^y9&ZE6g66DdaF>Pq9{6aEC}l8B1Um?a_kbKDH61a^Z}0kmj9stWa|q%xn0gZ zPi&H!R_Z(q&OAIFgXy&lJwO!DoM0}*i0L%x8u{f`=0d<4zA)GkPo!o$t>OggNyMd+ zYR^I~0%D6upE7N(Cq~HQLl?iHk!vnrPZ$8&K!J)D&$RI~jWnp_F?+MwDvhO9UdutQ zalfr+^tXW5bSw`cTI}Jw`I0{LFd}U*Ni|gmHvSwFsNhS-ce2661}o!KzsRP5Y%`z5 LA0KK!$*Da6={Qv; literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py new file mode 100644 index 0000000000..3b1f9f01fc --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py @@ -0,0 +1,114 @@ +import onnx +from onnx import helper +from onnx import TensorProto, OperatorSetIdProto + +# inputs/outputs +A = helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]) +R = helper.make_tensor_value_info('R', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +C = helper.make_tensor_value_info('C', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +mask = helper.make_tensor_value_info('mask', TensorProto.BOOL, ['unk_1', 'unk_2', 3072]) + +# initializers +ratio = helper.make_tensor('ratio_const', TensorProto.FLOAT, [], [0.8]) +training_mode = helper.make_tensor('training_mode', TensorProto.BOOL, [], [1]) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +bias = helper.make_node("Add", ["A", "B"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["C", "mask"], "dropout0") + +graph = helper.make_graph( + [bias, dropout_12], + "Bias_Dropout_Fusion", #name + [A, B], + [C], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_fusion1.onnx') + +# Create the model (ModelProto) +bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["C", "mask"], "dropout0") + +graph = helper.make_graph( + [bias, dropout_12], + "Bias_Dropout_Fusion", #name + [A, B], + [C], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_fusion2.onnx') + + +# Create the model (ModelProto) +bias = helper.make_node("Add", ["A", "B"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +residual = helper.make_node("Add", ["dropout_out", "R"], ["C"], "add1") + +graph = helper.make_graph( + [bias, dropout_12, residual], + "Bias_Dropout_Fusion", #name + [A, B, R], + [C], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_residual_fusion1.onnx') + +# Create the model (ModelProto) +bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") + +graph = helper.make_graph( + [bias, dropout_12, residual], + "Bias_Dropout_Fusion", #name + [A, B, R], + [C], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_residual_fusion2.onnx') + +# Create the model (ModelProto) +R_mismatch = helper.make_tensor_value_info('R', TensorProto.FLOAT, [3072]) + +bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") +dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") + +graph = helper.make_graph( + [bias, dropout_12, residual], + "Bias_Dropout_Fusion", #name + [A, B, R_mismatch], + [C], + [ratio, training_mode]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_dropout_residual_fusion_mismatch.onnx') + +# Create the model (ModelProto) +bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") +trainable_dropout = helper.make_node("TrainableDropout", ["add0_out", "ratio_const"], ["dropout_out", "mask"], "dropout0") +residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") + +graph = helper.make_graph( + [bias, trainable_dropout, residual], + "Bias_Dropout_Fusion", #name + [A, B, R], + [C], + [ratio]) + +model = helper.make_model(graph, producer_name='onnx-example', **kwargs) +onnx.save(model, 'bias_trainabledropout_residual_fusion.onnx') \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/bias_trainabledropout_residual_fusion.onnx b/onnxruntime/test/testdata/transform/fusion/bias_trainabledropout_residual_fusion.onnx new file mode 100644 index 0000000000000000000000000000000000000000..6d24631a1c0a5209734c06c91f14eba154fb4180 GIT binary patch literal 328 zcmd;J7vjm!%d5~$tw_u*$Vs(Y&%`Co#puMv=qSXIn37@;pI=%c#R6g}F*~NDaJg~8 zg}Jzk5=%1kd_X+3F}IkZWbRq&XOc z#JSjoSWEM=;|*a9BQ9o+28k#Epr?emxj2DpSRo=nT|oqNB>~kq6Hw#C!o?uK0{{j3 BQC$E4 literal 0 HcmV?d00001 diff --git a/orttraining/orttraining/core/graph/gradient_schema_defs.cc b/orttraining/orttraining/core/graph/gradient_schema_defs.cc index 31056a5bc4..72702786e4 100644 --- a/orttraining/orttraining/core/graph/gradient_schema_defs.cc +++ b/orttraining/orttraining/core/graph/gradient_schema_defs.cc @@ -1010,6 +1010,54 @@ Example 4: "Constrain indices to integer types") .SetDoc(R"DOC(SoftmaxCrossEntropyLossGrad)DOC"); + ONNX_CONTRIB_OPERATOR_SCHEMA(BiasDropout) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc("BiasDropout") + .Attr("seed", "(Optional) Seed to the random generator, if not specified we will auto generate one.", AttributeProto::INT, OPTIONAL_VALUE) + .AllowUncheckedAttributes() + .Input(0, "data", "The input data as Tensor.", "T") + .Input(1, "bias", "The bias input, a vector with the same shape as last dim of data", "T") + .Input(2, "residual", "The residual input, must have the same shape as data", "T", OpSchema::Optional) + .Input(3, "ratio", + "The ratio of random dropout, with value in [0, 1). If this input was not set, " + "or if it was set to 0, the output would be a simple copy of the input. " + "If it's non-zero, output will be a random dropout of input, which is typically " + "the case during training.", + "T1", + OpSchema::Optional) + .Input(4, "training_mode", + "If set to true then it indicates dropout is being used for " + "training. It is an optional value hence unless specified explicitly, it is false. " + "If it is false, ratio is ignored and the operation mimics inference mode where nothing " + "will be dropped from the input data and if mask is requested as output it will contain " + "all ones.", + "T2", + OpSchema::Optional) + .Output(0, "output", "The output.", "T") + .Output(1, "mask", "The output mask of dropout.", "T2", OpSchema::Optional) + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint( + "T1", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input 'ratio' types to float tensors.") + .TypeConstraint( + "T2", + {"tensor(bool)"}, + "Constrain output 'mask' types to boolean tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateShapeAndTypeFromFirstInput(ctx); + if (ctx.getNumOutputs() == 2) { + updateOutputElemType(ctx, 1, ONNX_NAMESPACE::TensorProto::BOOL); + if (hasNInputShapes(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + }); + ONNX_CONTRIB_OPERATOR_SCHEMA(TrainableDropout) .SetDomain(kOnnxDomain) .SinceVersion(9) diff --git a/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc b/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc new file mode 100644 index 0000000000..d0e467c35a --- /dev/null +++ b/orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc @@ -0,0 +1,200 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/initializer.h" +#include "orttraining/core/optimizer/bias_dropout_fusion.h" +#include "core/graph/graph_utils.h" +#include + +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::common; +namespace onnxruntime { + +void FuseResidualAddIfAny(Graph& graph, const Node& dropout_node, + std::vector& dropout_input, + std::vector& dropout_output, + std::vector>& nodes_to_fuse) { + bool has_residual_add = false; + for (auto last_node_itr = dropout_node.OutputNodesBegin(); last_node_itr != dropout_node.OutputNodesEnd(); ++last_node_itr) { + const Node& last_node = (*last_node_itr); + + if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Add", {7}) && + last_node.GetExecutionProviderType() == dropout_node.GetExecutionProviderType()) { + const TensorShapeProto* input1_shape = last_node.InputDefs()[0]->Shape(); + const TensorShapeProto* input2_shape = last_node.InputDefs()[1]->Shape(); + + if (input1_shape == nullptr || + input2_shape == nullptr || + input1_shape->dim_size() < 1 || + input2_shape->dim_size() < 1 || + input1_shape->dim_size() != input2_shape->dim_size()) { + continue; + } + + // Inputs of Residual Add must match in shape + bool match = true; + for (int i = 0; i < input1_shape->dim_size(); ++i) { + match &= ONNX_NAMESPACE::operator==(input1_shape->dim(i), input2_shape->dim(i)); + } + if (!match) { + continue; + } + + // dropout's output is not part of of graph output + if (!graph.GetNodeOutputsInGraphOutputs(dropout_node).empty()) { + continue; + } + + Node& residual_add_node = *graph.GetNode(last_node.Index()); + const std::string& dropout_output_name = dropout_node.OutputDefs()[0]->Name(); + if (dropout_output_name == residual_add_node.InputDefs()[0]->Name()) { + dropout_input.push_back(residual_add_node.MutableInputDefs()[1]); // residual + } else if (dropout_output_name == residual_add_node.InputDefs()[1]->Name()) { + dropout_input.push_back(residual_add_node.MutableInputDefs()[0]); // residual + } + + dropout_output[0] = residual_add_node.MutableOutputDefs()[0]; + + nodes_to_fuse.push_back(residual_add_node); + has_residual_add = true; + break; + } + } + + if (!has_residual_add) { + NodeArg& dummy = graph.GetOrCreateNodeArg("", nullptr); + dropout_input.push_back(&dummy); // add a dummy residual + } +} + +Status BiasDropoutFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + for (auto node_index : node_topology_list) { + auto* node_ptr = graph.GetNode(node_index); + if (nullptr == node_ptr) + continue; // node was removed + + auto& node = *node_ptr; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + std::vector> nodes_to_fuse; + + // matching for bias Add node + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Add", {7}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders()) || + node.GetOutputEdgesCount() != 1) { + continue; + } + + std::vector dropout_input, dropout_output; + const TensorShapeProto* input1_shape = node.MutableInputDefs()[0]->Shape(); + const TensorShapeProto* input2_shape = node.MutableInputDefs()[1]->Shape(); + + if (input1_shape == nullptr || + input2_shape == nullptr || + input1_shape->dim_size() < 1 || + input2_shape->dim_size() < 1) { + continue; + } + + int last_dim_shape1 = input1_shape->dim_size() - 1; + int last_dim_shape2 = input2_shape->dim_size() - 1; + if (!utils::HasDimValue(input1_shape->dim(last_dim_shape1)) || + !utils::HasDimValue(input2_shape->dim(last_dim_shape2)) || + input1_shape->dim(last_dim_shape1).dim_value() != input2_shape->dim(last_dim_shape2).dim_value()) { + continue; + } + + if (input1_shape->dim_size() == 1) { + dropout_input.push_back(node.MutableInputDefs()[1]); // dropout input + dropout_input.push_back(node.MutableInputDefs()[0]); // bias + } else if (input2_shape->dim_size() == 1) { + dropout_input.push_back(node.MutableInputDefs()[0]); // dropout input + dropout_input.push_back(node.MutableInputDefs()[1]); // bias + } else { + continue; + } + Node& add_node = node; + nodes_to_fuse.push_back(add_node); + + // matching for Dropout node + auto next_node_itr = node.OutputNodesBegin(); + if (next_node_itr == node.OutputNodesEnd()) { + continue; + } + + const Node& next_node = (*next_node_itr); + if (!(graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Dropout", {12}, kOnnxDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "TrainableDropout", {9}, kOnnxDomain)) || + next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { + continue; + } + + if (!graph.GetNodeOutputsInGraphOutputs(node).empty()) { + continue; + } + + Node& dropout_node = *graph.GetNode(next_node.Index()); + nodes_to_fuse.push_back(dropout_node); + + dropout_output.push_back(dropout_node.MutableOutputDefs()[0]); + dropout_output.push_back(dropout_node.MutableOutputDefs()[1]); + + FuseResidualAddIfAny(graph, dropout_node, dropout_input, dropout_output, nodes_to_fuse); + + if (dropout_node.InputDefs().size() > 1) { + dropout_input.push_back(dropout_node.MutableInputDefs()[1]); // ratio + } + + // populate training_mode + bool is_trainable_dropout = (dropout_node.OpType() == "TrainableDropout"); + if (is_trainable_dropout) { + // Create training_mode initializer + ONNX_NAMESPACE::TensorProto training_mode_initializer; + training_mode_initializer.set_name(graph.GenerateNodeArgName("training_mode")); + training_mode_initializer.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_BOOL); + const bool data = true; + training_mode_initializer.set_raw_data(&data, sizeof(bool)); + + NodeArg& training_mode_node_arg = graph_utils::AddInitializer(graph, training_mode_initializer); + dropout_input.push_back(&training_mode_node_arg); + } else { + if (dropout_node.InputDefs().size() > 2) { + dropout_input.push_back(dropout_node.MutableInputDefs()[2]); + } + } + + const std::string op_type = "BiasDropout"; + Node& dropout_add_fusion_node = graph.AddNode(graph.GenerateNodeName(op_type), + op_type, + "fused Add and Dropout", + dropout_input, + dropout_output, + {}, + kMSDomain); + + // Get attribute "seed" from "Dropout" node if available. + NodeAttributes dropout_attrs = dropout_node.GetAttributes(); + NodeAttributes::const_iterator seed = dropout_attrs.find("seed"); + if (seed != dropout_attrs.end()) { + dropout_add_fusion_node.AddAttribute("seed", seed->second); + } + + // Assign provider to this new node. Provider should be same as the provider for old node. + dropout_add_fusion_node.SetExecutionProviderType(dropout_node.GetExecutionProviderType()); + + // delete bias_add_node, dropout_node and optionally residual_add_node + for (Node& n : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, n); + graph.RemoveNode(n.Index()); + } + + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/bias_dropout_fusion.h b/orttraining/orttraining/core/optimizer/bias_dropout_fusion.h new file mode 100644 index 0000000000..ae16619541 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/bias_dropout_fusion.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class BiasDropoutFusion + +Fuse Add + Dropout + optional Add to BiasDropoutFusion + +*/ +class BiasDropoutFusion : public GraphTransformer { + public: + BiasDropoutFusion(const std::unordered_set& compatible_execution_providers = {}) noexcept + : GraphTransformer("BiasDropoutFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 8cbc56d970..a9ac30528a 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -6,6 +6,7 @@ #include "orttraining/core/framework/distributed_run_context.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/megatron_transformer.h" +#include "orttraining/core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/slice_elimination.h" @@ -143,6 +144,7 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); transformers.emplace_back(onnxruntime::make_unique(free_dimension_overrides)); transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(l1_execution_providers)); rule_transformer = optimizer_utils::GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers); } break; diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index b41b01609f..2f4e119ec0 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -10,11 +10,13 @@ #include "gtest/gtest.h" #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/utils.h" +#include "orttraining/core/optimizer/bias_dropout_fusion.h" #include "orttraining/core/optimizer/gist_encode_decode.h" #include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" +#include "test/util/include/asserts.h" #include "orttraining/test/optimizer/horizontal_parallel_test_utils.h" #include @@ -45,6 +47,33 @@ TEST_F(GraphTransformationTests, GistEncodeDecode) { ASSERT_TRUE(op_to_count["GistBinarizeEncoder"] == op_to_count["GistBinarizeEncoder"]); } +static void TestBiasDropoutFusion(const PathString& file_path, const logging::Logger& logger, const int add_count = 0) { + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(file_path, p_model, nullptr, logger).IsOK()); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, logger); + ASSERT_STATUS_OK(ret); + + std::map op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count["Add"], add_count); + ASSERT_EQ(op_to_count["Dropout"], 0); + ASSERT_EQ(op_to_count["TrainableDropout"], 0); + ASSERT_EQ(op_to_count["BiasDropout"], 1); +} + +TEST_F(GraphTransformationTests, BiasDropoutFusionTest) { + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_fusion1.onnx", *logger_); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_fusion2.onnx", *logger_); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion1.onnx", *logger_); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion2.onnx", *logger_); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_dropout_residual_fusion_mismatch.onnx", *logger_, 1); + TestBiasDropoutFusion(MODEL_FOLDER "fusion/bias_trainabledropout_residual_fusion.onnx", *logger_); +} + Node* GetNodeByName(Graph& graph, std::string node_name) { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc b/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc index da3ba31a01..70fff78a5d 100644 --- a/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" @@ -33,7 +34,7 @@ const Tensor& FetchTensor(const OrtValue& ort_value) { return ort_value.Get(); } -void RunDropoutTest(const char* op, const bool use_mask, const std::vector& input_shape, float ratio = -1, +void RunDropoutTest(const char* op, const bool use_mask, const std::vector& input_shape, float ratio = -1.0f, bool training_mode = true, bool use_float16_ratio = false) { OpTester t{op, k_dropout_opset_version, kOnnxDomain}; @@ -45,13 +46,21 @@ void RunDropoutTest(const char* op, const bool use_mask, const std::vector(); + } else { + t.AddMissingOptionalInput(); + } + // set ratio to default value + ratio = 0.5f; } else { - t.AddInput("ratio", {}, {ratio}); + if (use_float16_ratio) { + t.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))}); + } else { + t.AddInput("ratio", {}, {ratio}); + } } if (strcmp(op, "TrainableDropout") != 0 && training_mode) { @@ -73,12 +82,12 @@ void RunDropoutTest(const char* op, const bool use_mask, const std::vector(); - const auto num_output_zeros = std::count(output_span.begin(), output_span.end(), 0.0f); + const auto num_dropped_values = std::count(output_span.begin(), output_span.end(), 0.0f); if (ratio == 1.0f) { - ASSERT_EQ(num_output_zeros, static_cast(output_span.size())) << "provider: " << provider_type; + ASSERT_EQ(num_dropped_values, static_cast(output_span.size())) << "provider: " << provider_type; } else { - ASSERT_NEAR(static_cast(num_output_zeros) / static_cast(output_span.size()), ratio, 0.1f) + ASSERT_NEAR(static_cast(num_dropped_values) / static_cast(output_span.size()), ratio, 0.1f) << "provider: " << provider_type; for (decltype(output_span.size()) i = 0; i < output_span.size(); ++i) { @@ -96,7 +105,7 @@ void RunDropoutTest(const char* op, const bool use_mask, const std::vector& input_shape, float ratio = -1.0f, + bool training_mode = true, bool use_float16_ratio = false, bool has_residual = true) { + OpTester t{"BiasDropout", 1, kMSDomain}; + const int64_t seed = 42; + t.AddAttribute("seed", seed); + + const auto input_size = std::accumulate( + input_shape.begin(), input_shape.end(), static_cast(1), std::multiplies<>{}); + const std::vector input = ValueRange(input_size, 1.0f, 1.0f); + t.AddInput("data", input_shape, input); + + std::vector bias_shape{input_shape.back()}; + const auto bias_size = input_shape.back(); + const std::vector bias = ValueRange(bias_size, 2.0f, 1.0f); + t.AddInput("bias", bias_shape, bias); + + float residual_value = 0.0f; + if (has_residual) { + residual_value = 1.0f; + const auto residual_size = input_size; + const std::vector residual(residual_size, residual_value); + t.AddInput("residual", input_shape, residual); + } else { + t.AddMissingOptionalInput(); + } + + if (ratio == -1.0f) { + if (use_float16_ratio) { + t.AddMissingOptionalInput(); + } else { + t.AddMissingOptionalInput(); + } + // set ratio to default value + ratio = 0.5f; + } else { + if (use_float16_ratio) { + t.AddInput("ratio", {}, {MLFloat16(math::floatToHalf(ratio))}); + } else { + t.AddInput("ratio", {}, {ratio}); + } + } + + if (training_mode) { + t.AddInput("training_mode", {}, {true}); + } + + t.AddOutput("output", input_shape, input); // we'll do our own output verification + + std::unique_ptr mask_buffer{}; + if (use_mask) { + mask_buffer = onnxruntime::make_unique(input_size); + t.AddOutput("mask", input_shape, mask_buffer.get(), input_size); + } else { + t.AddMissingOptionalOutput(); + } + + auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + ASSERT_GE(fetches.size(), 1); + const auto& output_tensor = FetchTensor(fetches[0]); + auto output_span = output_tensor.DataAsSpan(); + + const auto num_dropped_values = std::count(output_span.begin(), output_span.end(), residual_value); + + if (ratio == 1.0f) { + ASSERT_EQ(num_dropped_values, static_cast(output_span.size())) << "provider: " << provider_type; + } else { + ASSERT_NEAR(static_cast(num_dropped_values) / static_cast(output_span.size()), ratio, 0.1f) + << "provider: " << provider_type; + + for (decltype(output_span.size()) i = 0; i < output_span.size(); ++i) { + if (output_span[i] == residual_value) continue; + const auto expected_value = (bias[i % bias_size] + i + 1.0f) / (1 - ratio) + residual_value; + ASSERT_NEAR(output_span[i], expected_value, 0.01f) + << "unexpected output value at index " << i << ", provider: " << provider_type; + } + } + + if (use_mask) { + ASSERT_GE(fetches.size(), 2); + const auto& mask_tensor = FetchTensor(fetches[1]); + auto mask_span = mask_tensor.DataAsSpan(); + ASSERT_EQ(mask_span.size(), output_span.size()) << "provider: " << provider_type; + + const auto num_mask_zeros = std::count(mask_span.begin(), mask_span.end(), false); + ASSERT_EQ(num_dropped_values, num_mask_zeros) << "provider: " << provider_type; + + for (decltype(mask_span.size()) i = 0; i < mask_span.size(); ++i) { + ASSERT_TRUE( + (mask_span[i] && output_span[i] != residual_value) || (!mask_span[i] && output_span[i] == residual_value)) + << "output and mask mismatch at index " << i << ", output[i]: " << output_span[i] + << ", mask[i]: " << mask_span[i] << ", provider: " << provider_type; + } + } + }; + + t.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, nullptr, ExecutionMode::ORT_SEQUENTIAL, output_verifier); +} +} // namespace + +TEST(BiasDropoutTest, Basic) { + RunBiasDropoutTest(false, {10, 10, 10}, 0.75f); +} + +TEST(BiasDropoutTest, BasicWithoutResidual) { + RunBiasDropoutTest(false, {10, 10, 10}, 0.75f, true, false, false); +} + +TEST(BiasDropoutTest, Mask) { + RunBiasDropoutTest(true, {3, 5, 768}, 0.25f); +} + +TEST(BiasDropoutTest, RatioLimit) { + RunBiasDropoutTest(true, {4, 8, 1024}, 0.0f, false); +} + +TEST(BiasDropoutTest, EmptyRatio) { + RunBiasDropoutTest(true, {2, 7, 1024}); +} +#endif + namespace { void RunDropoutGradTest(const char* op, float ratio, const std::vector& input_dims, bool default_ratio = true) { const auto input_shape = TensorShape(input_dims); diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 922f87c708..e648e6fb14 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -52,33 +52,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BatchNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BatchNormalizationGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16_MLFloat16, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16_float, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, MLFloat16_double, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float_MLFloat16, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float_float, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, float_double, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double_MLFloat16, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double_float, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, double_double, TrainableDropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_double, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_double, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_MLFloat16, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, TrainableDropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_MLFloat16, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_float, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_double, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_MLFloat16, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_float, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_double, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_MLFloat16, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_float, DropoutGrad); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double_double, DropoutGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 9, TrainableDropout); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, TrainableDropoutGrad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad); // TODO: decprecate GatherND-1 after updating training models to opset-12 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, int64_t, GatherND); @@ -141,133 +118,111 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Mega Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // Adam - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // Adam + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // Lamb - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // Lamb + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // TODO: decprecate GatherND-1 after updating training models to opset-12 - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - // P2P communication operators. + // TODO: decprecate GatherND-1 after updating training models to opset-12 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + +// P2P communication operators. #if defined(USE_NCCL) || defined(USE_HOROVOD) - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif #ifdef USE_HOROVOD - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef USE_NCCL - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc index fd84b6f3bb..425e68827b 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout.cc @@ -4,104 +4,187 @@ #include "core/framework/random_seed.h" #include "orttraining/training_ops/cuda/nn/dropout.h" #include "core/providers/cuda/nn/dropout.h" +#include "core/providers/cuda/cuda_common.h" #include "core/providers/common.h" namespace onnxruntime { namespace cuda { -#define REGISTER_TRAINABLE_KERNEL_TYPED(T1, T2) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - TrainableDropout, \ - kOnnxDomain, \ - 9, \ - T1##_##T2, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(1), \ - Dropout); +// Temporary for backward compatibility, will eventually get rid of TrainableDropout when PyTorch exporter will move to +// opset-12. +ONNX_OPERATOR_KERNEL_EX( + TrainableDropout, + kOnnxDomain, + 9, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::AllIEEEFloatTensorTypes()) + .InputMemoryType(1), + Dropout); + +#define REGISTER_GRADIENT_KERNEL(OpName) \ + ONNX_OPERATOR_KERNEL_EX( \ + OpName, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) \ + .TypeConstraint("T1", DataTypeImpl::AllIEEEFloatTensorTypes()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(2), \ + DropoutGrad); + +REGISTER_GRADIENT_KERNEL(DropoutGrad) // Temporary for backward compatibility, will eventually get rid of TrainableDropout when PyTorch exporter will move to // opset-12. -REGISTER_TRAINABLE_KERNEL_TYPED(MLFloat16, MLFloat16) -REGISTER_TRAINABLE_KERNEL_TYPED(MLFloat16, float) -REGISTER_TRAINABLE_KERNEL_TYPED(MLFloat16, double) -REGISTER_TRAINABLE_KERNEL_TYPED(float, MLFloat16) -REGISTER_TRAINABLE_KERNEL_TYPED(float, float) -REGISTER_TRAINABLE_KERNEL_TYPED(float, double) -REGISTER_TRAINABLE_KERNEL_TYPED(double, MLFloat16) -REGISTER_TRAINABLE_KERNEL_TYPED(double, float) -REGISTER_TRAINABLE_KERNEL_TYPED(double, double) +REGISTER_GRADIENT_KERNEL(TrainableDropoutGrad) -#define REGISTER_GRADIENT_KERNEL_TYPED(OpName, T1, T2) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - OpName, \ - kMSDomain, \ - 1, \ - T1##_##T2, \ - kCudaExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(2) \ - .InputMemoryType(3), \ - DropoutGrad); +template +struct DropoutGradComputeImpl { + void operator()(const int64_t N, + const Tensor& dY, + const bool* mask_data, + const float ratio_data, + Tensor& dX) const { + typedef typename ToCudaType::MappedType CudaT; -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, MLFloat16) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, float) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, MLFloat16, double) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, float, MLFloat16) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, float, float) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, float, double) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, double, MLFloat16) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, double, float) -REGISTER_GRADIENT_KERNEL_TYPED(DropoutGrad, double, double) - -// Temporary for backward compatibility, will eventually get rid of TrainableDropout when PyTorch exporter will move to -// opset-12. -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, MLFloat16, MLFloat16) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, MLFloat16, float) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, MLFloat16, double) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, float, MLFloat16) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, float, float) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, float, double) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, double, MLFloat16) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, double, float) -REGISTER_GRADIENT_KERNEL_TYPED(TrainableDropoutGrad, double, double) - -template -Status DropoutGrad::ComputeInternal(OpKernelContext* context) const { - typedef typename ToCudaType::MappedType CudaT; + const CudaT* dY_data = reinterpret_cast(dY.template Data()); + CudaT* dX_data = reinterpret_cast(dX.template MutableData()); + DropoutGradientKernelImpl(N, dY_data, mask_data, ratio_data, dX_data); + } +}; +Status DropoutGrad::ComputeInternal(OpKernelContext* context) const { auto dY = context->Input(0); const TensorShape& shape = dY->Shape(); - auto dY_data = reinterpret_cast(dY->template Data()); const int64_t N = shape.Size(); auto mask = context->Input(1); ORT_ENFORCE(mask->Shape().Size() == N); + const bool* mask_data = mask->template Data(); + + //Get the ratio_data + float ratio_data = default_ratio_; + auto ratio = context->Input(2); + if (ratio) { + utils::MLTypeCallDispatcher t_disp(ratio->GetElementType()); + t_disp.Invoke(ratio, ratio_data); + } auto dX = context->Output(0, shape); - auto dX_data = reinterpret_cast(dX->template MutableData()); - float ratio_data; - auto ratio = context->Input(2); - static_assert(std::is_same::value || std::is_same::value || std::is_same::value, - "T2 must be float16 or float or double"); - - if (ratio) { - ratio_data = static_cast(*(ratio->template Data())); - } else { - ratio_data = default_ratio_; - } - ORT_ENFORCE(ratio_data >= 0.0f && ratio_data < 1.0f); - - const bool* mask_data = mask->template Data(); - DropoutGradientKernelImpl(N, dY_data, mask_data, ratio_data, dX_data); + utils::MLTypeCallDispatcher t_disp(dY->GetElementType()); + t_disp.Invoke(N, *dY, mask_data, ratio_data, *dX); return Status::OK(); } + +ONNX_OPERATOR_KERNEL_EX( + BiasDropout, + kMSDomain, + 1, + kCudaExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T1", DataTypeImpl::AllIEEEFloatTensorTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .InputMemoryType(3) + .InputMemoryType(4), + BiasDropout); + +template +struct BiasDropoutComputeImpl { + Status operator()(const cudaDeviceProp& prop, + const int64_t N, + const fast_divmod fdm_dim, + const float ratio_data, + PhiloxGenerator& generator, + const Tensor& X, + const Tensor& bias, + const Tensor* residual, + Tensor& Y, + bool* mask_data) const { + typedef typename ToCudaType::MappedType CudaT; + + const CudaT* X_data = reinterpret_cast(X.template Data()); + const CudaT* bias_data = reinterpret_cast(bias.template Data()); + + const CudaT* residual_data = nullptr; + if (residual) { + if (residual->Shape() != X.Shape()) { + return Status(common::ONNXRUNTIME, common::FAIL, "Residual input shape does not match X input shape."); + } + residual_data = reinterpret_cast(residual->template Data()); + } + + CudaT* Y_data = reinterpret_cast(Y.template MutableData()); + + BiasDropoutKernelImpl(prop, N, fdm_dim, ratio_data, generator, X_data, bias_data, residual_data, Y_data, mask_data); + + return Status::OK(); + } +}; + +Status BiasDropout::ComputeInternal(OpKernelContext* context) const { + //Get X_data + const Tensor* X = context->Input(0); + ORT_RETURN_IF_NOT(X, "X Input is not available."); + + const TensorShape& x_shape = X->Shape(); + const int64_t N = x_shape.Size(); + + //Get bias_data + const Tensor* bias = context->Input(1); + if (bias == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "Bias input of BiasDropout is not available."); + const TensorShape& bias_shape = bias->Shape(); + if (bias_shape.NumDimensions() != 1) { + return Status(common::ONNXRUNTIME, common::FAIL, "Bias input is not a 1D tensor."); + } + const int64_t dim = bias_shape[0]; + if (dim != x_shape.GetDims().back()) { + return Status(common::ONNXRUNTIME, common::FAIL, "Bias' dimension doesn't match input's last dimension."); + } + + //Get residual_data + const Tensor* residual = context->Input(2); + + //Get Y_data + auto Y = context->Output(0, x_shape); + + //Get mask_data + auto mask = context->Output(1, x_shape); + + //Get the ratio_data + float ratio_data = default_ratio_; + auto ratio = context->Input(3); + if (ratio) { + utils::MLTypeCallDispatcher t_disp(ratio->GetElementType()); + t_disp.Invoke(ratio, ratio_data); + } + + //Check for inference mode. + const Tensor* training_mode = context->Input(4); + bool is_training_mode = (training_mode != nullptr) && training_mode->Data(); + if (!is_training_mode) { + ratio_data = 0.0f; + } + + IAllocatorUniquePtr temp_mask_buffer{}; // buffer to use if mask is not provided + bool* const mask_data = [this, N, mask, &temp_mask_buffer]() { + if (mask) return mask->MutableData(); + temp_mask_buffer = GetScratchBuffer(N); + return temp_mask_buffer.get(); + }(); + + const fast_divmod fdm_dim(gsl::narrow_cast(dim)); + PhiloxGenerator& generator = generator_ ? *generator_ : PhiloxGenerator::Default(); + + utils::MLTypeCallDispatcherRet t_disp(X->GetElementType()); + return t_disp.Invoke(GetDeviceProp(), N, fdm_dim, ratio_data, generator, *X, *bias, residual, *Y, mask_data); +} + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout.h b/orttraining/orttraining/training_ops/cuda/nn/dropout.h index 12084bfc7f..a92a10d38f 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout.h +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout.h @@ -9,16 +9,31 @@ namespace onnxruntime { namespace cuda { -template class DropoutGrad final : public CudaKernel { public: - DropoutGrad(const OpKernelInfo& info) : CudaKernel(info), default_ratio_(0.5) { + DropoutGrad(const OpKernelInfo& info) : CudaKernel(info) { } Status ComputeInternal(OpKernelContext* context) const override; private: - const float default_ratio_; + static constexpr float default_ratio_ = 0.5f; +}; + +class BiasDropout final : public CudaKernel { + public: + BiasDropout(const OpKernelInfo& info) : CudaKernel(info) { + int64_t seed = 0; + if (info.GetAttr("seed", &seed).IsOK()) { + generator_ = onnxruntime::make_unique(static_cast(seed)); + } + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + mutable std::unique_ptr generator_; + static constexpr float default_ratio_ = 0.5f; }; } // namespace cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu index df296f0fb1..4bf303d67e 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu @@ -24,15 +24,21 @@ namespace onnxruntime { namespace cuda { -template +template __global__ void DropoutGradientKernel( const int64_t N, const T* dY_data, const bool* mask_data, - const T scale, + const float scale, T* dX_data) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - dX_data[id] = dY_data[id] * T(mask_data[id]) * scale; + CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + dX_data[id] = T(float(dY_data[id]) * mask_data[id] * scale); + id += NumThreadsPerBlock; + } + } } template @@ -48,8 +54,9 @@ void DropoutGradientKernelImpl( } } else { const float scale = 1.f / (1.f - ratio); - const int blocksPerGrid = (N + GridDim::maxThreadsPerBlock - 1) / GridDim::maxThreadsPerBlock; - DropoutGradientKernel<<>>(N, dY_data, mask_data, T(scale), dX_data); + const int blocksPerGrid = static_cast(CeilDiv(N, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + DropoutGradientKernel + <<>>(N, dY_data, mask_data, scale, dX_data); } } @@ -65,5 +72,103 @@ SPECIALIZED_DROPOUT_GRAD_IMPL(float) SPECIALIZED_DROPOUT_GRAD_IMPL(double) SPECIALIZED_DROPOUT_GRAD_IMPL(half) +constexpr int UNROLL = 4; + +template +__global__ void BiasDropoutKernel( + const int64_t N, + const fast_divmod fdm_dim, + const float ratio, + const std::pair seeds, + const T* X_data, + const T* bias_data, + const T* residual_data, + T* Y_data, + bool* mask_data) { + const float p = 1.0f - ratio; + const float scale = 1.0f / p; + + CUDA_LONG idx = blockDim.x * blockIdx.x + threadIdx.x; + CUDA_LONG step_size = gridDim.x * blockDim.x * UNROLL; + CUDA_LONG rounded_size = ((N - 1) / step_size + 1) * step_size; + + curandStatePhilox4_32_10_t state; + curand_init(seeds.first, idx, seeds.second, &state); + + // We ensure every thread generates the same number of random numbers (by rounding + // up the size) and at the same timestep (by syncing threads). + // From CUDA curand documentation: + // The Philox_4x32_10 algorithm is closely tied to the thread and block count. + // Each thread computes 4 random numbers in the same time thus the most efficient + // use of Philox_4x32_10 is to generate a multiple of 4 times number of threads. + for (CUDA_LONG id = idx; id < rounded_size; id += step_size) { + float4 rand = curand_uniform4(&state); + + #pragma unroll + for (CUDA_LONG i = 0; i < UNROLL; i++) { + CUDA_LONG li = id + gridDim.x * blockDim.x * i; + if (li < N) { + int offset = fdm_dim.mod(li); + float bias = float(bias_data[offset]); + + mask_data[li] = (&rand.x)[i] < p; + float output_data = (float(X_data[li]) + bias) * mask_data[li] * scale; + if (has_residual) { + output_data += float(residual_data[li]); + } + + Y_data[li] = T(output_data); + } + } + + __syncthreads(); + } +} + +template +void BiasDropoutKernelImpl( + const cudaDeviceProp& prop, + const int64_t N, + const fast_divmod fdm_dim, + const float ratio, + PhiloxGenerator& generator, + const T* X_data, + const T* bias_data, + const T* residual_data, + T* Y_data, + bool* mask_data) { + const int block_size = 256; + const int blocks_per_sm = prop.maxThreadsPerMultiProcessor / block_size; + const int grid_size = std::min(prop.multiProcessorCount * blocks_per_sm, static_cast(CeilDiv(N, block_size * UNROLL))); + + // Compute the number of random numbers generated by each thread, and increment philox generator offset by that amount. + const uint64_t counter_offset = static_cast(((N - 1) / (block_size * grid_size * UNROLL) + 1) * UNROLL); + auto seeds = generator.NextPhiloxSeeds(counter_offset); + + if (residual_data == nullptr) { + BiasDropoutKernel<<>>(N, fdm_dim, ratio, seeds, X_data, bias_data, residual_data, Y_data, mask_data); + } else { + BiasDropoutKernel<<>>(N, fdm_dim, ratio, seeds, X_data, bias_data, residual_data, Y_data, mask_data); + } +} + +#define SPECIALIZED_BIAS_DROPOUT_IMPL(T) \ + template void BiasDropoutKernelImpl( \ + const cudaDeviceProp& prop, \ + const int64_t N, \ + const fast_divmod fdm_dim, \ + const float ratio, \ + PhiloxGenerator& generator, \ + const T* X_data, \ + const T* bias_data, \ + const T* residual_data, \ + T* Y_data, \ + bool* mask_data); + +SPECIALIZED_BIAS_DROPOUT_IMPL(float) +SPECIALIZED_BIAS_DROPOUT_IMPL(double) +SPECIALIZED_BIAS_DROPOUT_IMPL(half) + + } // namespace cuda } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.h b/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.h index b75ee462a6..09444662af 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.h +++ b/orttraining/orttraining/training_ops/cuda/nn/dropout_impl.h @@ -16,5 +16,18 @@ void DropoutGradientKernelImpl( const float ratio, T* dX_data); +template +void BiasDropoutKernelImpl( + const cudaDeviceProp& prop, + const int64_t N, + const fast_divmod fdm_dim, + const float ratio, + PhiloxGenerator& generator, + const T* X_data, + const T* bias_data, + const T* residual_data, + T* Y_data, + bool* mask_data); + } // namespace cuda } // namespace onnxruntime