Merged PR 6101363: Int64 prototype work for ONNX runtime DML EP

Add `support64BitTensorsViaEmulation` to the internal registration info, that informs the graph partitioner that int64 is supported via emulation, even if the device doesn't support it natively.

See further description in the corresponding WindowsAI DML PR:
https://dev.azure.com/microsoft/WindowsAI/_git/WindowsAI/pullrequest/6101182

Note a later PR will most likely *delete* this newly added flag and simplify much of the existing logic, even deleting the strides hack completely ^__^...

Related work items: #28761231
This commit is contained in:
Dwayne Robinson 2021-06-11 07:54:30 +00:00
parent f5c6b9703a
commit 8b0c2e1f3d
9 changed files with 121 additions and 75 deletions

View file

@ -122,6 +122,10 @@ namespace Windows::AI::MachineLearning::Adapter
// Operator supports true 64-bit tensors directly, no strides needed.
// So fallback to strided 32-bit only occurs when the device lacks 64-bit support.
bool prefer64BitTensorsDirectly = false;
// The operator supports emulation for uint64/int64 even if the hardware doesn't
// support native uint64/int64 data types.
bool support64BitTensorsViaEmulation = false;
};
using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;

View file

@ -336,6 +336,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
bool support64BitTensorsViaEmulation,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try
{
@ -461,6 +462,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
regInfo->supportedWith64BitTensorsVia32BitStrides = supportedWith64BitTensorsVia32BitStrides;
regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp = supportedWith64BitTensorsVia32BitStridesFromAnyEp;
regInfo->prefer64BitTensorsDirectly = prefer64BitTensorsDirectly;
regInfo->support64BitTensorsViaEmulation = support64BitTensorsViaEmulation;
// Only internal operators support usage in DML graphs
if (supportsGraph)
@ -537,7 +539,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
requiredConstantCpuInputs ||
supportedWith64BitTensorsVia32BitStrides ||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
prefer64BitTensorsDirectly)
prefer64BitTensorsDirectly ||
support64BitTensorsViaEmulation)
{
THROW_HR(E_INVALIDARG);
}

View file

@ -44,6 +44,7 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
bool support64BitTensorsViaEmulation = false,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;

View file

@ -139,7 +139,11 @@ namespace Dml
}
};
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask)
bool NodeArgSupportedInGraph(
const onnxruntime::NodeArg* arg,
bool supports64BitTensorsViaEmulation,
uint32_t supportedDeviceDataTypeMask
)
{
if (arg->Exists())
{
@ -154,8 +158,11 @@ namespace Dml
MLOperatorTensorDataType mlDataType = ToMLTensorDataType(static_cast<onnx::TensorProto_DataType>(tensorType.elem_type()));
if (mlDataType == MLOperatorTensorDataType::UInt64 ||
mlDataType == MLOperatorTensorDataType::Int64)
// Do not include operators in the graph if tensor types are unsupported,
// except cases that are always supported via emulation.
if ((mlDataType == MLOperatorTensorDataType::UInt64 ||
mlDataType == MLOperatorTensorDataType::Int64) &&
!supports64BitTensorsViaEmulation)
{
constexpr uint32_t deviceDataTypeMask64bit = (1 << DML_TENSOR_DATA_TYPE_UINT64) | (1 << DML_TENSOR_DATA_TYPE_INT64);
if ((supportedDeviceDataTypeMask & deviceDataTypeMask64bit) != deviceDataTypeMask64bit)
@ -181,6 +188,7 @@ namespace Dml
if (!isConstantCpuInput &&
!NodeArgSupportedInGraph(
node.InputDefs()[i],
registration.support64BitTensorsViaEmulation,
supportedDeviceDataTypeMask
))
{
@ -192,6 +200,7 @@ namespace Dml
{
if (!NodeArgSupportedInGraph(
arg,
registration.support64BitTensorsViaEmulation,
supportedDeviceDataTypeMask
))
{
@ -234,6 +243,7 @@ namespace Dml
THROW_HR_IF(E_INVALIDARG, allow64BitInputThroughStrides && !nodeNameToPartitionMap);
bool prefer64BitTensorsDirectly = false;
bool support64BitTensorsViaEmulation = false;
bool supportedWith64BitTensorsVia32BitStrides = false;
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false;
std::vector<onnxruntime::NodeArg const*> constantCpuInputs;
@ -244,9 +254,10 @@ namespace Dml
// to 32-bit tensors via strides. If the caller passes allow64BitInputThroughStrides = false
// in this particular call, then the operator-specific flags do not matter as the caller has
// disabled 64-bit support.
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
support64BitTensorsViaEmulation = regInfo->support64BitTensorsViaEmulation;
if (allow64BitInputThroughStrides)
{
prefer64BitTensorsDirectly = regInfo->prefer64BitTensorsDirectly;
supportedWith64BitTensorsVia32BitStridesFromAnyEp = regInfo->supportedWith64BitTensorsVia32BitStridesFromAnyEp;
supportedWith64BitTensorsVia32BitStrides = regInfo->supportedWith64BitTensorsVia32BitStrides | supportedWith64BitTensorsVia32BitStridesFromAnyEp;
}
@ -320,31 +331,44 @@ namespace Dml
// operator, graph input or initializer, it's not safe to assume the input
// can be represented with 32 bits.
//
bool isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
bool is64BitIntType = (dmlElementType == DML_TENSOR_DATA_TYPE_UINT64 || dmlElementType == DML_TENSOR_DATA_TYPE_INT64);
bool needsFallbackTo32Bit = !prefer64BitTensorsDirectly || !((1 << dmlElementType) & supportedDeviceDataTypeMask);
if (is64BitIntType && supportedWith64BitTensorsVia32BitStrides && needsFallbackTo32Bit)
if (is64BitIntType)
{
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
if (support64BitTensorsViaEmulation)
{
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the partition map.
const std::string& argName = nodeArg.Name();
// Consider it supported regardless of hardware support.
isDataTypeSupported = true;
}
else if (prefer64BitTensorsDirectly && isDataTypeSupported)
{
// Operator supports native int64/uint64 tensors.
}
else if (supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
dmlElementType = Remap64bitDmlDataTypeTo32bit(dmlElementType);
isDataTypeSupported = (1 << dmlElementType) & supportedDeviceDataTypeMask;
// If input tensor's data comes from the output of a different execution provider,
// consider it unsafe to apply fallback to.
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
if (isInput && !supportedWith64BitTensorsVia32BitStridesFromAnyEp)
{
nodeContainsSupportedDataTypes = false;
return;
// Look up the input partition. If it's a graph input or initializer it will be missing
// from the partition map.
const std::string& argName = nodeArg.Name();
// If input tensor's data comes from the output of a different execution provider,
// consider it unsafe to apply fallback to.
auto partitionIter = nodeNameToPartitionMap->find(argName);
if (partitionIter == nodeNameToPartitionMap->end() || !partitionIter->second->IsDmlPartition())
{
nodeContainsSupportedDataTypes = false;
return;
}
}
}
}
// Reject node if the data type is unsupported by the device.
if (!((1 << dmlElementType) & supportedDeviceDataTypeMask))
if (!isDataTypeSupported)
{
nodeContainsSupportedDataTypes = false;
return;

View file

@ -22,6 +22,13 @@ public:
std::vector<std::optional<uint32_t>> outputIndices = { 0 };
Initialize(kernelCreationContext, inputIndices, outputIndices);
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
DML_FILL_VALUE_CONSTANT_OPERATOR_DESC operatorDesc = {};
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.ValueDataType = this->m_outputTensorDescs.front().GetDmlDataType();
// operatorDesc.Value already zeroed.
// Read the tensor attribute for the output fill pattern.
if (kernelCreationContext.HasAttribute(AttrName::Value, MLOperatorAttributeTypeTensor))
{
@ -40,15 +47,13 @@ public:
ML_CHECK_VALID_ARGUMENT(elementCount == 1); // Expect exactly one element.
const size_t rawDataByteSize = GetByteSizeFromMlDataType(wrappedValueTensor.GetTensorDataType());
const std::byte* rawData = static_cast<const std::byte*>(valueTensor->GetData());
valueBytes.assign(rawData, rawData + rawDataByteSize);
memcpy(operatorDesc.Value.Bytes, rawData, std::min(rawDataByteSize, sizeof(operatorDesc.Value.Bytes)));
}
// Else valueBytes is empty, and the default fill pattern is 0.
}
void Compute(const MLOperatorKernelContext& kernelContext) override
{
std::vector<IMLOperatorTensor*> outputTensors = GetOutputTensorsForExecute(kernelContext);
THROW_IF_FAILED(m_executionProvider->FillTensorWithPattern(outputTensors.front(), valueBytes));
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
private:

View file

@ -35,13 +35,13 @@ enum class SupportedTensorDataTypes : uint32_t
Complex64 = 1<<14,
Complex128 = 1<<15,
Ints8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32,
Int32to64 = UInt32|Int32|UInt64|Int64,
Ints32to64 = UInt32|Int32|UInt64|Int64,
UInt8to64 = UInt8|UInt16|UInt32|UInt64,
Float16to32 = Float16|Float32,
Float16to64 = Float16|Float32|Float64,
NumericDefault = Ints8to32|Float16to32, // Only simple numbers, not bool, complex, or string.
Scalars8to32 = UInt8|Int8|UInt16|Int16|UInt32|Int32|Float16to32|Bool,
AllScalars = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16|Float32|Float64|Bool,
AllScalarsButFloat64 = UInt8|Int8|UInt16|Int16|UInt32|Int32|UInt64|Int64|Float16to32|Bool,
Ints8Bit = UInt8|Int8,
Ints16Bit = UInt16|Int16,
Ints32Bit = UInt32|Int32,
@ -56,6 +56,7 @@ enum class DmlGraphSupport : uint32_t
SupportedWith64BitTensorsVia32BitStrides = 2, // Supports them via 32-bit tensors and doubled strides.
SupportedWith64BitTensorsVia32BitStridesFromAnyEp = 4, // Supports input from any execution provider (otherwise only inputs from other DML nodes)
Prefer64BitTensorsDirectly = 8, // Natively supports 64-bit tensors. So avoid strided 32-bit unless the device lacks support.
Support64BitTensorsViaEmulation = 16,// supports int/uint64 tensors via emulation of 32-bit types.
};
DEFINE_ENUM_FLAG_OPERATORS(DmlGraphSupport);
@ -254,6 +255,7 @@ constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 3> typeNameListThree = { "T1", "T2", "T3" };
constexpr static std::array<const char*, 4> typeNameListFour = { "T1", "T2", "T3", "T4" };
constexpr static std::array<const char*, 2> typeNameListTopK = { "T", "I" };
constexpr static std::array<const char*, 2> typeNameListMaxPool = { "T", "I" };
constexpr static std::array<const char*, 2> typeNameListLogicalComparison = { "T", "T1" };
constexpr static std::array<const char*, 2> typeNameListPow12 = {"T", "T1"};
constexpr static std::array<const char*, 2> typeNameListConstantOfShape = { "T1", "T2" };
@ -269,31 +271,32 @@ constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat1
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::UInt32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListFloat16to32Int8to32 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt8to32 = {SupportedTensorDataTypes::Ints8to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInt32to64AndFloat16to32 = {SupportedTensorDataTypes::Int32to64|SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListUInt8to64 = {SupportedTensorDataTypes::UInt8to64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNumericDefault = { SupportedTensorDataTypes::NumericDefault };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalarsButFloat64 = { SupportedTensorDataTypes::AllScalarsButFloat64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListAllScalars = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListBool = {SupportedTensorDataTypes::Bool};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListPow12 = {SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::NumericDefault};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListTopK = {SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListMaxPool = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListIndices = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalarsButFloat64, SupportedTensorDataTypes::AllScalarsButFloat64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListCast = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScalars8to32 = { SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListScatterGather = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListScatterGatherND = { SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSlice10 = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantizeLinear = { SupportedTensorDataTypes::Float32 | SupportedTensorDataTypes::Int32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListDequantizeLinear = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::Int32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListQuantize = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsNan = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListIsInf = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::Float16to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListConstantOfShape = { SupportedTensorDataTypes::Int64, SupportedTensorDataTypes::AllScalars };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListWhere = { SupportedTensorDataTypes::Bool, SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Int32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListOneHot = /* indices, depth, values */ { SupportedTensorDataTypes::Ints32to64, SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison7 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::NumericDefault, SupportedTensorDataTypes::Bool };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Float32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float32 /* float32 ROI read by CPU */};
constexpr static std::array<SupportedTensorDataTypes, 3> supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 };
@ -361,10 +364,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListDefault, supportedTypeListFloat16to32Int8, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 10, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 11, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 12, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -385,43 +388,43 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
// Data Reorganization Layers
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO( 7, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, Split, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis.
{REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, Concat, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)}, // Adds negative axis.
{REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_VER( 10, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)}, // Adds negative axes.
{REG_INFO_VER( 11, Slice, typeNameListSlice10, supportedTypeListSlice10, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1, 2, 3, 4), std::nullopt, QuerySlice)},
{REG_INFO_VER( 7, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728
{REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 11, DepthToSpace, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO( 7, Tile, typeNameListDefault, supportedTypeListScalars8to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::NotSupported, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Prefer64BitTensorsDirectly|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 8, Expand, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))},
{REG_INFO( 9, ConstantOfShape, typeNameListConstantOfShape, supportedTypeListConstantOfShape, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(0))},
{REG_INFO( 7, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, Gather, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, GatherElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 12, GatherND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_VER( 9, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_VER( 11, Scatter, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListScalars8to32, DmlGraphSupport::Supported)},
// Data reorganization that merely changes the dimensions while keeping the data identical.
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalarsButFloat64, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides, requiredConstantCpuInputs(1))},
{REG_INFO_ID( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 7, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 11, Squeeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 7, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 11, Unsqueeze, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO_ID( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation, requiredConstantCpuInputs(1))},
// Elementwise
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -435,7 +438,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Floor, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 7, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 11, Clip, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO_VER( 12, Clip, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1,2))},
{REG_INFO( 7, Add, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Sub, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Mul, typeNameListDefault, supportedTypeListFloat16to32Int32, DmlGraphSupport::Supported)},
@ -446,7 +449,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 8, Mean, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Max, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Max, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 7, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 8, Min, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
{REG_INFO( 12, Min, typeNameListDefault, supportedTypeListFloat16to32Int8to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(), 2)},
@ -523,7 +526,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 9, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Upsample, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)},
{REG_INFO_VER( 11, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
{REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
// Activation Functions
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
@ -554,8 +557,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
// Uncategorized
{REG_INFO( 7, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 9, MatMul, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides)},
@ -577,7 +580,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListInt8to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, BitShift, typeNameListDefault, supportedTypeListUInt8to64, DmlGraphSupport::Supported|DmlGraphSupport::Support64BitTensorsViaEmulation)},
{REG_INFO( 11, Round, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, ReverseSequence, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported|DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp)},
{REG_INFO( 11, CumSum, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
@ -620,6 +623,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
// The graph must be configured with operators from only the legacy DML API, or only the new DML API
bool kernelSupportsGraph = !bool(information.dmlGraphSupport & DmlGraphSupport::NotSupported);
bool prefer64BitTensorsDirectly = bool(information.dmlGraphSupport & DmlGraphSupport::Prefer64BitTensorsDirectly);
bool support64BitTensorsViaEmulation = bool(information.dmlGraphSupport & DmlGraphSupport::Support64BitTensorsViaEmulation);
bool supportedWith64BitTensorsVia32BitStrides = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStrides);
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = bool(information.dmlGraphSupport & DmlGraphSupport::SupportedWith64BitTensorsVia32BitStridesFromAnyEp);
@ -705,6 +709,7 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,
support64BitTensorsViaEmulation,
information.requiredConstantCpuInputs.first.data(),
static_cast<uint32_t>(information.requiredConstantCpuInputs.second)
));

View file

@ -115,6 +115,7 @@ IMLOperatorRegistryPrivate : public IUnknown
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,
bool support64BitTensorsViaEmulation = false,
_In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0
) const noexcept PURE;

View file

@ -58,6 +58,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
bool support64BitTensorsViaEmulation,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept try {
#ifdef LAYERING_DONE
@ -83,6 +84,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,
support64BitTensorsViaEmulation,
requiredConstantCpuInputs,
constantCpuInputCount);
}

View file

@ -31,6 +31,7 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
bool supports_64bit_directly = false,
bool allows_64bit_via_strides = false,
bool allows_64bit_via_strides_from_any_ep = false,
bool supports_64bit_tensors_via_emulation = false,
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
uint32_t constant_cpu_input_count = 0) const noexcept override;