Merged PR 4574316: Pad, OneHot, DepthToSpace, SpaceToDepth, TopK, Where int registrations

Related work items: #24673980, #24674011, #24674018, #24674032, #24674039
This commit is contained in:
Dwayne Robinson 2020-04-18 01:05:33 +00:00
parent c170d087a1
commit dc576a8de8
2 changed files with 24 additions and 17 deletions

View file

@ -13,8 +13,9 @@ public:
: DmlOperator(kernelInfo),
PaddingHelper(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion)
{
ML_CHECK_VALID_ARGUMENT((kernelInfo.GetInputCount() == 1 && (opsetVersion >= 2 && opsetVersion < 11))
|| (kernelInfo.GetInputCount() == 3 && (opsetVersion >= 11)));
const uint32_t inputCount = kernelInfo.GetInputCount();
ML_CHECK_VALID_ARGUMENT((opsetVersion >= 2 && opsetVersion < 11 && inputCount == 1)
|| (opsetVersion >= 11 && inputCount >= 2 && inputCount <= 3));
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);
std::vector<std::optional<uint32_t>> kernelInputIndices = { 0 }; // Only bind GPU to first 'data' tensor.
@ -55,11 +56,15 @@ public:
ML_INVALID_ARGUMENT("Unknown Pad mode attribute.");
}
float value;
// Read the constant value which can come from an attribute or tensor.
float value = 0.0f;
if (opsetVersion >= 11)
{
auto valueTensor = kernelInfo.GetConstantInputTensor(2);
value = static_cast<float>(ReadScalarTensorCastToFloat64(valueTensor));
if (kernelInfo.IsInputValid(2))
{
auto valueTensor = kernelInfo.GetConstantInputTensor(2);
value = static_cast<float>(ReadScalarTensorCastToFloat64(valueTensor));
}
}
else
{

View file

@ -38,7 +38,7 @@ enum class SupportedTensorDataTypes : uint32_t
Int32to64 = UInt32|Int32|UInt64|Int64,
Float16to32 = Float16|Float32, // Float64 is not supported by DirectML.
NumericDefault = Int8to32|Float16to32,
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
Ints8Bit = UInt8|Int8,
All = static_cast<uint32_t>(-1),
@ -255,11 +255,12 @@ const static SupportedTensorDataTypes supportedTypeListInt32to64AndFloat16to32[1
const static SupportedTensorDataTypes supportedTypeListNumericDefault[1] = { SupportedTensorDataTypes::NumericDefault };
const static SupportedTensorDataTypes supportedTypeListAllScalars[1] = { SupportedTensorDataTypes::AllScalars };
const static SupportedTensorDataTypes supportedTypeListBool[1] = {SupportedTensorDataTypes::Bool};
const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::NumericDefault };
const static SupportedTensorDataTypes supportedTypeListScalars8to32[1] = { SupportedTensorDataTypes::Scalars8to32 };
const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
const static SupportedTensorDataTypes supportedTypeListScatterGatherND[1] = { SupportedTensorDataTypes::Scalars8to32 };
const static SupportedTensorDataTypes supportedTypeListSlice10[2] = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
const static SupportedTensorDataTypes supportedTypeListQuantizeLinear[2] = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
const static SupportedTensorDataTypes supportedTypeListDequantizeLinear[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
@ -267,13 +268,14 @@ const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { Supported
const static SupportedTensorDataTypes supportedTypeListIsNan[2] = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
const static SupportedTensorDataTypes supportedTypeListIsInf[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
const static SupportedTensorDataTypes supportedTypeListConstantOfShape[2] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 };
const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Float16to32 };
const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Float16to32 };
const static SupportedTensorDataTypes supportedTypeListWhere[2] = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::AllScalars };
const static SupportedTensorDataTypes supportedTypeListOneHot[3] = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
const static SupportedTensorDataTypes supportedTypeListLogicalComparison7[2] = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
const static SupportedTensorDataTypes supportedTypeListLogicalComparison9[2] = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool };
const static SupportedTensorDataTypes supportedTypeListSigned[1] = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
const static SupportedTensorDataTypes supportedTypeListRange[1] = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
const static SupportedTensorDataTypes supportedTypeListInteger[3] = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
const static SupportedTensorDataTypes supportedTypeListPadWithoutFloat16[1] = { SupportedTensorDataTypes::Int8to32 | SupportedTensorDataTypes::Float32 };
const static SupportedTensorDataTypes supportedTypeListQLinearMatMul[3] = {
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8,
@ -355,12 +357,12 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmGraphSupport::Supported, {1, 2, 3, 4}, std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListFloat32, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported, {1})},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmGraphSupport::Supported)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListPadWithoutFloat16, DmGraphSupport::Supported, {1, 2} /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmGraphSupport::Supported, {1})},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1})},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmGraphSupport::NotSupported, {0})},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmGraphSupport::Supported)},