mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Update DirectML EP changes from DmlDev as of 2021-06-07 (#7987)
* Merged PR 6093117: Fix test_DynamicQuantizedLinear_max_adjusted_expanded by allowing Identity operator to run on non-float inputs Motivation: As part of the OnnxConformance Backend tests, DynamicQuantizedLinear_max_adjusted_expanded is failing. Root Cause: - The test model has `Identity` operator as one of the node. The input of this node is of non-float data type. - In DML, `Identity` operator is registered as operator which requires floating input. - As per `DirectMLSchema.h`, support for non-float input has been added for `Identity` operator in DML but the same has not been reflected in the `OperatorRegistration.cpp`. Changes: - Removed all traces of the requiresFloatFormatsForGraph flag from it's definition and usage. This flag was only used for Identity and it's related operator. - Added null check for the graphOutput nodeArg in GraphDescBuilder.cpp to stop the crash of the test. Related work items: #33076298 * Merged PR 6103324: Remove usage of non-generic error code (FWP_E_NULL_POINTER) Motivation: Addressing Dwayne comment on the previous PR. [Ref: [6093117](https://dev.azure.com/microsoft/WindowsAI/_git/onnxruntime/pullrequest/6093117?discussionId=44292162&path=%2Fonnxruntime%2Fcore%2Fproviders%2Fdml%2FDmlExecutionProvider%2Fsrc%2FGraphPartitioner.cpp)] Changes: Inside the DML EP, we should not use some other platform specific error codes. Instead we should a appropriate generic error code. Related work items: #33076298 Co-authored-by: Sumit Agarwal <sumitagarwal@microsoft.com>
This commit is contained in:
parent
2f2aaf2cf6
commit
43c45ddd66
9 changed files with 7 additions and 32 deletions
|
|
@ -99,12 +99,6 @@ namespace Windows::AI::MachineLearning::Adapter
|
|||
{
|
||||
GraphNodeFactory factory;
|
||||
std::optional<uint32_t> requiredInputCount;
|
||||
|
||||
// The operator inputs/outputs must be a floating point data type. When true,
|
||||
// if the node's tensor data type is not-floating point, the node is partioned
|
||||
// separately (unless the input/output is a CPU constant input, which is okay,
|
||||
// as those can be read directly by the DML operator in the DML_OPERATOR_DESC).
|
||||
bool requiresFloatFormatsExceptConstInputs = false;
|
||||
};
|
||||
|
||||
using KernelSupportQuery = std::function<bool(const onnxruntime::Node& node)>;
|
||||
|
|
|
|||
|
|
@ -333,7 +333,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
bool canAliasFirstInput,
|
||||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph,
|
||||
bool requiresFloatFormatsForGraph,
|
||||
bool supportedWith64BitTensorsVia32BitStrides,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
bool prefer64BitTensorsDirectly,
|
||||
|
|
@ -503,7 +502,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
graphReg.requiredInputCount = *requiredInputCountForGraph;
|
||||
}
|
||||
|
||||
graphReg.requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
|
||||
regInfo->graphNodeFactoryRegistration = graphReg;
|
||||
}
|
||||
|
||||
|
|
@ -536,7 +534,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
|
|||
if (canAliasFirstInput ||
|
||||
supportsGraph ||
|
||||
requiredInputCountForGraph ||
|
||||
requiresFloatFormatsForGraph ||
|
||||
requiredConstantCpuInputs ||
|
||||
supportedWith64BitTensorsVia32BitStrides ||
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp ||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,6 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
|
|||
bool canAliasFirstInput,
|
||||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph = nullptr,
|
||||
bool requiresFloatFormatsForGraph = false,
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
|
||||
bool prefer64BitTensorsDirectly = false,
|
||||
|
|
|
|||
|
|
@ -248,6 +248,7 @@ namespace Dml::GraphDescBuilder
|
|||
const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(
|
||||
GraphKernelHelper::GetFusedNodeArgNameMatchingGraph(fusedNodeOutputDefs[outputIndex]->Name()));
|
||||
|
||||
THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg");
|
||||
const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name());
|
||||
|
||||
DML_OUTPUT_GRAPH_EDGE_DESC edge = {};
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ namespace Dml
|
|||
}
|
||||
};
|
||||
|
||||
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, bool requiresFloatFormats, uint32_t supportedDeviceDataTypeMask)
|
||||
bool NodeArgSupportedInGraph(const onnxruntime::NodeArg* arg, uint32_t supportedDeviceDataTypeMask)
|
||||
{
|
||||
if (arg->Exists())
|
||||
{
|
||||
|
|
@ -164,14 +164,6 @@ namespace Dml
|
|||
}
|
||||
}
|
||||
|
||||
if (requiresFloatFormats)
|
||||
{
|
||||
if (mlDataType != MLOperatorTensorDataType::Float &&
|
||||
mlDataType != MLOperatorTensorDataType::Float16)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -189,7 +181,6 @@ namespace Dml
|
|||
if (!isConstantCpuInput &&
|
||||
!NodeArgSupportedInGraph(
|
||||
node.InputDefs()[i],
|
||||
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
|
||||
supportedDeviceDataTypeMask
|
||||
))
|
||||
{
|
||||
|
|
@ -201,7 +192,6 @@ namespace Dml
|
|||
{
|
||||
if (!NodeArgSupportedInGraph(
|
||||
arg,
|
||||
registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs,
|
||||
supportedDeviceDataTypeMask
|
||||
))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -65,7 +65,6 @@ struct OperatorRegistrationInformation
|
|||
MLOperatorKernelCreateFn creationFunction;
|
||||
MLOperatorShapeInferenceFunction shapeInferenceFunction;
|
||||
bool canAliasFirstInput;
|
||||
bool requiresFloatFormatsForGraph = false;
|
||||
|
||||
gsl::span<char const* const> tensorTypeNames;
|
||||
gsl::span<const SupportedTensorDataTypes> supportedTensorDataTypes;
|
||||
|
|
@ -314,26 +313,26 @@ constexpr auto requiredConstantCpuInputs(Args... args)
|
|||
|
||||
// Define a single row of registration information.
|
||||
#define REG_INFO(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, false, ##__VA_ARGS__,
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
|
||||
|
||||
// Versioned operator
|
||||
#define REG_INFO_VER(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName##version>, false, false, ##__VA_ARGS__,
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName##version>, false, ##__VA_ARGS__,
|
||||
|
||||
// Identity operators use Copy, alias their first input, and require floating point formats
|
||||
// for usage in the graph, besides constant inputs. This is because they currently use
|
||||
// element-wise identity operators in the graph for striding support, but issue actual copies
|
||||
// outside the graph. Element-wise identity currently only supports floating point types.
|
||||
#define REG_INFO_ID(version, operatorName, ...) \
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, true, true, ##__VA_ARGS__,
|
||||
#operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, true, ##__VA_ARGS__,
|
||||
|
||||
// MS-domain operators
|
||||
#define REG_INFO_MS(version, operatorName, ...) \
|
||||
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, false, ##__VA_ARGS__,
|
||||
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
|
||||
|
||||
// MS-domain operators
|
||||
#define REG_INFO_MSDML(version, operatorName, ...) \
|
||||
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, false, ##__VA_ARGS__,
|
||||
#operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction<ShapeInferenceHelper_##operatorName>, false, ##__VA_ARGS__,
|
||||
|
||||
constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] =
|
||||
{
|
||||
|
|
@ -690,7 +689,6 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
|
|||
information.canAliasFirstInput, // alias
|
||||
kernelSupportsGraph, // supportsGraph
|
||||
information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr,
|
||||
information.requiresFloatFormatsForGraph,
|
||||
supportedWith64BitTensorsVia32BitStrides,
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
prefer64BitTensorsDirectly,
|
||||
|
|
|
|||
|
|
@ -112,7 +112,6 @@ IMLOperatorRegistryPrivate : public IUnknown
|
|||
bool canAliasFirstInput,
|
||||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph = nullptr,
|
||||
bool requiresFloatFormatsForGraph = false,
|
||||
bool supportedWith64BitTensorsVia32BitStrides = false,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp = false,
|
||||
bool prefer64BitTensorsDirectly = false,
|
||||
|
|
|
|||
|
|
@ -55,7 +55,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
|
|||
bool canAliasFirstInput,
|
||||
bool supportsGraph,
|
||||
const uint32_t* requiredInputCountForGraph,
|
||||
bool requiresFloatFormatsForGraph,
|
||||
bool supportedWith64BitTensorsVia32BitStrides,
|
||||
bool supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
bool prefer64BitTensorsDirectly,
|
||||
|
|
@ -81,7 +80,6 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
|
|||
canAliasFirstInput,
|
||||
supportsGraph,
|
||||
requiredInputCountForGraph,
|
||||
requiresFloatFormatsForGraph,
|
||||
supportedWith64BitTensorsVia32BitStrides,
|
||||
supportedWith64BitTensorsVia32BitStridesFromAnyEp,
|
||||
prefer64BitTensorsDirectly,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
|
|||
bool can_alias_first_input,
|
||||
bool supports_graph,
|
||||
const uint32_t* required_input_count_for_graph = nullptr,
|
||||
bool requires_float_formats_for_graph = false,
|
||||
bool supports_64bit_directly = false,
|
||||
bool allows_64bit_via_strides = false,
|
||||
bool allows_64bit_via_strides_from_any_ep = false,
|
||||
|
|
|
|||
Loading…
Reference in a new issue