Update DirectML EP changes from DmlDev as of 2021-06-07 (#7987)

* Merged PR 6093117: Fix test_DynamicQuantizedLinear_max_adjusted_expanded by allowing Identity operator to run on non-float inputs

Motivation:
As part of the OnnxConformance Backend tests, DynamicQuantizedLinear_max_adjusted_expanded is failing.

Root Cause:
- The test model has `Identity` operator as one of the node. The input of this node is of non-float data type.
- In DML, `Identity` operator is registered as operator which requires floating input.
- As per `DirectMLSchema.h`, support for non-float input has been added for `Identity` operator in DML but the same has not been reflected in the `OperatorRegistration.cpp`.

Changes:
- Removed all traces of the requiresFloatFormatsForGraph flag from it's definition and usage. This flag was only used for Identity and it's related operator.
- Added null check for the graphOutput nodeArg in GraphDescBuilder.cpp to stop the crash of the test.

Related work items: #33076298

* Merged PR 6103324: Remove usage of non-generic error code (FWP_E_NULL_POINTER)

Motivation:
Addressing Dwayne comment on the previous PR. [Ref: [6093117](https://dev.azure.com/microsoft/WindowsAI/_git/onnxruntime/pullrequest/6093117?discussionId=44292162&path=%2Fonnxruntime%2Fcore%2Fproviders%2Fdml%2FDmlExecutionProvider%2Fsrc%2FGraphPartitioner.cpp)]

Changes:
Inside the DML EP, we should not use some other platform specific error codes. Instead we should a appropriate generic error code.

Related work items: #33076298

Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
sumitsays 2021-06-11 11:09:48 -07:00 committed by GitHub
parent 2f2aaf2cf6
commit 43c45ddd66
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 7 additions and 32 deletions

View file

@ -99,12 +99,6 @@ namespace Windows::AI::MachineLearning::Adapter
{
GraphNodeFactory factory;
std::optional<uint32_t> requiredInputCount;
// The operator inputs/outputs must be a floating point data type. When true,
// if the node's tensor data type is not-floating point, the node is partioned
// separately (unless the input/output is a CPU constant input, which is okay,
// as those can be read directly by the DML operator in the DML_OPERATOR_DESC).
bool requiresFloatFormatsExceptConstInputs = false;
};
using KernelSupportQuery = std::function<bool(const onnxruntime::Node& node)>;

View file

@ -333,7 +333,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool requiresFloatFormatsForGraph,
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
@ -503,7 +502,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
graphReg.requiredInputCount = *requiredInputCountForGraph;
}
graphReg.requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
regInfo->graphNodeFactoryRegistration = graphReg;
}
@ -536,7 +534,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
if (canAliasFirstInput ||
supportsGraph ||
requiredInputCountForGraph ||
requiresFloatFormatsForGraph ||
requiredConstantCpuInputs ||
supportedWith64BitTensorsVia32BitStrides ||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||

View file

@ -41,7 +41,6 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool requiresFloatFormatsForGraph = false,
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,

View file

@ -248,6 +248,7 @@ namespace Dml::GraphDescBuilder
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
DML_OUTPUT_GRAPH_EDGE_DESC edge = {};

View file

@ -139,7 +139,7 @@ namespace Dml
}
};
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask)
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask)
{
if (arg->Exists())
{
@ -164,14 +164,6 @@ namespace Dml
}
}
if (requiresFloatFormats)
{
if (mlDataType != MLOperatorTensorDataType::Float &&
mlDataType != MLOperatorTensorDataType::Float16)
{
return false;
}
}
}
}
}
@ -189,7 +181,6 @@ namespace Dml
if (!isConstantCpuInput &&
!NodeArgSupportedInGraph(
node.InputDefs()[i],
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
supportedDeviceDataTypeMask
))
{
@ -201,7 +192,6 @@ namespace Dml
{
if (!NodeArgSupportedInGraph(
arg,
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
supportedDeviceDataTypeMask
))
{

View file

@ -65,7 +65,6 @@ struct OperatorRegistrationInformation
MLOperatorKernelCreateFn creationFunction;
MLOperatorShapeInferenceFunction shapeInferenceFunction;
bool canAliasFirstInput;
bool requiresFloatFormatsForGraph = false;
gsl::span<char const* const> tensorTypeNames;
gsl::span<const SupportedTensorDataTypes> supportedTensorDataTypes;
@ -314,26 +313,26 @@ constexpr auto requiredConstantCpuInputs(Args... args)
// Define a single row of registration information.
#define REG_INFO(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, false, ##__VA_ARGS__,
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
// Versioned operator
#define REG_INFO_VER(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName##version>, false, false, ##__VA_ARGS__,
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName##version>, false, ##__VA_ARGS__,
// Identity operators use Copy, alias their first input, and require floating point formats
// for usage in the graph, besides constant inputs. This is because they currently use
// element-wise identity operators in the graph for striding support, but issue actual copies
// outside the graph. Element-wise identity currently only supports floating point types.
#define REG_INFO_ID(version, operatorName, ...) \
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, true, true, ##__VA_ARGS__,
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, true, ##__VA_ARGS__,
// MS-domain operators
#define REG_INFO_MS(version, operatorName, ...) \
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, false, ##__VA_ARGS__,
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
// MS-domain operators
#define REG_INFO_MSDML(version, operatorName, ...) \
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, false, ##__VA_ARGS__,
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] =
{
@ -690,7 +689,6 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
information.canAliasFirstInput, // alias
kernelSupportsGraph, // supportsGraph
information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr,
information.requiresFloatFormatsForGraph,
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,

View file

@ -112,7 +112,6 @@ IMLOperatorRegistryPrivate : public IUnknown
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
bool requiresFloatFormatsForGraph = false,
bool supportedWith64BitTensorsVia32BitStrides = false,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
bool prefer64BitTensorsDirectly = false,

View file

@ -55,7 +55,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
bool requiresFloatFormatsForGraph,
bool supportedWith64BitTensorsVia32BitStrides,
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
bool prefer64BitTensorsDirectly,
@ -81,7 +80,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
canAliasFirstInput,
supportsGraph,
requiredInputCountForGraph,
requiresFloatFormatsForGraph,
supportedWith64BitTensorsVia32BitStrides,
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
prefer64BitTensorsDirectly,

View file

@ -28,7 +28,6 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
bool can_alias_first_input,
bool supports_graph,
const uint32_t* required_input_count_for_graph = nullptr,
bool requires_float_formats_for_graph = false,
bool supports_64bit_directly = false,
bool allows_64bit_via_strides = false,
bool allows_64bit_via_strides_from_any_ep = false,