From 8b0c2e1f3d72db6708317d6c655e7ee68e961cba Mon Sep 17 00:00:00 2001 From: Dwayne Robinson Date: Fri, 11 Jun 2021 07:54:30 +0000 Subject: [PATCH] Merged PR 6101363: Int64 prototype work for ONNX runtime DML EP Add `support64BitTensorsViaEmulation` to the internal registration info, that informs the graph partitioner that int64 is supported via emulation, even if the device doesn't support it natively. See further description in the corresponding WindowsAI DML PR: https://dev.azure.com/microsoft/WindowsAI/_git/WindowsAI/pullrequest/6101182 Note a later PR will most likely *delete* this newly added flag and simplify much of the existing logic, even deleting the strides hack completely ^__^... Related work items: #28761231 --- .../inc/IWinmlExecutionProvider.h | 4 + .../src/AbiCustomRegistry.cpp | 5 +- .../src/AbiCustomRegistry.h | 1 + .../src/GraphPartitioner.cpp | 62 +++++++---- .../Operators/DmlOperatorConstantOfShape.cpp | 17 ++- .../src/Operators/OperatorRegistration.cpp | 103 +++++++++--------- .../MLOperatorAuthorPrivate.h | 1 + winml/adapter/abi_custom_registry_impl.cpp | 2 + winml/adapter/abi_custom_registry_impl.h | 1 + 9 files changed, 121 insertions(+), 75 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index da11068d34..d5c651a0c4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -122,6 +122,10 @@ namespace Windows::AI::MachineLearning::Adapter // Operator supports true 64-bit tensors directly, no strides needed. // So fallback to strided 32-bit only occurs when the device lacks 64-bit support. bool prefer64BitTensorsDirectly = false; + + // The operator supports emulation for uint64/int64 even if the hardware doesn't + // support native uint64/int64 data types. + bool support64BitTensorsViaEmulation = false; }; using InternalRegistrationInfoMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index b3099fbaa9..c654d87c1d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -336,6 +336,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, + bool support64BitTensorsViaEmulation, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { @@ -461,6 +462,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides; regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp; regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly; + regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation; // Only internal operators support usage in DML graphs if (supportsGraph) @@ -537,7 +539,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( requiredConstantCpuInputs || supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp || - prefer64BitTensorsDirectly) + prefer64BitTensorsDirectly || + support64BitTensorsViaEmulation) { THROW_HR(E_INVALIDARG); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index 78d17418ef..2482d3af8b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -44,6 +44,7 @@ class AbiCustomRegistry : public WRL::BaseExists()) { @@ -154,8 +158,11 @@ namespace Dml MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast(tensorType.elem_type())); - if (mlDataType == MLOperatorTensorDataType::UInt64 || - mlDataType == MLOperatorTensorDataType::Int64) + // Do not include operators in the graph if tensor types are unsupported, + // except cases that are always supported via emulation. + if ((mlDataType == MLOperatorTensorDataType::UInt64 || + mlDataType == MLOperatorTensorDataType::Int64) && + !supports64BitTensorsViaEmulation) { constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64); if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit) @@ -181,6 +188,7 @@ namespace Dml if (!isConstantCpuInput && !NodeArgSupportedInGraph( node.InputDefs()[i], + registration.support64BitTensorsViaEmulation, supportedDeviceDataTypeMask )) { @@ -192,6 +200,7 @@ namespace Dml { if (!NodeArgSupportedInGraph( arg, + registration.support64BitTensorsViaEmulation, supportedDeviceDataTypeMask )) { @@ -234,6 +243,7 @@ namespace Dml THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap); bool prefer64BitTensorsDirectly = false; + bool support64BitTensorsViaEmulation = false; bool supportedWith64BitTensorsVia32BitStrides = false; bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false; std::vector constantCpuInputs; @@ -244,9 +254,10 @@ namespace Dml // to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false // in this particular call, then the operator-specific flags do not matter as the caller has // disabled 64-bit support. + prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; + support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation; if (allow64BitInputThroughStrides) { - prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp; supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp; } @@ -320,31 +331,44 @@ namespace Dml // operator, graph input or initializer, it's not safe to assume the input // can be represented with 32 bits. // + bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask; bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64); - bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask); - if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit) + if (is64BitIntType) { - dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); - - if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) + if (support64BitTensorsViaEmulation) { - // Look up the input partition. If it's a graph input or initializer it will be missing - // from the partition map. - const std::string& argName = nodeArg.Name(); + // Consider it supported regardless of hardware support. + isDataTypeSupported = true; + } + else if (prefer64BitTensorsDirectly && isDataTypeSupported) + { + // Operator supports native int64/uint64 tensors. + } + else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp) + { + dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); + isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask; - // If input tensor's data comes from the output of a different execution provider, - // consider it unsafe to apply fallback to. - auto partitionIter = nodeNameToPartitionMap->find(argName); - if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) + if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) { - nodeContainsSupportedDataTypes = false; - return; + // Look up the input partition. If it's a graph input or initializer it will be missing + // from the partition map. + const std::string& argName = nodeArg.Name(); + + // If input tensor's data comes from the output of a different execution provider, + // consider it unsafe to apply fallback to. + auto partitionIter = nodeNameToPartitionMap->find(argName); + if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) + { + nodeContainsSupportedDataTypes = false; + return; + } } } } // Reject node if the data type is unsupported by the device. - if (!((1 << dmlElementType) & supportedDeviceDataTypeMask)) + if (!isDataTypeSupported) { nodeContainsSupportedDataTypes = false; return; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp index 83bb47f1b8..24f68d9b75 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp @@ -22,6 +22,13 @@ public: std::vector> outputIndices = { 0 }; Initialize(kernelCreationContext, inputIndices, outputIndices); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC operatorDesc = {}; + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.ValueDataType = this->m_outputTensorDescs.front().GetDmlDataType(); + // operatorDesc.Value already zeroed. + // Read the tensor attribute for the output fill pattern. if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor)) { @@ -40,15 +47,13 @@ public: ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element. const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType()); const std::byte* rawData = static_cast(valueTensor->GetData()); - valueBytes.assign(rawData, rawData + rawDataByteSize); + + memcpy(operatorDesc.Value.Bytes, rawData, std::min(rawDataByteSize, sizeof(operatorDesc.Value.Bytes))); } // Else valueBytes is empty, and the default fill pattern is 0. - } - void Compute(const MLOperatorKernelContext& kernelContext) override - { - std::vector outputTensors = GetOutputTensorsForExecute(kernelContext); - THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes)); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); } private: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 84f6f8a5de..9e5db03173 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -35,13 +35,13 @@ enum class SupportedTensorDataTypes : uint32_t Complex64 = 1<<14, Complex128 = 1<<15, Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, - Int32to64 = UInt32|Int32|UInt64|Int64, + Ints32to64 = UInt32|Int32|UInt64|Int64, + UInt8to64 = UInt8|UInt16|UInt32|UInt64, Float16to32 = Float16|Float32, Float16to64 = Float16|Float32|Float64, NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string. Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool, - AllScalarsButFloat64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool, Ints8Bit = UInt8|Int8, Ints16Bit = UInt16|Int16, Ints32Bit = UInt32|Int32, @@ -56,6 +56,7 @@ enum class DmlGraphSupport : uint32_t SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides. SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes) Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support. + Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types. }; DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport); @@ -254,6 +255,7 @@ constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListThree = { "T1", "T2", "T3" }; constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; constexpr static std::array typeNameListTopK = { "T", "I" }; +constexpr static std::array typeNameListMaxPool = { "T", "I" }; constexpr static std::array typeNameListLogicalComparison = { "T", "T1" }; constexpr static std::array typeNameListPow12 = {"T", "T1"}; constexpr static std::array typeNameListConstantOfShape = { "T1", "T2" }; @@ -269,31 +271,32 @@ constexpr static std::array supportedTypeListFloat1 constexpr static std::array supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit }; constexpr static std::array supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; constexpr static std::array supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit}; -constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32}; -constexpr static std::array supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64}; constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; -constexpr static std::array supportedTypeListAllScalarsButFloat64 = { SupportedTensorDataTypes::AllScalarsButFloat64 }; +constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault}; constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; -constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalarsButFloat64, SupportedTensorDataTypes::AllScalarsButFloat64 }; +constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; -constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars }; +constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; -constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 }; +constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; constexpr static std::array supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; +constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32 /* float32 ROI read by CPU */}; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; @@ -361,10 +364,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -385,43 +388,43 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // Data Reorganization Layers - {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. - {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, + {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis. + {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis. + {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, - {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(0))}, + {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, // Data reorganization that merely changes the dimensions while keeping the data identical. - {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -435,7 +438,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, @@ -446,7 +449,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, @@ -523,7 +526,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, // Activation Functions {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -554,8 +557,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation // Uncategorized {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, @@ -577,7 +580,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, @@ -620,6 +623,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) // The graph must be configured with operators from only the legacy DML API, or only the new DML API bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported); bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly); + bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation); bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides); bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp); @@ -705,6 +709,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, + support64BitTensorsViaEmulation, information.requiredConstantCpuInputs.first.data(), static_cast(information.requiredConstantCpuInputs.second) )); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 8a740c9794..7172d53fb4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -115,6 +115,7 @@ IMLOperatorRegistryPrivate : public IUnknown bool supportedWith64BitTensorsVia32BitStrides = false, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false, bool prefer64BitTensorsDirectly = false, + bool support64BitTensorsViaEmulation = false, _In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr, uint32_t constantCpuInputCount = 0 ) const noexcept PURE; diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index 116efcc0a6..a748732230 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -58,6 +58,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, + bool support64BitTensorsViaEmulation, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { #ifdef LAYERING_DONE @@ -83,6 +84,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, + support64BitTensorsViaEmulation, requiredConstantCpuInputs, constantCpuInputCount); } diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index c955c7e384..f24ddd4b02 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -31,6 +31,7 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { bool supports_64bit_directly = false, bool allows_64bit_via_strides = false, bool allows_64bit_via_strides_from_any_ep = false, + bool supports_64bit_tensors_via_emulation = false, _In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr, uint32_t constant_cpu_input_count = 0) const noexcept override;