diff --git a/onnxruntime/core/optimizer/constant_sharing.cc b/onnxruntime/core/optimizer/constant_sharing.cc index c06349ec9b..116061a542 100644 --- a/onnxruntime/core/optimizer/constant_sharing.cc +++ b/onnxruntime/core/optimizer/constant_sharing.cc @@ -35,7 +35,6 @@ using SupportedTypeList = boost::mp11::mp_list TENSOR_ELEM_COUNT_THRESHOLD) { + if (num_elements > ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD) { return false; } } - if (num_elements > 0 && num_elements <= TENSOR_ELEM_COUNT_THRESHOLD) { + if (num_elements > 0 && num_elements <= ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD) { return true; } diff --git a/onnxruntime/core/optimizer/constant_sharing.h b/onnxruntime/core/optimizer/constant_sharing.h index d1ea0bce53..3d0cb875da 100644 --- a/onnxruntime/core/optimizer/constant_sharing.h +++ b/onnxruntime/core/optimizer/constant_sharing.h @@ -29,6 +29,8 @@ class ConstantSharing : public GraphTransformer { excluded_initializers_(excluded_initializers) { } + static constexpr int64_t TENSOR_ELEM_COUNT_THRESHOLD = 8; + private: Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 1b3954de2f..c8c5fddc78 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -7,6 +7,7 @@ #include "GraphPartitioner.h" #include "core/framework/kernel_type_str_resolver.h" #include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" #include "FusedGraphKernel.h" #include "MLOperatorAuthorImpl.h" #include "DmlGraphFusionHelper.h" @@ -87,11 +88,23 @@ namespace Dml assert(iter != initializerPartitionMap.end()); if (iter->second.size() > 1) { - if (requiredInitializerMap.find(input) != requiredInitializerMap.end()) + // By including non-transferrable tensors in isInitializerTransferable, it causes DML to upload and preprocess them + // to duplicate locations rather than treating them as being non-constant, which is helpful for optimization. + // The size threshold for this should be no smaller than that used to combine initializers in the constant + // sharing transform to prevent that transform from hurting performance. + // If the kernel relies on this input to be initialized, it should also be small enough to copy cheaply. + const uint64_t maximumElementsForDuplicationTensor = 64; + static_assert(maximumElementsForDuplicationTensor >= onnxruntime::ConstantSharing::TENSOR_ELEM_COUNT_THRESHOLD); + + uint64_t totalElementCount = 1; + for (int i = 0; i < tensor->dims().size(); ++i) + { + totalElementCount *= tensor->dims()[i]; + } + + if (totalElementCount <= maximumElementsForDuplicationTensor || + requiredInitializerMap.find(input) != requiredInitializerMap.end()) { - // The kernel relies on this input to be initialized, and it should be small enough to copy - // cheaply. FusedGraphKernel only handles constant CPU inputs through transferred initializers, - // rather than ORT, to avoid mismatches in policy or implementation causing failures. isInitializerTransferable[input] = {tensor, false}; }