mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Enable free dimension override by name
This commit is contained in:
parent
e303f458e4
commit
a3a8a53736
9 changed files with 87 additions and 31 deletions
|
|
@ -545,10 +545,10 @@ struct OrtApi {
|
|||
// Always returns the same instance on every invocation.
|
||||
OrtStatus*(ORT_API_CALL* GetAllocatorWithDefaultOptions)(_Outptr_ OrtAllocator** out)NO_EXCEPTION;
|
||||
|
||||
// Override symbolic dimensions with actual values if known at session initialization time to enable
|
||||
// Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable
|
||||
// optimizations that can take advantage of fixed values (such as memory planning, etc)
|
||||
OrtStatus*(ORT_API_CALL* AddFreeDimensionOverride)(_Inout_ OrtSessionOptions* options,
|
||||
_In_ const char* symbolic_dim, _In_ int64_t dim_override)NO_EXCEPTION;
|
||||
_In_ const char* dim_denotation, _In_ int64_t dim_value)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* APIs to support non-tensor types - map and sequence.
|
||||
|
|
@ -748,7 +748,6 @@ struct OrtApi {
|
|||
_In_ const char* key, _Outptr_ char** value)NO_EXCEPTION;
|
||||
|
||||
OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION;
|
||||
|
||||
ORT_CLASS_RELEASE(ModelMetadata);
|
||||
|
||||
/*
|
||||
|
|
@ -772,6 +771,11 @@ struct OrtApi {
|
|||
NO_EXCEPTION;
|
||||
|
||||
ORT_CLASS_RELEASE(ThreadingOptions);
|
||||
|
||||
// Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable
|
||||
// optimizations that can take advantage of fixed values (such as memory planning, etc)
|
||||
OrtStatus*(ORT_API_CALL* AddFreeDimensionOverrideByName)(_Inout_ OrtSessionOptions* options,
|
||||
_In_ const char* dim_name, _In_ int64_t dim_value)NO_EXCEPTION;
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -10,9 +10,17 @@
|
|||
#include "core/util/thread_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
enum class FreeDimensionOverrideType {
|
||||
Invalid = 0,
|
||||
Denotation = 1,
|
||||
Name = 2
|
||||
};
|
||||
|
||||
struct FreeDimensionOverride {
|
||||
std::string dimension_denotation;
|
||||
int64_t dimension_override;
|
||||
std::string dim_identifier;
|
||||
FreeDimensionOverrideType dim_identifer_type;
|
||||
int64_t dim_value;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
@ -62,8 +70,8 @@ struct SessionOptions {
|
|||
// configuring this makes sense only when you're using parallel executor
|
||||
OrtThreadPoolParams inter_op_param;
|
||||
|
||||
// For models with free input dimensions (most commonly batch size), specifies a set of values to override those
|
||||
// free dimensions with, keyed by dimension denotation.
|
||||
// For models with symbolic input dimensions (most commonly batch size), specifies a set of values to override those
|
||||
// symbolic dimensions with, keyed by dimension parameters.
|
||||
std::vector<FreeDimensionOverride> free_dimension_overrides;
|
||||
|
||||
// By default the session uses its own set of threadpools, unless this is set to false.
|
||||
|
|
|
|||
|
|
@ -18,13 +18,17 @@ static std::string ToLower(std::string s) {
|
|||
return s;
|
||||
}
|
||||
|
||||
/*explicit*/ FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply)
|
||||
FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply)
|
||||
: GraphTransformer("FreeDimensionOverrideTransformer") {
|
||||
for (const auto& o : overrides_to_apply) {
|
||||
// Convert to lowercase to perform case-insensitive comparisons later
|
||||
std::string denotation = ToLower(o.dimension_denotation);
|
||||
|
||||
dimension_override_by_denotation_.emplace(denotation, o.dimension_override);
|
||||
if (o.dim_identifer_type == FreeDimensionOverrideType::Denotation) {
|
||||
dimension_override_by_denotation_.emplace(ToLower(o.dim_identifier), o.dim_value);
|
||||
} else if (o.dim_identifer_type == FreeDimensionOverrideType::Name) {
|
||||
dimension_override_by_name_.emplace(o.dim_identifier, o.dim_value);
|
||||
} else {
|
||||
ORT_THROW("Invalid free dimension override.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -48,23 +52,38 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
|
|||
auto* new_dimension = new_shape.add_dim();
|
||||
*new_dimension = dimension;
|
||||
|
||||
bool overridden = false;
|
||||
int64_t dimension_override = 0;
|
||||
|
||||
if (dimension.has_denotation()) {
|
||||
// Convert to lowercase to perform case-insensitive comparison
|
||||
auto it = dimension_override_by_denotation_.find(ToLower(dimension.denotation()));
|
||||
if (it == dimension_override_by_denotation_.end()) {
|
||||
continue;
|
||||
if (it != dimension_override_by_denotation_.end()) {
|
||||
overridden = true;
|
||||
dimension_override = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t dimension_override = it->second;
|
||||
if (dimension.has_dim_param()) {
|
||||
auto it = dimension_override_by_name_.find(dimension.dim_param());
|
||||
if (it != dimension_override_by_name_.end()) {
|
||||
if (overridden && dimension_override != it->second) {
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Conflicting free dimension overrides.");
|
||||
}
|
||||
|
||||
overridden = true;
|
||||
dimension_override = it->second;
|
||||
}
|
||||
}
|
||||
|
||||
if (overridden) {
|
||||
if (dimension.has_dim_value()) {
|
||||
// If this dimension actually has a value but it doesn't match the override value, return an
|
||||
// error.
|
||||
if (dimension.dim_value() != dimension_override) {
|
||||
LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' "
|
||||
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
|
||||
<< "but the size of this dimension " << dimension.dim_value() << " "
|
||||
<< "does not equal the specified override of" << dimension_override << ".";
|
||||
<< "with a fixed dimension size " << dimension.dim_value() << " "
|
||||
<< "which does not equal the specified override of" << dimension_override << ".";
|
||||
|
||||
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ class FreeDimensionOverrideTransformer : public GraphTransformer {
|
|||
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
|
||||
|
||||
std::map<std::string, int64_t> dimension_override_by_denotation_;
|
||||
std::map<std::string, int64_t> dimension_override_by_name_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -151,8 +151,18 @@ ORT_API_STATUS_IMPL(OrtApis::SetInterOpNumThreads, _In_ OrtSessionOptions* optio
|
|||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options,
|
||||
_In_ const char* symbolic_dim, _In_ int64_t dim_override) {
|
||||
options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override});
|
||||
_In_ const char* dim_denotation, _In_ int64_t dim_value) {
|
||||
options->value.free_dimension_overrides.push_back(
|
||||
onnxruntime::FreeDimensionOverride{dim_denotation, onnxruntime::FreeDimensionOverrideType::Denotation, dim_value}
|
||||
);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options,
|
||||
_In_ const char* dim_name, _In_ int64_t dim_value) {
|
||||
options->value.free_dimension_overrides.push_back(
|
||||
onnxruntime::FreeDimensionOverride{dim_name, onnxruntime::FreeDimensionOverrideType::Name, dim_value}
|
||||
);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1515,7 +1515,8 @@ static constexpr OrtApi ort_api_1_to_3 = {
|
|||
&OrtApis::CreateEnvWithGlobalThreadPools,
|
||||
&OrtApis::DisablePerSessionThreads,
|
||||
&OrtApis::CreateThreadingOptions,
|
||||
&OrtApis::ReleaseThreadingOptions};
|
||||
&OrtApis::ReleaseThreadingOptions,
|
||||
&OrtApis::AddFreeDimensionOverrideByName};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
// If this assert hits, read the above 'Rules on how to add a new Ort API version'
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShape
|
|||
ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
|
||||
ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out);
|
||||
ORT_API_STATUS_IMPL(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out);
|
||||
ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* symbolic_dim, _In_ int64_t dim_override);
|
||||
ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, _In_ int64_t dim_value);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out);
|
||||
ORT_API_STATUS_IMPL(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out)
|
||||
|
|
@ -185,4 +185,7 @@ ORT_ALL_ARGS_NONNULL;
|
|||
ORT_API_STATUS_IMPL(DisablePerSessionThreads, _In_ OrtSessionOptions* options);
|
||||
ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out);
|
||||
ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions*);
|
||||
|
||||
ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value);
|
||||
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ using namespace ONNX_NAMESPACE;
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
TEST(FreeDimensionOverrideTransformerTest, Test) {
|
||||
void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
|
||||
auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx");
|
||||
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
@ -29,11 +29,15 @@ TEST(FreeDimensionOverrideTransformerTest, Test) {
|
|||
// The model's input shape has two free dimensions, which have the denotation of DATA_BATCH
|
||||
// and DATA_CHANNEL. Supplying these overrides to the transformer should replace those free
|
||||
// dimensions with values of 1 and 42, respectively.
|
||||
std::vector<FreeDimensionOverride> overrides =
|
||||
{
|
||||
FreeDimensionOverride{onnx::DATA_BATCH, 1},
|
||||
FreeDimensionOverride{onnx::DATA_CHANNEL, 42},
|
||||
};
|
||||
std::vector<FreeDimensionOverride> overrides(2);
|
||||
|
||||
if (overrideType == FreeDimensionOverrideType::Denotation) {
|
||||
overrides[0] = FreeDimensionOverride{onnx::DATA_BATCH, overrideType, 1};
|
||||
overrides[1] = FreeDimensionOverride{onnx::DATA_CHANNEL, overrideType, 42};
|
||||
} else {
|
||||
overrides[0] = FreeDimensionOverride{"Dim1", overrideType, 1};
|
||||
overrides[1] = FreeDimensionOverride{"Dim2", overrideType, 42};
|
||||
};
|
||||
|
||||
auto graph_transformer = onnxruntime::make_unique<FreeDimensionOverrideTransformer>(overrides);
|
||||
|
||||
|
|
@ -66,5 +70,11 @@ TEST(FreeDimensionOverrideTransformerTest, Test) {
|
|||
ASSERT_FALSE(modified); // no overrides apply anymore
|
||||
}
|
||||
|
||||
|
||||
TEST(FreeDimensionOverrideDenotationTransformerTest, Test) {
|
||||
TestFreeDimensions(FreeDimensionOverrideType::Denotation);
|
||||
TestFreeDimensions(FreeDimensionOverrideType::Name);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@
|
|||
xy"Abstest_absZ9
|
||||
x4
|
||||
2.
|
||||
None
|
||||
Dim1
|
||||
DATA_BATCH
|
||||
NoneDATA_CHANNEL
|
||||
Dim2DATA_CHANNEL
|
||||
b
|
||||
y
|
||||
|
||||
None
|
||||
None
|
||||
Dim1
|
||||
Dim2
|
||||
B
|
||||
Loading…
Reference in a new issue