Enable free dimension override by name

This commit is contained in:
Jeff 2020-04-15 16:21:52 -07:00
parent e303f458e4
commit a3a8a53736
9 changed files with 87 additions and 31 deletions

View file

@ -545,10 +545,10 @@ struct OrtApi {
// Always returns the same instance on every invocation.
OrtStatus*(ORT_API_CALL* GetAllocatorWithDefaultOptions)(_Outptr_ OrtAllocator** out)NO_EXCEPTION;
// Override symbolic dimensions with actual values if known at session initialization time to enable
// Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable
// optimizations that can take advantage of fixed values (such as memory planning, etc)
OrtStatus*(ORT_API_CALL* AddFreeDimensionOverride)(_Inout_ OrtSessionOptions* options,
_In_ const char* symbolic_dim, _In_ int64_t dim_override)NO_EXCEPTION;
_In_ const char* dim_denotation, _In_ int64_t dim_value)NO_EXCEPTION;
/**
* APIs to support non-tensor types - map and sequence.
@ -748,7 +748,6 @@ struct OrtApi {
_In_ const char* key, _Outptr_ char** value)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* ModelMetadataGetVersion)(_In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value)NO_EXCEPTION;
ORT_CLASS_RELEASE(ModelMetadata);
/*
@ -772,6 +771,11 @@ struct OrtApi {
NO_EXCEPTION;
ORT_CLASS_RELEASE(ThreadingOptions);
// Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable
// optimizations that can take advantage of fixed values (such as memory planning, etc)
OrtStatus*(ORT_API_CALL* AddFreeDimensionOverrideByName)(_Inout_ OrtSessionOptions* options,
_In_ const char* dim_name, _In_ int64_t dim_value)NO_EXCEPTION;
};
/*

View file

@ -10,9 +10,17 @@
#include "core/util/thread_utils.h"
namespace onnxruntime {
enum class FreeDimensionOverrideType {
Invalid = 0,
Denotation = 1,
Name = 2
};
struct FreeDimensionOverride {
std::string dimension_denotation;
int64_t dimension_override;
std::string dim_identifier;
FreeDimensionOverrideType dim_identifer_type;
int64_t dim_value;
};
/**
@ -62,8 +70,8 @@ struct SessionOptions {
// configuring this makes sense only when you're using parallel executor
OrtThreadPoolParams inter_op_param;
// For models with free input dimensions (most commonly batch size), specifies a set of values to override those
// free dimensions with, keyed by dimension denotation.
// For models with symbolic input dimensions (most commonly batch size), specifies a set of values to override those
// symbolic dimensions with, keyed by dimension parameters.
std::vector<FreeDimensionOverride> free_dimension_overrides;
// By default the session uses its own set of threadpools, unless this is set to false.

View file

@ -18,13 +18,17 @@ static std::string ToLower(std::string s) {
return s;
}
/*explicit*/ FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply)
FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply)
: GraphTransformer("FreeDimensionOverrideTransformer") {
for (const auto& o : overrides_to_apply) {
// Convert to lowercase to perform case-insensitive comparisons later
std::string denotation = ToLower(o.dimension_denotation);
dimension_override_by_denotation_.emplace(denotation, o.dimension_override);
if (o.dim_identifer_type == FreeDimensionOverrideType::Denotation) {
dimension_override_by_denotation_.emplace(ToLower(o.dim_identifier), o.dim_value);
} else if (o.dim_identifer_type == FreeDimensionOverrideType::Name) {
dimension_override_by_name_.emplace(o.dim_identifier, o.dim_value);
} else {
ORT_THROW("Invalid free dimension override.");
}
}
}
@ -48,23 +52,38 @@ Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified,
auto* new_dimension = new_shape.add_dim();
*new_dimension = dimension;
bool overridden = false;
int64_t dimension_override = 0;
if (dimension.has_denotation()) {
// Convert to lowercase to perform case-insensitive comparison
auto it = dimension_override_by_denotation_.find(ToLower(dimension.denotation()));
if (it == dimension_override_by_denotation_.end()) {
continue;
if (it != dimension_override_by_denotation_.end()) {
overridden = true;
dimension_override = it->second;
}
}
int64_t dimension_override = it->second;
if (dimension.has_dim_param()) {
auto it = dimension_override_by_name_.find(dimension.dim_param());
if (it != dimension_override_by_name_.end()) {
if (overridden && dimension_override != it->second) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Conflicting free dimension overrides.");
}
overridden = true;
dimension_override = it->second;
}
}
if (overridden) {
if (dimension.has_dim_value()) {
// If this dimension actually has a value but it doesn't match the override value, return an
// error.
if (dimension.dim_value() != dimension_override) {
LOGS(logger, ERROR) << "The model has input '" << graph_input->Name() << "' "
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
<< "but the size of this dimension " << dimension.dim_value() << " "
<< "does not equal the specified override of" << dimension_override << ".";
<< "with a fixed dimension size " << dimension.dim_value() << " "
<< "which does not equal the specified override of" << dimension_override << ".";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
}

View file

@ -26,6 +26,7 @@ class FreeDimensionOverrideTransformer : public GraphTransformer {
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
std::map<std::string, int64_t> dimension_override_by_denotation_;
std::map<std::string, int64_t> dimension_override_by_name_;
};
} // namespace onnxruntime

View file

@ -151,8 +151,18 @@ ORT_API_STATUS_IMPL(OrtApis::SetInterOpNumThreads, _In_ OrtSessionOptions* optio
}
ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options,
_In_ const char* symbolic_dim, _In_ int64_t dim_override) {
options->value.free_dimension_overrides.push_back(onnxruntime::FreeDimensionOverride{symbolic_dim, dim_override});
_In_ const char* dim_denotation, _In_ int64_t dim_value) {
options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{dim_denotation, onnxruntime::FreeDimensionOverrideType::Denotation, dim_value}
);
return nullptr;
}
ORT_API_STATUS_IMPL(OrtApis::AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options,
_In_ const char* dim_name, _In_ int64_t dim_value) {
options->value.free_dimension_overrides.push_back(
onnxruntime::FreeDimensionOverride{dim_name, onnxruntime::FreeDimensionOverrideType::Name, dim_value}
);
return nullptr;
}

View file

@ -1515,7 +1515,8 @@ static constexpr OrtApi ort_api_1_to_3 = {
&OrtApis::CreateEnvWithGlobalThreadPools,
&OrtApis::DisablePerSessionThreads,
&OrtApis::CreateThreadingOptions,
&OrtApis::ReleaseThreadingOptions};
&OrtApis::ReleaseThreadingOptions,
&OrtApis::AddFreeDimensionOverrideByName};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
// If this assert hits, read the above 'Rules on how to add a new Ort API version'

View file

@ -132,7 +132,7 @@ ORT_API_STATUS_IMPL(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShape
ORT_API_STATUS_IMPL(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out);
ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_ OrtTypeInfo** out);
ORT_API_STATUS_IMPL(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out);
ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* symbolic_dim, _In_ int64_t dim_override);
ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, _In_ int64_t dim_value);
ORT_API_STATUS_IMPL(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out);
ORT_API_STATUS_IMPL(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out)
@ -185,4 +185,7 @@ ORT_ALL_ARGS_NONNULL;
ORT_API_STATUS_IMPL(DisablePerSessionThreads, _In_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out);
ORT_API(void, ReleaseThreadingOptions, _Frees_ptr_opt_ OrtThreadingOptions*);
ORT_API_STATUS_IMPL(AddFreeDimensionOverrideByName, _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, _In_ int64_t dim_value);
} // namespace OrtApis

View file

@ -17,7 +17,7 @@ using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
TEST(FreeDimensionOverrideTransformerTest, Test) {
void TestFreeDimensions(FreeDimensionOverrideType overrideType) {
auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx");
std::shared_ptr<Model> model;
@ -29,11 +29,15 @@ TEST(FreeDimensionOverrideTransformerTest, Test) {
// The model's input shape has two free dimensions, which have the denotation of DATA_BATCH
// and DATA_CHANNEL. Supplying these overrides to the transformer should replace those free
// dimensions with values of 1 and 42, respectively.
std::vector<FreeDimensionOverride> overrides =
{
FreeDimensionOverride{onnx::DATA_BATCH, 1},
FreeDimensionOverride{onnx::DATA_CHANNEL, 42},
};
std::vector<FreeDimensionOverride> overrides(2);
if (overrideType == FreeDimensionOverrideType::Denotation) {
overrides[0] = FreeDimensionOverride{onnx::DATA_BATCH, overrideType, 1};
overrides[1] = FreeDimensionOverride{onnx::DATA_CHANNEL, overrideType, 42};
} else {
overrides[0] = FreeDimensionOverride{"Dim1", overrideType, 1};
overrides[1] = FreeDimensionOverride{"Dim2", overrideType, 42};
};
auto graph_transformer = onnxruntime::make_unique<FreeDimensionOverrideTransformer>(overrides);
@ -66,5 +70,11 @@ TEST(FreeDimensionOverrideTransformerTest, Test) {
ASSERT_FALSE(modified); // no overrides apply anymore
}
TEST(FreeDimensionOverrideDenotationTransformerTest, Test) {
TestFreeDimensions(FreeDimensionOverrideType::Denotation);
TestFreeDimensions(FreeDimensionOverrideType::Name);
}
} // namespace test
} // namespace onnxruntime

View file

@ -3,12 +3,12 @@
xy"Abstest_absZ9
x4
2.
None
Dim1
DATA_BATCH
None DATA_CHANNEL
Dim2 DATA_CHANNEL
b
y

None
None
Dim1
Dim2
B