Implement L1 graph transformer for free dimension override (#1825)

* Implement FreeDimensionOverrideTransformer

* Add test

* Fix compiler warnings

* Update comment

* LOGS_DEFAULT

* Merge from master
This commit is contained in:
Adrian Tsai 2019-09-20 10:52:14 -07:00 committed by Pranav Sharma
parent 561f2c4a9a
commit a7beed798e
9 changed files with 213 additions and 3 deletions

View file

@ -3,11 +3,14 @@
#pragma once
#include <gsl/span>
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
struct FreeDimensionOverride;
namespace transformer_utils {
@ -21,6 +24,7 @@ std::vector<std::unique_ptr<RewriteRule>> GenerateRewriteRules(TransformerLevel
If transformers_and_rules_to_enable is not empty, it returns the intersection between the predefined transformers/rules
and the transformers_and_rules_to_enable. */
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerLevel level,
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
const std::vector<std::string>& rules_and_transformers_to_enable = {});
/** Given a TransformerLevel, this method generates a name for the rule-based graph transformer of that level. */

View file

@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/logging/logging.h"
#include "core/session/inference_session.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/free_dim_override_transformer.h"
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::common;
namespace onnxruntime {
static std::string ToLower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(), [](char c) {
return static_cast<char>(::tolower(c));
});
return s;
}
/*explicit*/ FreeDimensionOverrideTransformer::FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply)
: GraphTransformer("FreeDimensionOverrideTransformer") {
for (const auto& o : overrides_to_apply) {
// Convert to lowercase to perform case-insensitive comparisons later
std::string denotation = ToLower(o.dimension_denotation);
dimension_override_by_denotation_.emplace(denotation, o.dimension_override);
}
}
Status FreeDimensionOverrideTransformer::ApplyImpl(Graph& graph, bool& modified, int /*graph_level*/) const {
for (const onnxruntime::NodeArg* graph_input : graph.GetInputs()) {
// Get the current input's type and shape
const auto* input_type = graph_input->TypeAsProto();
const auto* input_shape = graph_input->Shape();
if (!input_type || !input_shape || !input_type->has_tensor_type()) {
continue;
}
// Construct a new shape for this input, replacing free dimensions with their overrides
onnx::TensorShapeProto new_shape;
for (int32_t dim_index = 0; dim_index < input_shape->dim_size(); ++dim_index) {
const auto& dimension = input_shape->dim(dim_index);
// By default just make a copy of the dimension
auto* new_dimension = new_shape.add_dim();
*new_dimension = dimension;
if (dimension.has_denotation()) {
// Convert to lowercase to perform case-insensitive comparison
auto it = dimension_override_by_denotation_.find(ToLower(dimension.denotation()));
if (it == dimension_override_by_denotation_.end()) {
continue;
}
int64_t dimension_override = it->second;
// If this dimension actually has a value but it doesn't match the override value, return an
// error.
if (dimension.has_dim_value() && dimension.dim_value() != dimension_override) {
LOGS_DEFAULT(ERROR) << "The model has input '" << graph_input->Name() << "' "
<< "with a fixed dimension denotation '" << dimension.denotation() << "' "
<< "but the size of this dimension " << dimension.dim_value() << " "
<< "does not equal the specified override of" << dimension_override << ".";
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid free dimension override.");
}
// Set the dimension override
new_dimension->clear_dim_param();
new_dimension->set_dim_value(dimension_override);
}
}
// Set the new shape
auto* mutable_graph_input = graph.GetNodeArg(graph_input->Name());
assert(mutable_graph_input != nullptr);
mutable_graph_input->SetShape(new_shape);
modified = true;
}
return Status::OK();
}
} // namespace onnxruntime

View file

@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <gsl/span>
#include "core/common/common.h"
#include "core/optimizer/graph_transformer.h"
namespace onnxruntime {
struct FreeDimensionOverride;
/**
@Class FreeDimensionOverrideTransformer
Transformer that overrides free dimensions in the graph with the specific value
that matches the denotation for that dimension.
*/
class FreeDimensionOverrideTransformer : public GraphTransformer {
public:
explicit FreeDimensionOverrideTransformer(gsl::span<const FreeDimensionOverride> overrides_to_apply);
private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level) const override;
std::map<std::string, int64_t> dimension_override_by_denotation_;
};
} // namespace onnxruntime

View file

@ -17,7 +17,9 @@
#include "core/optimizer/relu_clip_fusion.h"
#include "core/optimizer/shape_to_initializer.h"
#include "core/optimizer/nchwc_transformer.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/mlas/inc/mlas.h"
#include "core/session/inference_session.h"
namespace onnxruntime {
@ -87,6 +89,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(Tra
}
std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerLevel level,
gsl::span<const FreeDimensionOverride> free_dimension_overrides,
const std::vector<std::string>& transformers_and_rules_to_enable) {
std::vector<std::unique_ptr<GraphTransformer>> transformers;
std::unique_ptr<RuleBasedGraphTransformer> rule_transformer = nullptr;
@ -95,6 +98,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GenerateTransformers(TransformerL
std::unordered_set<std::string> l1_execution_providers = {};
transformers.emplace_back(std::make_unique<ConstantFolding>(l1_execution_providers));
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(free_dimension_overrides));
rule_transformer = GenerateRuleBasedGraphTransformer(level, transformers_and_rules_to_enable, l1_execution_providers);
} break;

View file

@ -1037,7 +1037,7 @@ void InferenceSession::AddPredefinedTransformers(GraphTransformerManager& transf
const std::vector<std::string>& custom_list) {
auto add_transformers = [&](TransformerLevel level) {
// Generate and register transformers for level
auto transformers_to_register = transformer_utils::GenerateTransformers(level, custom_list);
auto transformers_to_register = transformer_utils::GenerateTransformers(level, session_options_.free_dimension_overrides, custom_list);
for (auto& entry : transformers_to_register) {
transformer_manager.Register(std::move(entry), level);
}

View file

@ -46,6 +46,12 @@ namespace logging {
class LoggingManager;
}
struct FreeDimensionOverride
{
std::string dimension_denotation;
int64_t dimension_override;
};
/**
* Configuration information for a session.
*/
@ -93,6 +99,10 @@ struct SessionOptions {
// controls the size of the thread pool used to parallelize the execution of nodes (ops)
// configuring this makes sense only when you're using parallel executor
int inter_op_num_threads = 0;
// For models with free input dimensions (most commonly batch size), specifies a set of values to override those
// free dimensions with, keyed by dimension denotation.
std::vector<FreeDimensionOverride> free_dimension_overrides;
};
/**

View file

@ -0,0 +1,61 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/optimizer/graph_transformer.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "test/framework/test_utils.h"
#include "test/test_environment.h"
#include "gtest/gtest.h"
#include "core/optimizer/free_dim_override_transformer.h"
#include "core/session/inference_session.h"
using namespace std;
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
TEST(FreeDimensionOverrideTransformerTest, Test) {
string model_uri = "testdata/abs_free_dimensions.onnx";
std::shared_ptr<Model> model;
ASSERT_TRUE(Model::Load(model_uri, model).IsOK());
Graph& graph = model->MainGraph();
// The model's input shape has two free dimensions, which have the denotation of DATA_BATCH
// and DATA_CHANNEL. Supplying these overrides to the transformer should replace those free
// dimensions with values of 1 and 42, respectively.
std::vector<FreeDimensionOverride> overrides =
{
FreeDimensionOverride{ onnx::DATA_BATCH, 1 },
FreeDimensionOverride{ onnx::DATA_CHANNEL, 42 },
};
auto graph_transformer = std::make_unique<FreeDimensionOverrideTransformer>(overrides);
onnxruntime::GraphTransformerManager graph_transformation_mgr(5);
graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1);
graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1);
// Verify that the shape of the input graph has the correct values
const auto& graph_inputs = graph.GetInputs();
ASSERT_TRUE(graph_inputs.size() == 1); // This model only has a single input ('x')
const auto* input_shape = graph_inputs[0]->Shape();
ASSERT_TRUE(input_shape->dim_size() == 3); // Model takes a 3D tensor as input; two of those dimensions are (were) free dimensions
ASSERT_TRUE(input_shape->dim(0).denotation() == onnx::DATA_BATCH);
ASSERT_TRUE(input_shape->dim(0).has_dim_value());
ASSERT_TRUE(input_shape->dim(0).dim_value() == 1);
ASSERT_TRUE(input_shape->dim(1).denotation() == onnx::DATA_CHANNEL);
ASSERT_TRUE(input_shape->dim(1).has_dim_value());
ASSERT_TRUE(input_shape->dim(1).dim_value() == 42);
}
} // namespace test
} // namespace onnxruntime

View file

@ -40,7 +40,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
std::string l2_rule1 = "ConvBNFusion";
std::vector<std::string> custom_list = {l1_rule1, l1_transformer, l2_rule1};
auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, custom_list);
auto transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level1, {}, custom_list);
ASSERT_TRUE(transformers.size() == 2);
auto l1_rule_transformer_name = transformer_utils::GenerateRuleBasedTransformerName(TransformerLevel::Level1);
RuleBasedGraphTransformer* rule_transformer = nullptr;
@ -51,7 +51,7 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) {
}
ASSERT_TRUE(rule_transformer && rule_transformer->RulesCount() == 1);
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, custom_list);
transformers = transformer_utils::GenerateTransformers(TransformerLevel::Level2, {}, custom_list);
ASSERT_TRUE(transformers.size() == 1);
rule_transformer = dynamic_cast<RuleBasedGraphTransformer*>(transformers[0].get());
ASSERT_TRUE(rule_transformer->RulesCount() == 1);

View file

@ -0,0 +1,14 @@
 backend-test:s
xy"Abstest_absZ9
x4
2.
None
DATA_BATCH
None DATA_CHANNEL
b
y

None
None
B