diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index 251c72df06..0fcb935416 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -13,8 +13,9 @@ public: : DmlOperator(kernelInfo), PaddingHelper(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion) { - ML_CHECK_VALID_ARGUMENT((kernelInfo.GetInputCount() == 1 && (opsetVersion >= 2 && opsetVersion < 11)) - || (kernelInfo.GetInputCount() == 3 && (opsetVersion >= 11))); + const uint32_t inputCount = kernelInfo.GetInputCount(); + ML_CHECK_VALID_ARGUMENT((opsetVersion >= 2 && opsetVersion < 11 && inputCount == 1) + || (opsetVersion >= 11 && inputCount >= 2 && inputCount <= 3)); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); std::vector> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor. @@ -55,11 +56,15 @@ public: ML_INVALID_ARGUMENT("Unknown Pad mode attribute."); } - float value; + // Read the constant value which can come from an attribute or tensor. + float value = 0.0f; if (opsetVersion >= 11) { - auto valueTensor = kernelInfo.GetConstantInputTensor(2); - value = static_cast(ReadScalarTensorCastToFloat64(valueTensor)); + if (kernelInfo.IsInputValid(2)) + { + auto valueTensor = kernelInfo.GetConstantInputTensor(2); + value = static_cast(ReadScalarTensorCastToFloat64(valueTensor)); + } } else { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 108622fcfa..61aeab0069 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -38,7 +38,7 @@ enum class SupportedTensorDataTypes : uint32_t Int32to64 = UInt32|Int32|UInt64|Int64, Float16to32 = Float16|Float32, // Float64 is not supported by DirectML. NumericDefault = Int8to32|Float16to32, - Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, + Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool, Ints8Bit = UInt8|Int8, All = static_cast(-1), @@ -255,11 +255,12 @@ const static SupportedTensorDataTypes supportedTypeListInt32to64AndFloat16to32[1 const static SupportedTensorDataTypes supportedTypeListNumericDefault[1] = { SupportedTensorDataTypes::NumericDefault }; const static SupportedTensorDataTypes supportedTypeListAllScalars[1] = { SupportedTensorDataTypes::AllScalars }; const static SupportedTensorDataTypes supportedTypeListBool[1] = {SupportedTensorDataTypes::Bool}; -const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; +const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; -const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; -const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::NumericDefault }; +const static SupportedTensorDataTypes supportedTypeListScalars8to32[1] = { SupportedTensorDataTypes::Scalars8to32 }; +const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::Scalars8to32 }; const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; @@ -267,13 +268,14 @@ const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { Supported const static SupportedTensorDataTypes supportedTypeListIsNan[2] = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; const static SupportedTensorDataTypes supportedTypeListIsInf[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; const static SupportedTensorDataTypes supportedTypeListConstantOfShape[2] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 }; -const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Float16to32 }; -const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Float16to32 }; +const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars }; +const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; const static SupportedTensorDataTypes supportedTypeListLogicalComparison7[2] = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; const static SupportedTensorDataTypes supportedTypeListLogicalComparison9[2] = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool }; const static SupportedTensorDataTypes supportedTypeListSigned[1] = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; const static SupportedTensorDataTypes supportedTypeListRange[1] = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; const static SupportedTensorDataTypes supportedTypeListInteger[3] = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; +const static SupportedTensorDataTypes supportedTypeListPadWithoutFloat16[1] = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 }; const static SupportedTensorDataTypes supportedTypeListQLinearMatMul[3] = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -355,12 +357,12 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes. {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, - {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListFloat32, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 - {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1})}, + {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)}, + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, + {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, + {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)}, + {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, {1})}, {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})}, {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})}, {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},