From 43c45ddd66f4ee01d7d6f63b785423cdf874aadd Mon Sep 17 00:00:00 2001 From: sumitsays Date: Fri, 11 Jun 2021 11:09:48 -0700 Subject: [PATCH] Update DirectML EP changes from DmlDev as of 2021-06-07 (#7987) * Merged PR 6093117: Fix test_DynamicQuantizedLinear_max_adjusted_expanded by allowing Identity operator to run on non-float inputs Motivation: As part of the OnnxConformance Backend tests, DynamicQuantizedLinear_max_adjusted_expanded is failing. Root Cause: - The test model has `Identity` operator as one of the node. The input of this node is of non-float data type. - In DML, `Identity` operator is registered as operator which requires floating input. - As per `DirectMLSchema.h`, support for non-float input has been added for `Identity` operator in DML but the same has not been reflected in the `OperatorRegistration.cpp`. Changes: - Removed all traces of the requiresFloatFormatsForGraph flag from it's definition and usage. This flag was only used for Identity and it's related operator. - Added null check for the graphOutput nodeArg in GraphDescBuilder.cpp to stop the crash of the test. Related work items: #33076298 * Merged PR 6103324: Remove usage of non-generic error code (FWP_E_NULL_POINTER) Motivation: Addressing Dwayne comment on the previous PR. [Ref: [6093117](https://dev.azure.com/microsoft/WindowsAI/_git/onnxruntime/pullrequest/6093117?discussionId=44292162&path=%2Fonnxruntime%2Fcore%2Fproviders%2Fdml%2FDmlExecutionProvider%2Fsrc%2FGraphPartitioner.cpp)] Changes: Inside the DML EP, we should not use some other platform specific error codes. Instead we should a appropriate generic error code. Related work items: #33076298 Co-authored-by: Sumit Agarwal --- .../inc/IWinmlExecutionProvider.h | 6 ------ .../DmlExecutionProvider/src/AbiCustomRegistry.cpp | 3 --- .../dml/DmlExecutionProvider/src/AbiCustomRegistry.h | 1 - .../DmlExecutionProvider/src/GraphDescBuilder.cpp | 1 + .../DmlExecutionProvider/src/GraphPartitioner.cpp | 12 +----------- .../src/Operators/OperatorRegistration.cpp | 12 +++++------- .../OperatorAuthorHelper/MLOperatorAuthorPrivate.h | 1 - winml/adapter/abi_custom_registry_impl.cpp | 2 -- winml/adapter/abi_custom_registry_impl.h | 1 - 9 files changed, 7 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 8f58597f45..da11068d34 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -99,12 +99,6 @@ namespace Windows::AI::MachineLearning::Adapter { GraphNodeFactory factory; std::optional requiredInputCount; - - // The operator inputs/outputs must be a floating point data type. When true, - // if the node's tensor data type is not-floating point, the node is partioned - // separately (unless the input/output is a CPU constant input, which is okay, - // as those can be read directly by the DML operator in the DML_OPERATOR_DESC). - bool requiresFloatFormatsExceptConstInputs = false; }; using KernelSupportQuery = std::function; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index b15c84963d..b3099fbaa9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -333,7 +333,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph, - bool requiresFloatFormatsForGraph, bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, @@ -503,7 +502,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( graphReg.requiredInputCount = *requiredInputCountForGraph; } - graphReg.requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph; regInfo->graphNodeFactoryRegistration = graphReg; } @@ -536,7 +534,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph || - requiresFloatFormatsForGraph || requiredConstantCpuInputs || supportedWith64BitTensorsVia32BitStrides || supportedWith64BitTensorsVia32BitStridesFromAnyEp || diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index 2bd9bf8b1a..78d17418ef 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -41,7 +41,6 @@ class AbiCustomRegistry : public WRL::BaseName())); + THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); DML_OUTPUT_GRAPH_EDGE_DESC edge = {}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 32ba2abe50..9b6d958d72 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -139,7 +139,7 @@ namespace Dml } }; - bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask) + bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask) { if (arg->Exists()) { @@ -164,14 +164,6 @@ namespace Dml } } - if (requiresFloatFormats) - { - if (mlDataType != MLOperatorTensorDataType::Float && - mlDataType != MLOperatorTensorDataType::Float16) - { - return false; - } - } } } } @@ -189,7 +181,6 @@ namespace Dml if (!isConstantCpuInput && !NodeArgSupportedInGraph( node.InputDefs()[i], - registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs, supportedDeviceDataTypeMask )) { @@ -201,7 +192,6 @@ namespace Dml { if (!NodeArgSupportedInGraph( arg, - registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs, supportedDeviceDataTypeMask )) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index c2aac0bddd..5930756472 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -65,7 +65,6 @@ struct OperatorRegistrationInformation MLOperatorKernelCreateFn creationFunction; MLOperatorShapeInferenceFunction shapeInferenceFunction; bool canAliasFirstInput; - bool requiresFloatFormatsForGraph = false; gsl::span tensorTypeNames; gsl::span supportedTensorDataTypes; @@ -314,26 +313,26 @@ constexpr auto requiredConstantCpuInputs(Args... args) // Define a single row of registration information. #define REG_INFO(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, // Versioned operator #define REG_INFO_VER(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, ##__VA_ARGS__, // Identity operators use Copy, alias their first input, and require floating point formats // for usage in the graph, besides constant inputs. This is because they currently use // element-wise identity operators in the graph for striding support, but issue actual copies // outside the graph. Element-wise identity currently only supports floating point types. #define REG_INFO_ID(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, true, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, ##__VA_ARGS__, // MS-domain operators #define REG_INFO_MS(version, operatorName, ...) \ - #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, // MS-domain operators #define REG_INFO_MSDML(version, operatorName, ...) \ - #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { @@ -690,7 +689,6 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) information.canAliasFirstInput, // alias kernelSupportsGraph, // supportsGraph information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr, - information.requiresFloatFormatsForGraph, supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index f1d2a225df..8a740c9794 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -112,7 +112,6 @@ IMLOperatorRegistryPrivate : public IUnknown bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph = nullptr, - bool requiresFloatFormatsForGraph = false, bool supportedWith64BitTensorsVia32BitStrides = false, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false, bool prefer64BitTensorsDirectly = false, diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index a04871b1c5..116efcc0a6 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -55,7 +55,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph, - bool requiresFloatFormatsForGraph, bool supportedWith64BitTensorsVia32BitStrides, bool supportedWith64BitTensorsVia32BitStridesFromAnyEp, bool prefer64BitTensorsDirectly, @@ -81,7 +80,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( canAliasFirstInput, supportsGraph, requiredInputCountForGraph, - requiresFloatFormatsForGraph, supportedWith64BitTensorsVia32BitStrides, supportedWith64BitTensorsVia32BitStridesFromAnyEp, prefer64BitTensorsDirectly, diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index e1bfae7f36..c955c7e384 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -28,7 +28,6 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { bool can_alias_first_input, bool supports_graph, const uint32_t* required_input_count_for_graph = nullptr, - bool requires_float_formats_for_graph = false, bool supports_64bit_directly = false, bool allows_64bit_via_strides = false, bool allows_64bit_via_strides_from_any_ep = false,