diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index da11068d34..d5c651a0c4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -122,6 +122,10 @@ namespace Windows::AI::MachineLearning::Adapter // Operator supports true 64-bit tensors directly, no strides needed. // So fallback to strided 32-bit only occurs when the device lacks 64-bit support. bool prefer64BitTensorsDirectly = false; + + // The operator supports emulation for uint64/int64 even if the hardware doesn't + // support native uint64/int64 data types. + bool support64BitTensorsViaEmulation = false; }; using InternalRegistrationInfoMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index b3099fbaa9..c654d87c1d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -336,6 +336,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, + bool support64BitTensorsViaEmulation, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { @@ -461,6 +462,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides; regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp; regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly; + regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation; // Only internal operators support usage in DML graphs if (supportsGraph) @@ -537,7 +539,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( requiredConstantCpuInputs || supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp || - prefer64BitTensorsDirectly) + prefer64BitTensorsDirectly || + support64BitTensorsViaEmulation) { THROW_HR(E_INVALIDARG); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index 78d17418ef..2482d3af8b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -44,6 +44,7 @@ class AbiCustomRegistry : public WRL::BaseExists()) { @@ -154,8 +158,11 @@ namespace Dml MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast(tensorType.elem_type())); - if (mlDataType == MLOperatorTensorDataType::UInt64 || - mlDataType == MLOperatorTensorDataType::Int64) + // Do not include operators in the graph if tensor types are unsupported, + // except cases that are always supported via emulation. + if ((mlDataType == MLOperatorTensorDataType::UInt64 || + mlDataType == MLOperatorTensorDataType::Int64) && + !supports64BitTensorsViaEmulation) { constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64); if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit) @@ -181,6 +188,7 @@ namespace Dml if (!isConstantCpuInput && !NodeArgSupportedInGraph( node.InputDefs()[i], + registration.support64BitTensorsViaEmulation, supportedDeviceDataTypeMask )) { @@ -192,6 +200,7 @@ namespace Dml { if (!NodeArgSupportedInGraph( arg, + registration.support64BitTensorsViaEmulation, supportedDeviceDataTypeMask )) { @@ -234,6 +243,7 @@ namespace Dml THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap); bool prefer64BitTensorsDirectly = false; + bool support64BitTensorsViaEmulation = false; bool supportedWith64BitTensorsVia32BitStrides = false; bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false; std::vector constantCpuInputs; @@ -244,9 +254,10 @@ namespace Dml // to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false // in this particular call, then the operator-specific flags do not matter as the caller has // disabled 64-bit support. + prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; + support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation; if (allow64BitInputThroughStrides) { - prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly; supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp; supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp; } @@ -320,31 +331,44 @@ namespace Dml // operator, graph input or initializer, it's not safe to assume the input // can be represented with 32 bits. // + bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask; bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64); - bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask); - if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit) + if (is64BitIntType) { - dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); - - if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) + if (support64BitTensorsViaEmulation) { - // Look up the input partition. If it's a graph input or initializer it will be missing - // from the partition map. - const std::string& argName = nodeArg.Name(); + // Consider it supported regardless of hardware support. + isDataTypeSupported = true; + } + else if (prefer64BitTensorsDirectly && isDataTypeSupported) + { + // Operator supports native int64/uint64 tensors. + } + else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp) + { + dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType); + isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask; - // If input tensor's data comes from the output of a different execution provider, - // consider it unsafe to apply fallback to. - auto partitionIter = nodeNameToPartitionMap->find(argName); - if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) + if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp) { - nodeContainsSupportedDataTypes = false; - return; + // Look up the input partition. If it's a graph input or initializer it will be missing + // from the partition map. + const std::string& argName = nodeArg.Name(); + + // If input tensor's data comes from the output of a different execution provider, + // consider it unsafe to apply fallback to. + auto partitionIter = nodeNameToPartitionMap->find(argName); + if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition()) + { + nodeContainsSupportedDataTypes = false; + return; + } } } } // Reject node if the data type is unsupported by the device. - if (!((1 << dmlElementType) & supportedDeviceDataTypeMask)) + if (!isDataTypeSupported) { nodeContainsSupportedDataTypes = false; return; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp index 83bb47f1b8..24f68d9b75 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConstantOfShape.cpp @@ -22,6 +22,13 @@ public: std::vector> outputIndices = { 0 }; Initialize(kernelCreationContext, inputIndices, outputIndices); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC operatorDesc = {}; + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.ValueDataType = this->m_outputTensorDescs.front().GetDmlDataType(); + // operatorDesc.Value already zeroed. + // Read the tensor attribute for the output fill pattern. if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor)) { @@ -40,15 +47,13 @@ public: ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element. const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType()); const std::byte* rawData = static_cast(valueTensor->GetData()); - valueBytes.assign(rawData, rawData + rawDataByteSize); + + memcpy(operatorDesc.Value.Bytes, rawData, std::min(rawDataByteSize, sizeof(operatorDesc.Value.Bytes))); } // Else valueBytes is empty, and the default fill pattern is 0. - } - void Compute(const MLOperatorKernelContext& kernelContext) override - { - std::vector outputTensors = GetOutputTensorsForExecute(kernelContext); - THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes)); + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &operatorDesc }; + SetDmlOperatorDesc(opDesc, kernelCreationContext); } private: diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 84f6f8a5de..9e5db03173 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -35,13 +35,13 @@ enum class SupportedTensorDataTypes : uint32_t Complex64 = 1<<14, Complex128 = 1<<15, Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32, - Int32to64 = UInt32|Int32|UInt64|Int64, + Ints32to64 = UInt32|Int32|UInt64|Int64, + UInt8to64 = UInt8|UInt16|UInt32|UInt64, Float16to32 = Float16|Float32, Float16to64 = Float16|Float32|Float64, NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string. Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool, AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool, - AllScalarsButFloat64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool, Ints8Bit = UInt8|Int8, Ints16Bit = UInt16|Int16, Ints32Bit = UInt32|Int32, @@ -56,6 +56,7 @@ enum class DmlGraphSupport : uint32_t SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides. SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes) Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support. + Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types. }; DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport); @@ -254,6 +255,7 @@ constexpr static std::array typeNameListTwo = { "T1", "T2" }; constexpr static std::array typeNameListThree = { "T1", "T2", "T3" }; constexpr static std::array typeNameListFour = { "T1", "T2", "T3", "T4" }; constexpr static std::array typeNameListTopK = { "T", "I" }; +constexpr static std::array typeNameListMaxPool = { "T", "I" }; constexpr static std::array typeNameListLogicalComparison = { "T", "T1" }; constexpr static std::array typeNameListPow12 = {"T", "T1"}; constexpr static std::array typeNameListConstantOfShape = { "T1", "T2" }; @@ -269,31 +271,32 @@ constexpr static std::array supportedTypeListFloat1 constexpr static std::array supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit }; constexpr static std::array supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32}; constexpr static std::array supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit}; -constexpr static std::array supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32}; -constexpr static std::array supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32}; +constexpr static std::array supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64}; constexpr static std::array supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault }; -constexpr static std::array supportedTypeListAllScalarsButFloat64 = { SupportedTensorDataTypes::AllScalarsButFloat64 }; +constexpr static std::array supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListBool = {SupportedTensorDataTypes::Bool}; constexpr static std::array supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault}; constexpr static std::array supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64}; +constexpr static std::array supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64}; constexpr static std::array supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; -constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalarsButFloat64, SupportedTensorDataTypes::AllScalarsButFloat64 }; +constexpr static std::array supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; -constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; +constexpr static std::array supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars }; +constexpr static std::array supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; constexpr static std::array supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool }; -constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 }; +constexpr static std::array supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars }; constexpr static std::array supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 }; -constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; +constexpr static std::array supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; constexpr static std::array supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool }; constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 }; constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32}; +constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32 /* float32 ROI read by CPU */}; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; @@ -361,10 +364,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, - {REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, + {REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)}, {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -385,43 +388,43 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // Data Reorganization Layers - {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis. - {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. - {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, + {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis. + {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis. + {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes. + {REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, {REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, {REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, - {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))}, - {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, - {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, + {REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))}, + {REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(0))}, + {REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, {REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)}, // Data reorganization that merely changes the dimensions while keeping the data identical. - {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, - {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))}, + {REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -435,7 +438,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, - {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, + {REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))}, {REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, {REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)}, @@ -446,7 +449,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, - {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, + {REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, {REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)}, @@ -523,7 +526,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, {REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, - {REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, // Activation Functions {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -554,8 +557,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation // Uncategorized {REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, - {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, - {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, + {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, @@ -577,7 +580,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)}, {REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)}, - {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)}, + {REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, {REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)}, {REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, @@ -620,6 +623,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) // The graph must be configured with operators from only the legacy DML API, or only the new DML API bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported); bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly); + bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation); bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides); bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp); @@ -705,6 +709,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, + support64BitTensorsViaEmulation, information.requiredConstantCpuInputs.first.data(), static_cast(information.requiredConstantCpuInputs.second) )); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 8a740c9794..7172d53fb4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -115,6 +115,7 @@ IMLOperatorRegistryPrivate : public IUnknown bool supportedWith64BitTensorsVia32BitStrides = false, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false, bool prefer64BitTensorsDirectly = false, + bool support64BitTensorsViaEmulation = false, _In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr, uint32_t constantCpuInputCount = 0 ) const noexcept PURE; diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index 116efcc0a6..a748732230 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -58,6 +58,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, + bool support64BitTensorsViaEmulation, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, uint32_t constantCpuInputCount) const noexcept try { #ifdef LAYERING_DONE @@ -83,6 +84,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, + support64BitTensorsViaEmulation, requiredConstantCpuInputs, constantCpuInputCount); } diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index c955c7e384..f24ddd4b02 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -31,6 +31,7 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { bool supports_64bit_directly = false, bool allows_64bit_via_strides = false, bool allows_64bit_via_strides_from_any_ep = false, + bool supports_64bit_tensors_via_emulation = false, _In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr, uint32_t constant_cpu_input_count = 0) const noexcept override;