From a3a8a53736a475319f095dff805ab6d398cbf103 Mon Sep 17 00:00:00 2001 From: Jeff Date: Wed, 15 Apr 2020 16:21:52 -0700 Subject: [PATCH] Enable free dimension override by name --- .../core/session/onnxruntime_c_api.h | 10 +++-- onnxruntime/core/framework/session_options.h | 16 ++++++-- .../free_dim_override_transformer.cc | 39 ++++++++++++++----- .../optimizer/free_dim_override_transformer.h | 1 + .../core/session/abi_session_options.cc | 14 ++++++- onnxruntime/core/session/onnxruntime_c_api.cc | 3 +- onnxruntime/core/session/ort_apis.h | 5 ++- .../optimizer/free_dimension_override_test.cc | 22 ++++++++--- .../test/testdata/abs_free_dimensions.onnx | 8 ++-- 9 files changed, 87 insertions(+), 31 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e7867142b5..4f78b464fa 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -545,10 +545,10 @@ struct OrtApi { // Always returns the same instance on every invocation. OrtStatus*(ORT_API_CALL* GetAllocatorWithDefaultOptions)(_Outptr_ OrtAllocator** out)NO_EXCEPTION; - // Override symbolic dimensions with actual values if known at session initialization time to enable + // Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable // optimizations that can take advantage of fixed values (such as memory planning, etc) OrtStatus*(ORT_API_CALL* AddFreeDimensionOverride)(_Inout_ OrtSessionOptions* options, - _In_ const char* symbolic_dim, _In_ int64_t dim_override)NO_EXCEPTION; + _In_ const char* dim_denotation, _In_ int64_t dim_value)NO_EXCEPTION; /** * APIs to support non-tensor types - map and sequence. @@ -748,7 +748,6 @@ struct OrtApi { _In_ const char* key, _Outptr_ char** value)NO_EXCEPTION; OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION; - ORT_CLASS_RELEASE(ModelMetadata); /* @@ -772,6 +771,11 @@ struct OrtApi { NO_EXCEPTION; ORT_CLASS_RELEASE(ThreadingOptions); + + // Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable + // optimizations that can take advantage of fixed values (such as memory planning, etc) + OrtStatus*(ORT_API_CALL* AddFreeDimensionOverrideByName)(_Inout_ OrtSessionOptions* options, + _In_ const char* dim_name, _In_ int64_t dim_value)NO_EXCEPTION; }; /* diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 88de473f1c..2b5ccac963 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -10,9 +10,17 @@ #include "core/util/thread_utils.h" namespace onnxruntime { + +enum class FreeDimensionOverrideType { + Invalid = 0, + Denotation = 1, + Name = 2 +}; + struct FreeDimensionOverride { - std::string dimension_denotation; - int64_t dimension_override; + std::string dim_identifier; + FreeDimensionOverrideType dim_identifer_type; + int64_t dim_value; }; /** @@ -62,8 +70,8 @@ struct SessionOptions { // configuring this makes sense only when you're using parallel executor OrtThreadPoolParams inter_op_param; - // For models with free input dimensions (most commonly batch size), specifies a set of values to override those - // free dimensions with, keyed by dimension denotation. + // For models with symbolic input dimensions (most commonly batch size), specifies a set of values to override those + // symbolic dimensions with, keyed by dimension parameters. std::vector free_dimension_overrides; // By default the session uses its own set of threadpools, unless this is set to false. diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.cc b/onnxruntime/core/optimizer/free_dim_override_transformer.cc index ac6a79a372..46e7b96ded 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.cc +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.cc @@ -18,13 +18,17 @@ static std::string ToLower(std::string s) { return s; } -/*explicit*/ FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span overrides_to_apply) +FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span overrides_to_apply) : GraphTransformer("FreeDimensionOverrideTransformer") { for (const auto& o : overrides_to_apply) { // Convert to lowercase to perform case-insensitive comparisons later - std::string denotation = ToLower(o.dimension_denotation); - - dimension_override_by_denotation_.emplace(denotation, o.dimension_override); + if (o.dim_identifer_type == FreeDimensionOverrideType::Denotation) { + dimension_override_by_denotation_.emplace(ToLower(o.dim_identifier), o.dim_value); + } else if (o.dim_identifer_type == FreeDimensionOverrideType::Name) { + dimension_override_by_name_.emplace(o.dim_identifier, o.dim_value); + } else { + ORT_THROW("Invalid free dimension override."); + } } } @@ -48,23 +52,38 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, auto* new_dimension = new_shape.add_dim(); *new_dimension = dimension; + bool overridden = false; + int64_t dimension_override = 0; + if (dimension.has_denotation()) { // Convert to lowercase to perform case-insensitive comparison auto it = dimension_override_by_denotation_.find(ToLower(dimension.denotation())); - if (it == dimension_override_by_denotation_.end()) { - continue; + if (it != dimension_override_by_denotation_.end()) { + overridden = true; + dimension_override = it->second; } + } - int64_t dimension_override = it->second; + if (dimension.has_dim_param()) { + auto it = dimension_override_by_name_.find(dimension.dim_param()); + if (it != dimension_override_by_name_.end()) { + if (overridden && dimension_override != it->second) { + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Conflicting free dimension overrides."); + } + overridden = true; + dimension_override = it->second; + } + } + + if (overridden) { if (dimension.has_dim_value()) { // If this dimension actually has a value but it doesn't match the override value, return an // error. if (dimension.dim_value() != dimension_override) { LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' " - << "with a fixed dimension denotation '" << dimension.denotation() << "' " - << "but the size of this dimension " << dimension.dim_value() << " " - << "does not equal the specified override of" << dimension_override << "."; + << "with a fixed dimension size " << dimension.dim_value() << " " + << "which does not equal the specified override of" << dimension_override << "."; return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override."); } diff --git a/onnxruntime/core/optimizer/free_dim_override_transformer.h b/onnxruntime/core/optimizer/free_dim_override_transformer.h index e6b437982f..18e0b128b8 100644 --- a/onnxruntime/core/optimizer/free_dim_override_transformer.h +++ b/onnxruntime/core/optimizer/free_dim_override_transformer.h @@ -26,6 +26,7 @@ class FreeDimensionOverrideTransformer : public GraphTransformer { Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; std::map dimension_override_by_denotation_; + std::map dimension_override_by_name_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 78e5188ad9..9717b65b77 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -151,8 +151,18 @@ ORT_API_STATUS_IMPL(OrtApis::SetInterOpNumThreads, _In_ OrtSessionOptions* optio } ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, - _In_ const char* symbolic_dim, _In_ int64_t dim_override) { - options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override}); + _In_ const char* dim_denotation, _In_ int64_t dim_value) { + options->value.free_dimension_overrides.push_back( + onnxruntime::FreeDimensionOverride{dim_denotation, onnxruntime::FreeDimensionOverrideType::Denotation, dim_value} + ); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, + _In_ const char* dim_name, _In_ int64_t dim_value) { + options->value.free_dimension_overrides.push_back( + onnxruntime::FreeDimensionOverride{dim_name, onnxruntime::FreeDimensionOverrideType::Name, dim_value} + ); return nullptr; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 12be053c81..393c05fd29 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1515,7 +1515,8 @@ static constexpr OrtApi ort_api_1_to_3 = { &OrtApis::CreateEnvWithGlobalThreadPools, &OrtApis::DisablePerSessionThreads, &OrtApis::CreateThreadingOptions, - &OrtApis::ReleaseThreadingOptions}; + &OrtApis::ReleaseThreadingOptions, + &OrtApis::AddFreeDimensionOverrideByName}; // Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other) // If this assert hits, read the above 'Rules on how to add a new Ort API version' diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index b1b50c6659..472e34e162 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -132,7 +132,7 @@ ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShape ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out); ORT_API_STATUS_IMPL(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); -ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* symbolic_dim, _In_ int64_t dim_override); +ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, _In_ int64_t dim_value); ORT_API_STATUS_IMPL(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out); ORT_API_STATUS_IMPL(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) @@ -185,4 +185,7 @@ ORT_ALL_ARGS_NONNULL; ORT_API_STATUS_IMPL(DisablePerSessionThreads, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions*); + +ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value); + } // namespace OrtApis diff --git a/onnxruntime/test/optimizer/free_dimension_override_test.cc b/onnxruntime/test/optimizer/free_dimension_override_test.cc index 653ab75804..d278a5b7e6 100644 --- a/onnxruntime/test/optimizer/free_dimension_override_test.cc +++ b/onnxruntime/test/optimizer/free_dimension_override_test.cc @@ -17,7 +17,7 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -TEST(FreeDimensionOverrideTransformerTest, Test) { +void TestFreeDimensions(FreeDimensionOverrideType overrideType) { auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx"); std::shared_ptr model; @@ -29,11 +29,15 @@ TEST(FreeDimensionOverrideTransformerTest, Test) { // The model's input shape has two free dimensions, which have the denotation of DATA_BATCH // and DATA_CHANNEL. Supplying these overrides to the transformer should replace those free // dimensions with values of 1 and 42, respectively. - std::vector overrides = - { - FreeDimensionOverride{onnx::DATA_BATCH, 1}, - FreeDimensionOverride{onnx::DATA_CHANNEL, 42}, - }; + std::vector overrides(2); + + if (overrideType == FreeDimensionOverrideType::Denotation) { + overrides[0] = FreeDimensionOverride{onnx::DATA_BATCH, overrideType, 1}; + overrides[1] = FreeDimensionOverride{onnx::DATA_CHANNEL, overrideType, 42}; + } else { + overrides[0] = FreeDimensionOverride{"Dim1", overrideType, 1}; + overrides[1] = FreeDimensionOverride{"Dim2", overrideType, 42}; + }; auto graph_transformer = onnxruntime::make_unique(overrides); @@ -66,5 +70,11 @@ TEST(FreeDimensionOverrideTransformerTest, Test) { ASSERT_FALSE(modified); // no overrides apply anymore } + +TEST(FreeDimensionOverrideDenotationTransformerTest, Test) { + TestFreeDimensions(FreeDimensionOverrideType::Denotation); + TestFreeDimensions(FreeDimensionOverrideType::Name); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/abs_free_dimensions.onnx b/onnxruntime/test/testdata/abs_free_dimensions.onnx index 7d4c041a01..4c3d5ab6c4 100644 --- a/onnxruntime/test/testdata/abs_free_dimensions.onnx +++ b/onnxruntime/test/testdata/abs_free_dimensions.onnx @@ -3,12 +3,12 @@ xy"Abstest_absZ9 x4 2. -None +Dim1 DATA_BATCH -None DATA_CHANNEL +Dim2 DATA_CHANNEL b y  -None -None +Dim1 +Dim2 B \ No newline at end of file