Test and fix optimizers LayerNormFusion, BiasSoftmaxFusion, Transpose for opset 18 (#14542)

### Description

Due to the changes introduced in opset 18 on Reduce operators (axes is
an input and not an attribute), the following optimizers are not
catching the pattern they are supposed to optimize. This PR addresses
that.

* layer_norm_fusion.cc: the optimizer was not detecting the pattern it
was suppose to optimize
* bias_softmax_fusion.cc: the optimizer was not detecting the pattern it
was suppose to optimize
* transpose_optimizer.cc: the optimizer was not optimize Reduce
operators other than ReduceSum

### Motivation and Context
Better performance.

---------

Signed-off-by: xadupre <xadupre@microsoft.com>
This commit is contained in:
Xavier Dupré 2023-02-08 23:11:31 +01:00 committed by GitHub
parent cfda876a3f
commit 30ec8b038f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 605 additions and 228 deletions

View file

@ -135,6 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
new_axis = (int)HandleNegativeAxis(axis, rank);
// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
// Details in function documentatin.
if (is_since_opset_13 && new_axis != rank - 1) return false;
int singlebatch_rank = rank - new_axis;

View file

@ -4,6 +4,7 @@
#include "core/optimizer/layer_norm_fusion.h"
#include "core/graph/graph_utils.h"
#include "core/optimizer/utils.h"
#include "core/optimizer/transpose_optimizer/optimizer_api.h"
#include "float.h"
#include <deque>
@ -16,12 +17,17 @@ static constexpr std::array<std::string_view, 3> supported_data_types{"tensor(fl
// Default epsilon
static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f;
static bool IsSupportedDataType(const Node& node) {
static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) {
int input_index = 0;
for (const auto& input_arg : node.InputDefs()) {
if (first_n_inputs != -1 && input_index >= first_n_inputs) {
return true;
}
if (std::find(supported_data_types.begin(), supported_data_types.end(),
*(input_arg->Type())) == supported_data_types.end()) {
return false;
}
++input_index;
}
return true;
}
@ -99,11 +105,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
Node& reduce_mean_node = *p_reduce_mean;
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) ||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
graph.NodeProducesGraphOutput(reduce_mean_node) ||
!IsSupportedDataType(reduce_mean_node)) {
!IsSupportedDataType(reduce_mean_node, 1)) {
continue;
}
nodes_to_remove.push_back(reduce_mean_node);
@ -263,10 +269,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
continue;
}
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13, 18}) ||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) ||
!IsSupportedDataType(reduce_mean2_node) ||
!IsSupportedDataType(reduce_mean2_node, 1) ||
reduce_mean2_node.GetInputEdgesCount() == 0) {
continue;
}
@ -333,8 +339,16 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// get axes attributes
const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes();
std::vector<int64_t> axes_values;
// TODO: modify this codes when opset >= 18 (axes is an input).
if (attributes.find("axes") != attributes.end()) {
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
} else if (reduce_mean_node.InputDefs().size() == 2) {
auto axes = reduce_mean_node.InputDefs()[1];
auto axes_const = graph.GetConstantInitializer(axes->Name(), true);
if (axes_const != nullptr) {
Initializer initializer{*axes_const, graph.ModelPath()};
axes_values.insert(axes_values.end(), initializer.DataAsSpan<int64_t>().begin(), initializer.DataAsSpan<int64_t>().end());
}
}
// Get the inputs for the new LayerNormalization node.
@ -485,9 +499,9 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
continue;
}
Node& reduce_mean_node = *graph.GetNode(p_reduce_mean->Index());
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) ||
reduce_mean_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node) ||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node, 1) ||
reduce_mean_node.GetInputEdgesCount() == 0) {
continue;
}
@ -585,6 +599,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
std::vector<int64_t> axes_values;
if (attributes.find("axes") != attributes.end()) {
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
} else if (reduce_mean_node.InputDefs().size() == 2) {
auto axes = reduce_mean_node.InputDefs()[1];
auto axes_const = graph.GetConstantInitializer(axes->Name(), true);
if (axes_const != nullptr && axes_const->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
Initializer initializer{*axes_const, graph.ModelPath()};
axes_values.insert(axes_values.end(), initializer.DataAsSpan<int64_t>().begin(), initializer.DataAsSpan<int64_t>().end());
}
}
// Get the inputs for the new LayerNormalization node.

View file

@ -1040,7 +1040,7 @@ static bool HandlePad(HandlerArgs& args) {
constexpr HandlerInfo pad_handler = {&FirstInput, &HandlePad};
static bool HandleReduceOp(HandlerArgs& args) {
static bool HandleReduceOpWithArg(HandlerArgs& args) {
int64_t keepdims = args.node.GetAttributeIntDefault("keepdims", 1);
std::optional<std::vector<int64_t>> axes = args.node.GetAttributeInts("axes");
@ -1078,11 +1078,11 @@ static bool HandleReduceOp(HandlerArgs& args) {
return true;
}
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOp};
static bool HandleReduceSum(HandlerArgs& args) {
if (args.ctx.opset < 13) {
return HandleReduceOp(args);
static bool HandleReduceOps(HandlerArgs& args) {
if ((args.node.OpType() == "ReduceSum" && args.ctx.opset < 13) ||
// or all other reduce operators since opset 18
(args.node.OpType() != "ReduceSum" && args.ctx.opset < 18)) {
return HandleReduceOpWithArg(args);
}
bool keepdims = args.node.GetAttributeIntDefault("keepdims", 1) != 0;
@ -1147,7 +1147,7 @@ static bool HandleReduceSum(HandlerArgs& args) {
return true;
}
constexpr HandlerInfo reduce_sum_handler = {&FirstInput, &HandleReduceSum};
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps};
static bool HandleSqueeze(HandlerArgs& args) {
std::vector<int64_t> new_axes;
@ -1709,7 +1709,7 @@ static const std::unordered_map<std::string_view, const HandlerInfo&> handler_ma
#if !defined(USE_CUDA) && !defined(USE_ROCM)
{"Resize", resize_handler},
#endif
{"ReduceSum", reduce_sum_handler},
{"ReduceSum", reduce_op_handler},
{"ReduceLogSum", reduce_op_handler},
{"ReduceLogSumExp", reduce_op_handler},

View file

@ -95,7 +95,6 @@ namespace onnxruntime {
namespace test {
#define MODEL_FOLDER ORT_TSTR("testdata/transform/")
TEST_F(GraphTransformationTests, IdentityElimination) {
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "abs-id-max.onnx";
std::shared_ptr<Model> model;
@ -4390,11 +4389,12 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
return Status::OK();
};
const std::vector<int> opsets{11, 12, 13, 14, 15, 15};
const std::vector<int> opsets{11, 12, 13, 14, 15, 18};
bool shape_test_for_opset15 = false;
for (auto& opset_version : opsets) {
for (auto& opset : opsets) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto opset_version = builder.DomainToVersionMap().find(kOnnxDomain)->second;
auto* input_arg0 = builder.MakeInput<float>({{batch_size, seq_lenth, hidden_size}});
auto* input_arg1 = builder.MakeInput<float>({{hidden_size}});
auto* scalar_int_0 = builder.MakeInitializer<int64_t>({}, {0});
@ -4414,7 +4414,7 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
auto* out = builder.MakeOutput();
builder.AddNode("Add", {input_arg0, input_arg1}, {add_out});
if (opset_version == 15) {
if (opset_version >= 15) {
if (shape_test_for_opset15) {
auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out});
shape_1.AddAttribute("start", (int64_t)1);
@ -4442,11 +4442,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<ReshapeFusion>();
if (opset_version == 15 && shape_test_for_opset15) {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
if (opset >= 15 && shape_test_for_opset15) {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker));
} else {
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker));
}
}
@ -4610,13 +4610,24 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
auto* cast_out_2 = builder.MakeIntermediate();
auto* mul_out = builder.MakeIntermediate();
auto* add_out_2 = builder.MakeOutput();
auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second;
onnxruntime::NodeArg* axes = nullptr;
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
if (opset >= 18) {
axes = builder.MakeInitializer<int64_t>({1}, {-1});
builder.AddNode("ReduceMean", {data_arg, axes}, {reduce_mean_out_1});
} else {
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
}
builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out});
builder.AddNode("Cast", {sub_out}, {cast_out_1})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out});
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
if (opset >= 18) {
builder.AddNode("ReduceMean", {pow_out, axes}, {reduce_mean_out_2});
} else {
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
}
builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1});
builder.AddNode("Sqrt", {add_out_1}, {sqrt_out});
builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out});
@ -4652,7 +4663,7 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
};
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
1, pre_graph_checker, post_graph_checker));
}

View file

@ -17,6 +17,31 @@
namespace onnxruntime {
namespace test {
void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
TransformerLevel target_level,
const std::vector<int64_t>& opset_versions,
double per_sample_tolerance,
double relative_per_sample_tolerance,
std::unique_ptr<GraphTransformer> transformer,
const std::function<void(SessionOptions&)>& add_session_options,
const InlinedHashSet<std::string>& disabled_optimizers) {
ASSERT_TRUE(transformer == nullptr);
for (auto opset_version : opset_versions) {
TransformerTester(build_test_case,
check_transformed_graph,
baseline_level,
target_level,
opset_version,
per_sample_tolerance,
relative_per_sample_tolerance,
nullptr,
add_session_options,
disabled_optimizers);
}
}
void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
@ -101,22 +126,36 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker) {
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = opset_version;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
helper.SetGraphOutputs();
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
const std::vector<int64_t> opset_versions{opset_version};
return TestGraphTransformer(build_test_case, opset_versions, logger, std::move(transformer),
level, steps, pre_graph_checker, post_graph_checker);
}
Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::vector<int64_t>& opset_versions,
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker) {
onnxruntime::GraphTransformerManager graph_transformation_mgr{steps};
ORT_RETURN_IF_ERROR(graph_transformation_mgr.Register(std::move(transformer), level));
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
for (auto opset : opset_versions) {
// Build the model for this test.
std::unordered_map<std::string, int> domain_to_version;
domain_to_version[kOnnxDomain] = opset;
domain_to_version[kMSDomain] = 1;
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
Graph& graph = model.MainGraph();
ModelTestBuilder helper(graph);
build_test_case(helper);
helper.SetGraphOutputs();
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
}
return Status::OK();
}

View file

@ -50,6 +50,10 @@ class ModelTestBuilder {
ModelTestBuilder(Graph& graph) : graph_(graph) {
}
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
return graph_.DomainToVersionMap();
}
template <typename T>
NodeArg* MakeInput(const std::vector<int64_t>& shape, const std::vector<T>& data) {
ONNX_NAMESPACE::TypeProto type_proto;
@ -356,6 +360,17 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
const std::function<void(SessionOptions&)>& add_session_options = {},
const InlinedHashSet<std::string>& disabled_optimizers = {});
void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
TransformerLevel baseline_level,
TransformerLevel target_level,
const std::vector<int64_t>& opset_versions,
double per_sample_tolerance = 0.0,
double relative_per_sample_tolerance = 0.0,
std::unique_ptr<GraphTransformer> transformer = nullptr, // must be null in this case.
const std::function<void(SessionOptions&)>& add_session_options = {},
const InlinedHashSet<std::string>& disabled_optimizers = {});
/**
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
*
@ -372,5 +387,23 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker);
/**
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
*
* @param build_test_case The function to build a graph for testing
* @param opset_versions A graph is created and tested for every opset in this set
* @param logger The logger
* @param transformer The GraphTransformer to be applied
* @param level The transformer level on which the transformer will be applied
* @param steps The step count of the GraphTransformerManager
* @param pre_graph_checker The graph checker function before applying the transformer
* @param post_graph_checker The graph checker function after applying the transformer
*/
Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
const std::vector<int64_t>& opset_versions,
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
const std::function<Status(Graph&)>& post_graph_checker);
} // namespace test
} // namespace onnxruntime

View file

@ -278,6 +278,9 @@ TEST(NhwcTransformerTests, ConvSplit) {
conv_output_arg, .37f, 131);
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
Node& split_node = builder.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg});
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(2));
}
split_node.AddAttribute("axis", static_cast<int64_t>(axis));
builder.AddQLinearBinaryNode("QLinearAdd",
split_output1_arg, .37f, 131,
@ -302,6 +305,11 @@ TEST(NhwcTransformerTests, ConvSplit) {
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3);
TransformerTester(build_test_case,
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3,
18);
}
}
@ -323,6 +331,9 @@ TEST(NhwcTransformerTests, ConvSplitQLinearConcat) {
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
Node& split_node = builder.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg});
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(2));
}
split_node.AddAttribute("axis", static_cast<int64_t>(axis));
Node& qlconcat_node = builder.AddQLinearConcatLike(
@ -346,6 +357,11 @@ TEST(NhwcTransformerTests, ConvSplitQLinearConcat) {
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3);
TransformerTester(build_test_case,
check_nhwc_graph,
TransformerLevel::Level2,
TransformerLevel::Level3,
18);
}
}

View file

@ -378,6 +378,9 @@ GetQDQTestCaseFn BuildConsolidationTestCase(
auto* split_output_3 = builder.MakeIntermediate();
Node& split_node = builder.AddNode("Split", {upper_dq_output}, {split_output_1, split_output_2, split_output_3});
split_node.AddAttribute("axis", axis);
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(3));
}
// add Q
auto* lower_q_output_1 = builder.MakeIntermediate();
@ -456,6 +459,9 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(
auto* split_output_3 = builder.MakeIntermediate();
Node& split_node = builder.AddNode("Split", {dq_output}, {split_output_1, split_output_2, split_output_3});
split_node.AddAttribute("axis", axis);
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
split_node.AddAttribute("num_outputs", static_cast<int64_t>(3));
}
// add Q
auto* q_split_output_1 = builder.MakeOutput();

View file

@ -67,6 +67,14 @@ void QDQTransformerConvTests() {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(BuildQDQConvTestCase<InputType, WeightType, BiasType, OutputType>(input_shape, weights_shape),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37}, {32, 12, 5});
@ -157,10 +165,13 @@ TEST(QDQTransformerTests, ConvMaxPoolReshape_UInt8) {
test_case({1, 12, 37}, {32, 12, 5}, 11);
test_case({1, 12, 37}, {32, 12, 5}, 12);
test_case({1, 12, 37}, {32, 12, 5}, 18);
test_case({1, 23, 13, 13}, {30, 23, 3, 3}, 11);
test_case({1, 23, 13, 13}, {30, 23, 3, 3}, 12);
test_case({1, 23, 13, 13}, {30, 23, 3, 3}, 18);
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}, 11);
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}, 12);
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}, 18);
}
TEST(QDQTransformerTests, ConvMaxPoolReshape_Int8) {
@ -292,6 +303,14 @@ void QDQTransformerAveragePoolTests() {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(BuildQDQAveragePoolTestCase<InputType, OutputType>(input_shape),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -341,6 +360,14 @@ void QDQTransformerGlobalAveragePoolTests() {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(BuildQDQGlobalAveragePoolTestCase<InputType, OutputType>(input_shape),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -391,6 +418,14 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(BuildBinaryOpTestCase<Input1Type, Input2Type, OutputType>(input_shape, op_type),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -522,6 +557,14 @@ void QDQTransformerMatMulTests(bool has_output_q) {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(build_test_case,
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 2, 2}, {1, 2, 4});
@ -677,6 +720,14 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(build_test_case,
check_binary_op_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({2, 2}, {2, 4});
@ -813,6 +864,14 @@ TEST(QDQTransformerTests, DoubleQDQ) {
12,
(scale_1 + scale_3) / 2,
0.01);
TransformerTester(
BuildDoubleQDQTestCases<int8_t, int8_t, int8_t, int8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
succeed ? expect_succeed : expect_fail,
TransformerLevel::Default,
TransformerLevel::Level1,
18,
(scale_1 + scale_3) / 2,
0.01);
};
auto test_case_2u8_2s8_failed = [&](uint8_t zp_1, uint8_t zp_2, int8_t zp_3, int8_t zp_4,
@ -870,7 +929,8 @@ TEST(QDQTransformerTests, Split) {
TransformerTester(BuildQDQSplitTestCase<int8_t, int8_t>(input_shape, axis),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2);
TransformerLevel::Level2,
{12, 18});
};
test_case({6, 18, 54}, 0);
}
@ -887,7 +947,7 @@ TEST(QDQTransformerTests, Split_without_IdenticalChildrenConsolidation) {
TransformerTester(BuildConsolidationTestCase<int8_t, int8_t>(input_shape, axis),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2, 12, {}, {}, nullptr, {},
TransformerLevel::Level2, {12, 18}, {}, {}, nullptr, {},
{"IdenticalChildrenConsolidation"});
};
test_case({6, 18, 54}, 0);
@ -904,7 +964,8 @@ TEST(QDQTransformerTests, Split_with_IdenticalChildrenConsolidation) {
TransformerTester(BuildConsolidationTestCase<int8_t, int8_t>(input_shape, axis),
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2);
TransformerLevel::Level2,
{12, 18});
};
test_case({6, 18, 54}, 0);
}
@ -1509,7 +1570,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_Int8_Fail) {
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
12 /*opset_version*/,
{12, 18} /*opset_version*/,
0.01f /*per_sample_tolerance*/,
0.01f /*relative_per_sample_tolerance*/);
};
@ -1566,6 +1627,14 @@ void QDQTransformerLeakyReluTests() {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(build_test_case,
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -1635,6 +1704,14 @@ void QDQTransformerSigmoidTests() {
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
TransformerTester(build_test_case,
check_graph,
TransformerLevel::Level1,
TransformerLevel::Level2,
18 /*opset_version*/,
0.01 /*per_sample_tolerance*/,
0.01 /*relative_per_sample_tolerance*/,
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
};
test_case({1, 12, 37});
@ -1907,7 +1984,7 @@ TEST(QDQTransformerTests, DQForward_MutilpleSteps) {
TEST(QDQTransformerTests, Clip) {
constexpr float epsilon = std::numeric_limits<float>::epsilon();
auto test_case = [&](float scale, auto zero_point, int clip_count, int opset_version = 12) {
auto test_case = [&](float scale, auto zero_point, int clip_count, int opset_version) {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<int8_t>({1, 32, 112, 112},
std::numeric_limits<int8_t>::min(),
@ -1922,7 +1999,9 @@ TEST(QDQTransformerTests, Clip) {
auto* clip_output = builder.MakeIntermediate();
constexpr float min = .0f;
constexpr float max = 6.0f;
if (opset_version >= 11) {
auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second;
EXPECT_EQ(opset_version, opset);
if (opset >= 11) {
auto* min_initializer = builder.MakeScalarInitializer<float>(min);
auto* max_initializer = builder.MakeScalarInitializer<float>(max);
builder.AddNode("Clip", {dq_output, min_initializer, max_initializer}, {clip_output});
@ -1953,18 +2032,21 @@ TEST(QDQTransformerTests, Clip) {
epsilon);
};
test_case(.0235294122248888f, static_cast<int8_t>(-128), 0); // [0, 6]
test_case(.02f, static_cast<int8_t>(-128), 0); // [0, 5.1]
test_case(.03f, static_cast<int8_t>(-128), 1); // [0, 7.65]
test_case(.02f, static_cast<int8_t>(127), 1); // [-5.1 , 0]
test_case(.02f, static_cast<int8_t>(0), 1); // [-2.56, 2.54]
test_case(.04f, static_cast<int8_t>(-97), 1); // [-1.24, 8.96]
test_case(.02352941176f, static_cast<uint8_t>(0), 0); // [0, 6]
test_case(.02f, static_cast<uint8_t>(0), 0); // [0, 5.1]
test_case(.03f, static_cast<uint8_t>(0), 1); // [0, 7.65]
test_case(.02f, static_cast<uint8_t>(255), 1); // [-5.1, 0]
test_case(.02f, static_cast<uint8_t>(128), 1); // [-2.56, 2.54]
test_case(.04f, static_cast<uint8_t>(31), 1); // [-1.24, 8.96]
std::vector<int64_t> opsets{12, 18};
for (auto opset : opsets) {
test_case(.0235294122248888f, static_cast<int8_t>(-128), 0, opset); // [0, 6]
test_case(.02f, static_cast<int8_t>(-128), 0, opset); // [0, 5.1]
test_case(.03f, static_cast<int8_t>(-128), 1, opset); // [0, 7.65]
test_case(.02f, static_cast<int8_t>(127), 1, opset); // [-5.1 , 0]
test_case(.02f, static_cast<int8_t>(0), 1, opset); // [-2.56, 2.54]
test_case(.04f, static_cast<int8_t>(-97), 1, opset); // [-1.24, 8.96]
test_case(.02352941176f, static_cast<uint8_t>(0), 0, opset); // [0, 6]
test_case(.02f, static_cast<uint8_t>(0), 0, opset); // [0, 5.1]
test_case(.03f, static_cast<uint8_t>(0), 1, opset); // [0, 7.65]
test_case(.02f, static_cast<uint8_t>(255), 1, opset); // [-5.1, 0]
test_case(.02f, static_cast<uint8_t>(128), 1, opset); // [-2.56, 2.54]
test_case(.04f, static_cast<uint8_t>(31), 1, opset); // [-1.24, 8.96]
}
// opset_version = 10
test_case(.02f, static_cast<int8_t>(-128), 0, 10); // [0, 5.1]
@ -1973,10 +2055,12 @@ TEST(QDQTransformerTests, Clip) {
test_case(.03f, static_cast<uint8_t>(0), 1, 10); // [0, 7.65]
// difference between lower/upper and min/max are within epsilon
test_case(epsilon, static_cast<int8_t>(-127), 0); // [-epsilon, x] (x <= 6 + epsilon)
test_case((6 + epsilon) / 255, static_cast<int8_t>(-128), 0); // [0, 6 + epsilon]
test_case(epsilon, static_cast<uint8_t>(1), 0); // [-epsilon, x] (x <= 6 + epsilon)
test_case((6 + epsilon) / 255, static_cast<uint8_t>(0), 0); // [0, 6 + epsilon]
for (auto opset : opsets) {
test_case(epsilon, static_cast<int8_t>(-127), 0, opset); // [-epsilon, x] (x <= 6 + epsilon)
test_case((6 + epsilon) / 255, static_cast<int8_t>(-128), 0, opset); // [0, 6 + epsilon]
test_case(epsilon, static_cast<uint8_t>(1), 0, opset); // [-epsilon, x] (x <= 6 + epsilon)
test_case((6 + epsilon) / 255, static_cast<uint8_t>(0), 0, opset); // [0, 6 + epsilon]
}
}
TEST(QDQTransformerTests, Concat) {
@ -2536,7 +2620,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
// regression test to validate TransposeOptimizer and QDQ Propagation don't loop
// see https://github.com/microsoft/onnxruntime/issues/11605
TEST(QDQTransformerTests, QDQPropagation_GH11605) {
TEST(QDQTransformerTests, QDQPropagation_GH11605_Opset12) {
auto test_case = [&]() {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>({1, 4, 4},
@ -2585,7 +2669,61 @@ TEST(QDQTransformerTests, QDQPropagation_GH11605) {
TransformerTester(build_test_case,
check_graph,
TransformerLevel::Default,
TransformerLevel::Level2);
TransformerLevel::Level2,
12);
};
test_case();
}
TEST(QDQTransformerTests, QDQPropagation_GH11605_Opset13) {
auto test_case = [&]() {
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input_arg = builder.MakeInput<uint8_t>({1, 4, 4},
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
// add DQ
auto* dq_output = builder.MakeIntermediate();
builder.AddDequantizeLinearNode(input_arg, 0.123f, uint8_t(0), dq_output);
// add Transpose 0, 2, 1
const std::vector<int64_t>& perms{0, 2, 1};
auto* transpose_output = builder.MakeIntermediate();
Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output});
transpose_node.AddAttribute("perm", perms);
// add Softmax with axis=2 (to block the Transpose moving past it due to the transpose perms)
auto* softmax_output = builder.MakeIntermediate();
Node& softmax_node = builder.AddNode("Softmax", {transpose_output}, {softmax_output});
softmax_node.AddAttribute("axis", int64_t(2));
// add second Transpose. this is so the check in TransposeOptimizer::ProcessTranspose for outputs leading to
// a Transpose is satisfied, allowing the first Transpose to move past the Q/DQ inserted by QDQ Propagation
Node& transpose_node2 = builder.AddNode("Transpose", {softmax_output}, {builder.MakeOutput()});
transpose_node2.AddAttribute("perm", perms);
};
// check that an edge case where transpose optimization gets blocked is handled gracefully.
// Original: DQ -> Tr -> SoftM -> Tr
// QDQ Prop inserts a Q/DQ pair to create a QDQ node group for the Transpose: DQ -> Tr -> Q -> DQ -> SoftM -> Tr
// Transpose opt phase 1 moves the Tr down until it blocks on the SoftMax: DQ -> Q -> DQ -> Tr -> SoftM -> Tr
// Transpose opt phase 2 flips the Tr to prior to the DQ as it's not part of a QDQ node group at that point, as
// running the transpose on 8-bit data should be cheaper: DQ -> Q -> Tr -> DQ -> SoftM -> Tr
// QDQ cleanup in Level2 removes the unnecessary DQ/Q pair at the start: Tr -> DQ -> SoftM -> Tr
// this is the optimal result as the Transpose is using 8-bit data and we have no surplus Q/DQ pairs
auto check_graph = [&](InferenceSessionWrapper& session) {
std::vector<std::string> expected_op_types_in_order{
"DequantizeLinear",
"Softmax"};
const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph());
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
};
TransformerTester(build_test_case,
check_graph,
TransformerLevel::Default,
TransformerLevel::Level2,
13);
};
test_case();

File diff suppressed because it is too large Load diff