mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
Merged PR 4574316: Pad, OneHot, DepthToSpace, SpaceToDepth, TopK, Where int registrations
Related work items: #24673980, #24674011, #24674018, #24674032, #24674039
This commit is contained in:
parent
c170d087a1
commit
dc576a8de8
2 changed files with 24 additions and 17 deletions
|
|
@ -13,8 +13,9 @@ public:
|
|||
: DmlOperator(kernelInfo),
|
||||
PaddingHelper(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion)
|
||||
{
|
||||
ML_CHECK_VALID_ARGUMENT((kernelInfo.GetInputCount() == 1 && (opsetVersion >= 2 && opsetVersion < 11))
|
||||
|| (kernelInfo.GetInputCount() == 3 && (opsetVersion >= 11)));
|
||||
const uint32_t inputCount = kernelInfo.GetInputCount();
|
||||
ML_CHECK_VALID_ARGUMENT((opsetVersion >= 2 && opsetVersion < 11 && inputCount == 1)
|
||||
|| (opsetVersion >= 11 && inputCount >= 2 && inputCount <= 3));
|
||||
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
|
||||
|
||||
std::vector<std::optional<uint32_t>> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor.
|
||||
|
|
@ -55,11 +56,15 @@ public:
|
|||
ML_INVALID_ARGUMENT("Unknown Pad mode attribute.");
|
||||
}
|
||||
|
||||
float value;
|
||||
// Read the constant value which can come from an attribute or tensor.
|
||||
float value = 0.0f;
|
||||
if (opsetVersion >= 11)
|
||||
{
|
||||
auto valueTensor = kernelInfo.GetConstantInputTensor(2);
|
||||
value = static_cast<float>(ReadScalarTensorCastToFloat64(valueTensor));
|
||||
if (kernelInfo.IsInputValid(2))
|
||||
{
|
||||
auto valueTensor = kernelInfo.GetConstantInputTensor(2);
|
||||
value = static_cast<float>(ReadScalarTensorCastToFloat64(valueTensor));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ enum class SupportedTensorDataTypes : uint32_t
|
|||
Int32to64 = UInt32|Int32|UInt64|Int64,
|
||||
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
|
||||
NumericDefault = Int8to32|Float16to32,
|
||||
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
|
||||
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
|
||||
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
|
||||
Ints8Bit = UInt8|Int8,
|
||||
All = static_cast<uint32_t>(-1),
|
||||
|
|
@ -255,11 +255,12 @@ const static SupportedTensorDataTypes supportedTypeListInt32to64AndFloat16to32[1
|
|||
const static SupportedTensorDataTypes supportedTypeListNumericDefault[1] = { SupportedTensorDataTypes::NumericDefault };
|
||||
const static SupportedTensorDataTypes supportedTypeListAllScalars[1] = { SupportedTensorDataTypes::AllScalars };
|
||||
const static SupportedTensorDataTypes supportedTypeListBool[1] = {SupportedTensorDataTypes::Bool};
|
||||
const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
|
||||
const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
|
||||
const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
|
||||
const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::NumericDefault };
|
||||
const static SupportedTensorDataTypes supportedTypeListScalars8to32[1] = { SupportedTensorDataTypes::Scalars8to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::Scalars8to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
|
||||
const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
|
||||
const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
|
||||
|
|
@ -267,13 +268,14 @@ const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { Supported
|
|||
const static SupportedTensorDataTypes supportedTypeListIsNan[2] = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
const static SupportedTensorDataTypes supportedTypeListIsInf[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
|
||||
const static SupportedTensorDataTypes supportedTypeListConstantOfShape[2] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Float16to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Float16to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars };
|
||||
const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListLogicalComparison7[2] = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
|
||||
const static SupportedTensorDataTypes supportedTypeListLogicalComparison9[2] = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool };
|
||||
const static SupportedTensorDataTypes supportedTypeListSigned[1] = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
|
||||
const static SupportedTensorDataTypes supportedTypeListRange[1] = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
|
||||
const static SupportedTensorDataTypes supportedTypeListInteger[3] = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListPadWithoutFloat16[1] = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 };
|
||||
const static SupportedTensorDataTypes supportedTypeListQLinearMatMul[3] = {
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
|
||||
|
|
@ -355,12 +357,12 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
|
|||
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes.
|
||||
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)},
|
||||
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListFloat32, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1})},
|
||||
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
|
||||
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
|
||||
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
|
||||
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, {1})},
|
||||
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})},
|
||||
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})},
|
||||
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},
|
||||
|
|
|
|||
Loading…
Reference in a new issue