diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 8f58597f45..da11068d34 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -99,12 +99,6 @@ namespace Windows::AI::MachineLearning::Adapter { GraphNodeFactory factory; std::optional requiredInputCount; - - // The operator inputs/outputs must be a floating point data type. When true, - // if the node's tensor data type is not-floating point, the node is partioned - // separately (unless the input/output is a CPU constant input, which is okay, - // as those can be read directly by the DML operator in the DML_OPERATOR_DESC). - bool requiresFloatFormatsExceptConstInputs = false; }; using KernelSupportQuery = std::function; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index b15c84963d..b3099fbaa9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -333,7 +333,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph, - bool requiresFloatFormatsForGraph, bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, @@ -503,7 +502,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( graphReg.requiredInputCount = *requiredInputCountForGraph; } - graphReg.requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph; regInfo->graphNodeFactoryRegistration = graphReg; } @@ -536,7 +534,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph || - requiresFloatFormatsForGraph || requiredConstantCpuInputs || supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp || diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index 2bd9bf8b1a..78d17418ef 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -41,7 +41,6 @@ class AbiCustomRegistry : public WRL::BaseName())); + THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); DML_OUTPUT_GRAPH_EDGE_DESC edge = {}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 32ba2abe50..9b6d958d72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -139,7 +139,7 @@ namespace Dml } }; - bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask) + bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask) { if (arg->Exists()) { @@ -164,14 +164,6 @@ namespace Dml } } - if (requiresFloatFormats) - { - if (mlDataType != MLOperatorTensorDataType::Float && - mlDataType != MLOperatorTensorDataType::Float16) - { - return false; - } - } } } } @@ -189,7 +181,6 @@ namespace Dml if (!isConstantCpuInput && !NodeArgSupportedInGraph( node.InputDefs()[i], - registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs, supportedDeviceDataTypeMask )) { @@ -201,7 +192,6 @@ namespace Dml { if (!NodeArgSupportedInGraph( arg, - registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs, supportedDeviceDataTypeMask )) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index c2aac0bddd..5930756472 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -65,7 +65,6 @@ struct OperatorRegistrationInformation MLOperatorKernelCreateFn creationFunction; MLOperatorShapeInferenceFunction shapeInferenceFunction; bool canAliasFirstInput; - bool requiresFloatFormatsForGraph = false; gsl::span tensorTypeNames; gsl::span supportedTensorDataTypes; @@ -314,26 +313,26 @@ constexpr auto requiredConstantCpuInputs(Args... args) // Define a single row of registration information. #define REG_INFO(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, // Versioned operator #define REG_INFO_VER(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, ##__VA_ARGS__, // Identity operators use Copy, alias their first input, and require floating point formats // for usage in the graph, besides constant inputs. This is because they currently use // element-wise identity operators in the graph for striding support, but issue actual copies // outside the graph. Element-wise identity currently only supports floating point types. #define REG_INFO_ID(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, true, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, ##__VA_ARGS__, // MS-domain operators #define REG_INFO_MS(version, operatorName, ...) \ - #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, // MS-domain operators #define REG_INFO_MSDML(version, operatorName, ...) \ - #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { @@ -690,7 +689,6 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) information.canAliasFirstInput, // alias kernelSupportsGraph, // supportsGraph information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr, - information.requiresFloatFormatsForGraph, supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index f1d2a225df..8a740c9794 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -112,7 +112,6 @@ IMLOperatorRegistryPrivate : public IUnknown bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph = nullptr, - bool requiresFloatFormatsForGraph = false, bool supportedWith64BitTensorsVia32BitStrides = false, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false, bool prefer64BitTensorsDirectly = false, diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index a04871b1c5..116efcc0a6 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -55,7 +55,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph, - bool requiresFloatFormatsForGraph, bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, @@ -81,7 +80,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( canAliasFirstInput, supportsGraph, requiredInputCountForGraph, - requiresFloatFormatsForGraph, supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index e1bfae7f36..c955c7e384 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -28,7 +28,6 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { bool can_alias_first_input, bool supports_graph, const uint32_t* required_input_count_for_graph = nullptr, - bool requires_float_formats_for_graph = false, bool supports_64bit_directly = false, bool allows_64bit_via_strides = false, bool allows_64bit_via_strides_from_any_ep = false,