mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Test and fix optimizers LayerNormFusion, BiasSoftmaxFusion, Transpose for opset 18 (#14542)
### Description Due to the changes introduced in opset 18 on Reduce operators (axes is an input and not an attribute), the following optimizers are not catching the pattern they are supposed to optimize. This PR addresses that. * layer_norm_fusion.cc: the optimizer was not detecting the pattern it was suppose to optimize * bias_softmax_fusion.cc: the optimizer was not detecting the pattern it was suppose to optimize * transpose_optimizer.cc: the optimizer was not optimize Reduce operators other than ReduceSum ### Motivation and Context Better performance. --------- Signed-off-by: xadupre <xadupre@microsoft.com>
This commit is contained in:
parent
cfda876a3f
commit
30ec8b038f
10 changed files with 605 additions and 228 deletions
|
|
@ -135,6 +135,7 @@ bool TrySelectInputAndBiasWithAlignment(Node& add_node, Node& softmax_node, Node
|
|||
new_axis = (int)HandleNegativeAxis(axis, rank);
|
||||
|
||||
// The axis attribute for Softmax in OpSet-11 and OpSet-13 are different.
|
||||
// Details in function documentatin.
|
||||
if (is_since_opset_13 && new_axis != rank - 1) return false;
|
||||
|
||||
int singlebatch_rank = rank - new_axis;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "core/optimizer/layer_norm_fusion.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "core/optimizer/transpose_optimizer/optimizer_api.h"
|
||||
#include "float.h"
|
||||
#include <deque>
|
||||
|
||||
|
|
@ -16,12 +17,17 @@ static constexpr std::array<std::string_view, 3> supported_data_types{"tensor(fl
|
|||
// Default epsilon
|
||||
static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f;
|
||||
|
||||
static bool IsSupportedDataType(const Node& node) {
|
||||
static bool IsSupportedDataType(const Node& node, int first_n_inputs=-1) {
|
||||
int input_index = 0;
|
||||
for (const auto& input_arg : node.InputDefs()) {
|
||||
if (first_n_inputs != -1 && input_index >= first_n_inputs) {
|
||||
return true;
|
||||
}
|
||||
if (std::find(supported_data_types.begin(), supported_data_types.end(),
|
||||
*(input_arg->Type())) == supported_data_types.end()) {
|
||||
return false;
|
||||
}
|
||||
++input_index;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
@ -99,11 +105,11 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
Node& reduce_mean_node = *p_reduce_mean;
|
||||
ORT_RETURN_IF_ERROR(Recurse(reduce_mean_node, modified, graph_level, logger));
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) ||
|
||||
!graph_utils::IsSupportedProvider(reduce_mean_node, GetCompatibleExecutionProviders()) ||
|
||||
(reduce_mean_node.GetOutputEdgesCount() != 1 && reduce_mean_node.GetOutputEdgesCount() != 2) ||
|
||||
graph.NodeProducesGraphOutput(reduce_mean_node) ||
|
||||
!IsSupportedDataType(reduce_mean_node)) {
|
||||
!IsSupportedDataType(reduce_mean_node, 1)) {
|
||||
continue;
|
||||
}
|
||||
nodes_to_remove.push_back(reduce_mean_node);
|
||||
|
|
@ -263,10 +269,10 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
continue;
|
||||
}
|
||||
Node& reduce_mean2_node = *graph.GetNode(p_reduce_mean2->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean2_node, "ReduceMean", {1, 11, 13, 18}) ||
|
||||
reduce_mean2_node.GetExecutionProviderType() != reduce_mean_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean2_node, 1) ||
|
||||
!IsSupportedDataType(reduce_mean2_node) ||
|
||||
!IsSupportedDataType(reduce_mean2_node, 1) ||
|
||||
reduce_mean2_node.GetInputEdgesCount() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -333,8 +339,16 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
|
|||
// get axes attributes
|
||||
const onnxruntime::NodeAttributes& attributes = reduce_mean_node.GetAttributes();
|
||||
std::vector<int64_t> axes_values;
|
||||
// TODO: modify this codes when opset >= 18 (axes is an input).
|
||||
if (attributes.find("axes") != attributes.end()) {
|
||||
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
} else if (reduce_mean_node.InputDefs().size() == 2) {
|
||||
auto axes = reduce_mean_node.InputDefs()[1];
|
||||
auto axes_const = graph.GetConstantInitializer(axes->Name(), true);
|
||||
if (axes_const != nullptr) {
|
||||
Initializer initializer{*axes_const, graph.ModelPath()};
|
||||
axes_values.insert(axes_values.end(), initializer.DataAsSpan<int64_t>().begin(), initializer.DataAsSpan<int64_t>().end());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the inputs for the new LayerNormalization node.
|
||||
|
|
@ -485,9 +499,9 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
continue;
|
||||
}
|
||||
Node& reduce_mean_node = *graph.GetNode(p_reduce_mean->Index());
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(reduce_mean_node, "ReduceMean", {1, 11, 13, 18}) ||
|
||||
reduce_mean_node.GetExecutionProviderType() != pow_node.GetExecutionProviderType() ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node) ||
|
||||
!optimizer_utils::CheckOutputEdges(graph, reduce_mean_node, 1) || !IsSupportedDataType(reduce_mean_node, 1) ||
|
||||
reduce_mean_node.GetInputEdgesCount() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -585,6 +599,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
|
|||
std::vector<int64_t> axes_values;
|
||||
if (attributes.find("axes") != attributes.end()) {
|
||||
axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
|
||||
} else if (reduce_mean_node.InputDefs().size() == 2) {
|
||||
auto axes = reduce_mean_node.InputDefs()[1];
|
||||
auto axes_const = graph.GetConstantInitializer(axes->Name(), true);
|
||||
if (axes_const != nullptr && axes_const->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64) {
|
||||
Initializer initializer{*axes_const, graph.ModelPath()};
|
||||
axes_values.insert(axes_values.end(), initializer.DataAsSpan<int64_t>().begin(), initializer.DataAsSpan<int64_t>().end());
|
||||
}
|
||||
}
|
||||
|
||||
// Get the inputs for the new LayerNormalization node.
|
||||
|
|
|
|||
|
|
@ -1040,7 +1040,7 @@ static bool HandlePad(HandlerArgs& args) {
|
|||
|
||||
constexpr HandlerInfo pad_handler = {&FirstInput, &HandlePad};
|
||||
|
||||
static bool HandleReduceOp(HandlerArgs& args) {
|
||||
static bool HandleReduceOpWithArg(HandlerArgs& args) {
|
||||
int64_t keepdims = args.node.GetAttributeIntDefault("keepdims", 1);
|
||||
|
||||
std::optional<std::vector<int64_t>> axes = args.node.GetAttributeInts("axes");
|
||||
|
|
@ -1078,11 +1078,11 @@ static bool HandleReduceOp(HandlerArgs& args) {
|
|||
return true;
|
||||
}
|
||||
|
||||
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOp};
|
||||
|
||||
static bool HandleReduceSum(HandlerArgs& args) {
|
||||
if (args.ctx.opset < 13) {
|
||||
return HandleReduceOp(args);
|
||||
static bool HandleReduceOps(HandlerArgs& args) {
|
||||
if ((args.node.OpType() == "ReduceSum" && args.ctx.opset < 13) ||
|
||||
// or all other reduce operators since opset 18
|
||||
(args.node.OpType() != "ReduceSum" && args.ctx.opset < 18)) {
|
||||
return HandleReduceOpWithArg(args);
|
||||
}
|
||||
|
||||
bool keepdims = args.node.GetAttributeIntDefault("keepdims", 1) != 0;
|
||||
|
|
@ -1147,7 +1147,7 @@ static bool HandleReduceSum(HandlerArgs& args) {
|
|||
return true;
|
||||
}
|
||||
|
||||
constexpr HandlerInfo reduce_sum_handler = {&FirstInput, &HandleReduceSum};
|
||||
constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps};
|
||||
|
||||
static bool HandleSqueeze(HandlerArgs& args) {
|
||||
std::vector<int64_t> new_axes;
|
||||
|
|
@ -1709,7 +1709,7 @@ static const std::unordered_map<std::string_view, const HandlerInfo&> handler_ma
|
|||
#if !defined(USE_CUDA) && !defined(USE_ROCM)
|
||||
{"Resize", resize_handler},
|
||||
#endif
|
||||
{"ReduceSum", reduce_sum_handler},
|
||||
{"ReduceSum", reduce_op_handler},
|
||||
|
||||
{"ReduceLogSum", reduce_op_handler},
|
||||
{"ReduceLogSumExp", reduce_op_handler},
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ namespace onnxruntime {
|
|||
namespace test {
|
||||
|
||||
#define MODEL_FOLDER ORT_TSTR("testdata/transform/")
|
||||
|
||||
TEST_F(GraphTransformationTests, IdentityElimination) {
|
||||
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "abs-id-max.onnx";
|
||||
std::shared_ptr<Model> model;
|
||||
|
|
@ -4390,11 +4389,12 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
|
|||
return Status::OK();
|
||||
};
|
||||
|
||||
const std::vector<int> opsets{11, 12, 13, 14, 15, 15};
|
||||
const std::vector<int> opsets{11, 12, 13, 14, 15, 18};
|
||||
bool shape_test_for_opset15 = false;
|
||||
|
||||
for (auto& opset_version : opsets) {
|
||||
for (auto& opset : opsets) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto opset_version = builder.DomainToVersionMap().find(kOnnxDomain)->second;
|
||||
auto* input_arg0 = builder.MakeInput<float>({{batch_size, seq_lenth, hidden_size}});
|
||||
auto* input_arg1 = builder.MakeInput<float>({{hidden_size}});
|
||||
auto* scalar_int_0 = builder.MakeInitializer<int64_t>({}, {0});
|
||||
|
|
@ -4414,7 +4414,7 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
|
|||
auto* out = builder.MakeOutput();
|
||||
|
||||
builder.AddNode("Add", {input_arg0, input_arg1}, {add_out});
|
||||
if (opset_version == 15) {
|
||||
if (opset_version >= 15) {
|
||||
if (shape_test_for_opset15) {
|
||||
auto& shape_1 = builder.AddNode("Shape", {add_out}, {shape_out});
|
||||
shape_1.AddAttribute("start", (int64_t)1);
|
||||
|
|
@ -4442,11 +4442,11 @@ TEST_F(GraphTransformationTests, ReshapeFusionOpsetTest) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<ReshapeFusion>();
|
||||
if (opset_version == 15 && shape_test_for_opset15) {
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
if (opset >= 15 && shape_test_for_opset15) {
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, pre_graph_checker));
|
||||
} else {
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), TransformerLevel::Level1, 1,
|
||||
pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
}
|
||||
|
|
@ -4610,13 +4610,24 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
|
|||
auto* cast_out_2 = builder.MakeIntermediate();
|
||||
auto* mul_out = builder.MakeIntermediate();
|
||||
auto* add_out_2 = builder.MakeOutput();
|
||||
auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second;
|
||||
onnxruntime::NodeArg* axes = nullptr;
|
||||
|
||||
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
|
||||
if (opset >= 18) {
|
||||
axes = builder.MakeInitializer<int64_t>({1}, {-1});
|
||||
builder.AddNode("ReduceMean", {data_arg, axes}, {reduce_mean_out_1});
|
||||
} else {
|
||||
builder.AddNode("ReduceMean", {data_arg}, {reduce_mean_out_1}).AddAttribute("axes", std::vector<int64_t>{-1});
|
||||
}
|
||||
builder.AddNode("Sub", {data_arg, reduce_mean_out_1}, {sub_out});
|
||||
builder.AddNode("Cast", {sub_out}, {cast_out_1})
|
||||
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));
|
||||
builder.AddNode("Pow", {cast_out_1, pow_initializer}, {pow_out});
|
||||
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
|
||||
if (opset >= 18) {
|
||||
builder.AddNode("ReduceMean", {pow_out, axes}, {reduce_mean_out_2});
|
||||
} else {
|
||||
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out_2}).AddAttribute("axes", std::vector<int64_t>{-1});
|
||||
}
|
||||
builder.AddNode("Add", {reduce_mean_out_2, add_initializer}, {add_out_1});
|
||||
builder.AddNode("Sqrt", {add_out_1}, {sqrt_out});
|
||||
builder.AddNode("Div", {cast_out_1, sqrt_out}, {div_out});
|
||||
|
|
@ -4652,7 +4663,7 @@ TEST_F(GraphTransformationTests, LayerNormWithCastFusionTest_5) {
|
|||
};
|
||||
|
||||
std::unique_ptr<GraphTransformer> transformer = std::make_unique<LayerNormFusion>();
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, {14, 18}, *logger_, std::move(transformer), TransformerLevel::Level1,
|
||||
1, pre_graph_checker, post_graph_checker));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,31 @@
|
|||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
|
||||
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
|
||||
TransformerLevel baseline_level,
|
||||
TransformerLevel target_level,
|
||||
const std::vector<int64_t>& opset_versions,
|
||||
double per_sample_tolerance,
|
||||
double relative_per_sample_tolerance,
|
||||
std::unique_ptr<GraphTransformer> transformer,
|
||||
const std::function<void(SessionOptions&)>& add_session_options,
|
||||
const InlinedHashSet<std::string>& disabled_optimizers) {
|
||||
ASSERT_TRUE(transformer == nullptr);
|
||||
for (auto opset_version : opset_versions) {
|
||||
TransformerTester(build_test_case,
|
||||
check_transformed_graph,
|
||||
baseline_level,
|
||||
target_level,
|
||||
opset_version,
|
||||
per_sample_tolerance,
|
||||
relative_per_sample_tolerance,
|
||||
nullptr,
|
||||
add_session_options,
|
||||
disabled_optimizers);
|
||||
}
|
||||
}
|
||||
|
||||
void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
|
||||
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
|
||||
TransformerLevel baseline_level,
|
||||
|
|
@ -101,22 +126,36 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
|
|||
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
|
||||
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
|
||||
const std::function<Status(Graph&)>& post_graph_checker) {
|
||||
// Build the model for this test.
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = opset_version;
|
||||
domain_to_version[kMSDomain] = 1;
|
||||
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, {}, logger);
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
build_test_case(helper);
|
||||
helper.SetGraphOutputs();
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
|
||||
const std::vector<int64_t> opset_versions{opset_version};
|
||||
return TestGraphTransformer(build_test_case, opset_versions, logger, std::move(transformer),
|
||||
level, steps, pre_graph_checker, post_graph_checker);
|
||||
}
|
||||
|
||||
Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
|
||||
const std::vector<int64_t>& opset_versions,
|
||||
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
|
||||
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
|
||||
const std::function<Status(Graph&)>& post_graph_checker) {
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{steps};
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.Register(std::move(transformer), level));
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
|
||||
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
|
||||
|
||||
for (auto opset : opset_versions) {
|
||||
// Build the model for this test.
|
||||
std::unordered_map<std::string, int> domain_to_version;
|
||||
domain_to_version[kOnnxDomain] = opset;
|
||||
domain_to_version[kMSDomain] = 1;
|
||||
Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
|
||||
domain_to_version, {}, logger);
|
||||
Graph& graph = model.MainGraph();
|
||||
ModelTestBuilder helper(graph);
|
||||
build_test_case(helper);
|
||||
helper.SetGraphOutputs();
|
||||
ORT_RETURN_IF_ERROR(graph.Resolve());
|
||||
ORT_RETURN_IF_ERROR(pre_graph_checker(graph));
|
||||
ORT_RETURN_IF_ERROR(graph_transformation_mgr.ApplyTransformers(graph, level, logger));
|
||||
ORT_RETURN_IF_ERROR(post_graph_checker(graph));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -50,6 +50,10 @@ class ModelTestBuilder {
|
|||
ModelTestBuilder(Graph& graph) : graph_(graph) {
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept {
|
||||
return graph_.DomainToVersionMap();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
NodeArg* MakeInput(const std::vector<int64_t>& shape, const std::vector<T>& data) {
|
||||
ONNX_NAMESPACE::TypeProto type_proto;
|
||||
|
|
@ -356,6 +360,17 @@ void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& buil
|
|||
const std::function<void(SessionOptions&)>& add_session_options = {},
|
||||
const InlinedHashSet<std::string>& disabled_optimizers = {});
|
||||
|
||||
void TransformerTester(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
|
||||
const std::function<void(InferenceSessionWrapper& session)>& check_transformed_graph,
|
||||
TransformerLevel baseline_level,
|
||||
TransformerLevel target_level,
|
||||
const std::vector<int64_t>& opset_versions,
|
||||
double per_sample_tolerance = 0.0,
|
||||
double relative_per_sample_tolerance = 0.0,
|
||||
std::unique_ptr<GraphTransformer> transformer = nullptr, // must be null in this case.
|
||||
const std::function<void(SessionOptions&)>& add_session_options = {},
|
||||
const InlinedHashSet<std::string>& disabled_optimizers = {});
|
||||
|
||||
/**
|
||||
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
|
||||
*
|
||||
|
|
@ -372,5 +387,23 @@ Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>&
|
|||
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
|
||||
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
|
||||
const std::function<Status(Graph&)>& post_graph_checker);
|
||||
|
||||
/**
|
||||
* @brief Apply a GraphTransformer to a graph, and run graph checkers before and after applying the transformer.
|
||||
*
|
||||
* @param build_test_case The function to build a graph for testing
|
||||
* @param opset_versions A graph is created and tested for every opset in this set
|
||||
* @param logger The logger
|
||||
* @param transformer The GraphTransformer to be applied
|
||||
* @param level The transformer level on which the transformer will be applied
|
||||
* @param steps The step count of the GraphTransformerManager
|
||||
* @param pre_graph_checker The graph checker function before applying the transformer
|
||||
* @param post_graph_checker The graph checker function after applying the transformer
|
||||
*/
|
||||
Status TestGraphTransformer(const std::function<void(ModelTestBuilder& helper)>& build_test_case,
|
||||
const std::vector<int64_t>& opset_versions,
|
||||
const logging::Logger& logger, std::unique_ptr<GraphTransformer> transformer,
|
||||
TransformerLevel level, unsigned steps, const std::function<Status(Graph&)>& pre_graph_checker,
|
||||
const std::function<Status(Graph&)>& post_graph_checker);
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -278,6 +278,9 @@ TEST(NhwcTransformerTests, ConvSplit) {
|
|||
conv_output_arg, .37f, 131);
|
||||
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
Node& split_node = builder.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg});
|
||||
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(2));
|
||||
}
|
||||
split_node.AddAttribute("axis", static_cast<int64_t>(axis));
|
||||
builder.AddQLinearBinaryNode("QLinearAdd",
|
||||
split_output1_arg, .37f, 131,
|
||||
|
|
@ -302,6 +305,11 @@ TEST(NhwcTransformerTests, ConvSplit) {
|
|||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3);
|
||||
TransformerTester(build_test_case,
|
||||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3,
|
||||
18);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -323,6 +331,9 @@ TEST(NhwcTransformerTests, ConvSplitQLinearConcat) {
|
|||
conv_node.AddAttribute("pads", std::vector<int64_t>{1, 1, 1, 1});
|
||||
|
||||
Node& split_node = builder.AddNode("Split", {conv_output_arg}, {split_output1_arg, split_output2_arg});
|
||||
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(2));
|
||||
}
|
||||
split_node.AddAttribute("axis", static_cast<int64_t>(axis));
|
||||
|
||||
Node& qlconcat_node = builder.AddQLinearConcatLike(
|
||||
|
|
@ -346,6 +357,11 @@ TEST(NhwcTransformerTests, ConvSplitQLinearConcat) {
|
|||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3);
|
||||
TransformerTester(build_test_case,
|
||||
check_nhwc_graph,
|
||||
TransformerLevel::Level2,
|
||||
TransformerLevel::Level3,
|
||||
18);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -378,6 +378,9 @@ GetQDQTestCaseFn BuildConsolidationTestCase(
|
|||
auto* split_output_3 = builder.MakeIntermediate();
|
||||
Node& split_node = builder.AddNode("Split", {upper_dq_output}, {split_output_1, split_output_2, split_output_3});
|
||||
split_node.AddAttribute("axis", axis);
|
||||
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(3));
|
||||
}
|
||||
|
||||
// add Q
|
||||
auto* lower_q_output_1 = builder.MakeIntermediate();
|
||||
|
|
@ -456,6 +459,9 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(
|
|||
auto* split_output_3 = builder.MakeIntermediate();
|
||||
Node& split_node = builder.AddNode("Split", {dq_output}, {split_output_1, split_output_2, split_output_3});
|
||||
split_node.AddAttribute("axis", axis);
|
||||
if (builder.DomainToVersionMap().find(kOnnxDomain)->second >= 18) {
|
||||
split_node.AddAttribute("num_outputs", static_cast<int64_t>(3));
|
||||
}
|
||||
|
||||
// add Q
|
||||
auto* q_split_output_1 = builder.MakeOutput();
|
||||
|
|
|
|||
|
|
@ -67,6 +67,14 @@ void QDQTransformerConvTests() {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(BuildQDQConvTestCase<InputType, WeightType, BiasType, OutputType>(input_shape, weights_shape),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37}, {32, 12, 5});
|
||||
|
|
@ -157,10 +165,13 @@ TEST(QDQTransformerTests, ConvMaxPoolReshape_UInt8) {
|
|||
|
||||
test_case({1, 12, 37}, {32, 12, 5}, 11);
|
||||
test_case({1, 12, 37}, {32, 12, 5}, 12);
|
||||
test_case({1, 12, 37}, {32, 12, 5}, 18);
|
||||
test_case({1, 23, 13, 13}, {30, 23, 3, 3}, 11);
|
||||
test_case({1, 23, 13, 13}, {30, 23, 3, 3}, 12);
|
||||
test_case({1, 23, 13, 13}, {30, 23, 3, 3}, 18);
|
||||
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}, 11);
|
||||
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}, 12);
|
||||
test_case({1, 22, 11, 13, 15}, {30, 22, 5, 3, 3}, 18);
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, ConvMaxPoolReshape_Int8) {
|
||||
|
|
@ -292,6 +303,14 @@ void QDQTransformerAveragePoolTests() {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(BuildQDQAveragePoolTestCase<InputType, OutputType>(input_shape),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37});
|
||||
|
|
@ -341,6 +360,14 @@ void QDQTransformerGlobalAveragePoolTests() {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(BuildQDQGlobalAveragePoolTestCase<InputType, OutputType>(input_shape),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37});
|
||||
|
|
@ -391,6 +418,14 @@ void QDQTransformerBinaryOpTests(const std::string& op_type) {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(BuildBinaryOpTestCase<Input1Type, Input2Type, OutputType>(input_shape, op_type),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37});
|
||||
|
|
@ -522,6 +557,14 @@ void QDQTransformerMatMulTests(bool has_output_q) {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 2, 2}, {1, 2, 4});
|
||||
|
|
@ -677,6 +720,14 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(build_test_case,
|
||||
check_binary_op_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({2, 2}, {2, 4});
|
||||
|
|
@ -813,6 +864,14 @@ TEST(QDQTransformerTests, DoubleQDQ) {
|
|||
12,
|
||||
(scale_1 + scale_3) / 2,
|
||||
0.01);
|
||||
TransformerTester(
|
||||
BuildDoubleQDQTestCases<int8_t, int8_t, int8_t, int8_t>(zp_1, zp_2, zp_3, zp_4, scale_1, scale_2, scale_3, scale_4),
|
||||
succeed ? expect_succeed : expect_fail,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level1,
|
||||
18,
|
||||
(scale_1 + scale_3) / 2,
|
||||
0.01);
|
||||
};
|
||||
|
||||
auto test_case_2u8_2s8_failed = [&](uint8_t zp_1, uint8_t zp_2, int8_t zp_3, int8_t zp_4,
|
||||
|
|
@ -870,7 +929,8 @@ TEST(QDQTransformerTests, Split) {
|
|||
TransformerTester(BuildQDQSplitTestCase<int8_t, int8_t>(input_shape, axis),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2);
|
||||
TransformerLevel::Level2,
|
||||
{12, 18});
|
||||
};
|
||||
test_case({6, 18, 54}, 0);
|
||||
}
|
||||
|
|
@ -887,7 +947,7 @@ TEST(QDQTransformerTests, Split_without_IdenticalChildrenConsolidation) {
|
|||
TransformerTester(BuildConsolidationTestCase<int8_t, int8_t>(input_shape, axis),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2, 12, {}, {}, nullptr, {},
|
||||
TransformerLevel::Level2, {12, 18}, {}, {}, nullptr, {},
|
||||
{"IdenticalChildrenConsolidation"});
|
||||
};
|
||||
test_case({6, 18, 54}, 0);
|
||||
|
|
@ -904,7 +964,8 @@ TEST(QDQTransformerTests, Split_with_IdenticalChildrenConsolidation) {
|
|||
TransformerTester(BuildConsolidationTestCase<int8_t, int8_t>(input_shape, axis),
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2);
|
||||
TransformerLevel::Level2,
|
||||
{12, 18});
|
||||
};
|
||||
test_case({6, 18, 54}, 0);
|
||||
}
|
||||
|
|
@ -1509,7 +1570,7 @@ TEST(QDQTransformerTests, ConvAveragePoolReshape_Int8_Fail) {
|
|||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
12 /*opset_version*/,
|
||||
{12, 18} /*opset_version*/,
|
||||
0.01f /*per_sample_tolerance*/,
|
||||
0.01f /*relative_per_sample_tolerance*/);
|
||||
};
|
||||
|
|
@ -1566,6 +1627,14 @@ void QDQTransformerLeakyReluTests() {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37});
|
||||
|
|
@ -1635,6 +1704,14 @@ void QDQTransformerSigmoidTests() {
|
|||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Level1,
|
||||
TransformerLevel::Level2,
|
||||
18 /*opset_version*/,
|
||||
0.01 /*per_sample_tolerance*/,
|
||||
0.01 /*relative_per_sample_tolerance*/,
|
||||
std::make_unique<QDQSelectorActionTransformer>(QDQIsInt8Allowed()));
|
||||
};
|
||||
|
||||
test_case({1, 12, 37});
|
||||
|
|
@ -1907,7 +1984,7 @@ TEST(QDQTransformerTests, DQForward_MutilpleSteps) {
|
|||
TEST(QDQTransformerTests, Clip) {
|
||||
constexpr float epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
auto test_case = [&](float scale, auto zero_point, int clip_count, int opset_version = 12) {
|
||||
auto test_case = [&](float scale, auto zero_point, int clip_count, int opset_version) {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<int8_t>({1, 32, 112, 112},
|
||||
std::numeric_limits<int8_t>::min(),
|
||||
|
|
@ -1922,7 +1999,9 @@ TEST(QDQTransformerTests, Clip) {
|
|||
auto* clip_output = builder.MakeIntermediate();
|
||||
constexpr float min = .0f;
|
||||
constexpr float max = 6.0f;
|
||||
if (opset_version >= 11) {
|
||||
auto opset = builder.DomainToVersionMap().find(kOnnxDomain)->second;
|
||||
EXPECT_EQ(opset_version, opset);
|
||||
if (opset >= 11) {
|
||||
auto* min_initializer = builder.MakeScalarInitializer<float>(min);
|
||||
auto* max_initializer = builder.MakeScalarInitializer<float>(max);
|
||||
builder.AddNode("Clip", {dq_output, min_initializer, max_initializer}, {clip_output});
|
||||
|
|
@ -1953,18 +2032,21 @@ TEST(QDQTransformerTests, Clip) {
|
|||
epsilon);
|
||||
};
|
||||
|
||||
test_case(.0235294122248888f, static_cast<int8_t>(-128), 0); // [0, 6]
|
||||
test_case(.02f, static_cast<int8_t>(-128), 0); // [0, 5.1]
|
||||
test_case(.03f, static_cast<int8_t>(-128), 1); // [0, 7.65]
|
||||
test_case(.02f, static_cast<int8_t>(127), 1); // [-5.1 , 0]
|
||||
test_case(.02f, static_cast<int8_t>(0), 1); // [-2.56, 2.54]
|
||||
test_case(.04f, static_cast<int8_t>(-97), 1); // [-1.24, 8.96]
|
||||
test_case(.02352941176f, static_cast<uint8_t>(0), 0); // [0, 6]
|
||||
test_case(.02f, static_cast<uint8_t>(0), 0); // [0, 5.1]
|
||||
test_case(.03f, static_cast<uint8_t>(0), 1); // [0, 7.65]
|
||||
test_case(.02f, static_cast<uint8_t>(255), 1); // [-5.1, 0]
|
||||
test_case(.02f, static_cast<uint8_t>(128), 1); // [-2.56, 2.54]
|
||||
test_case(.04f, static_cast<uint8_t>(31), 1); // [-1.24, 8.96]
|
||||
std::vector<int64_t> opsets{12, 18};
|
||||
for (auto opset : opsets) {
|
||||
test_case(.0235294122248888f, static_cast<int8_t>(-128), 0, opset); // [0, 6]
|
||||
test_case(.02f, static_cast<int8_t>(-128), 0, opset); // [0, 5.1]
|
||||
test_case(.03f, static_cast<int8_t>(-128), 1, opset); // [0, 7.65]
|
||||
test_case(.02f, static_cast<int8_t>(127), 1, opset); // [-5.1 , 0]
|
||||
test_case(.02f, static_cast<int8_t>(0), 1, opset); // [-2.56, 2.54]
|
||||
test_case(.04f, static_cast<int8_t>(-97), 1, opset); // [-1.24, 8.96]
|
||||
test_case(.02352941176f, static_cast<uint8_t>(0), 0, opset); // [0, 6]
|
||||
test_case(.02f, static_cast<uint8_t>(0), 0, opset); // [0, 5.1]
|
||||
test_case(.03f, static_cast<uint8_t>(0), 1, opset); // [0, 7.65]
|
||||
test_case(.02f, static_cast<uint8_t>(255), 1, opset); // [-5.1, 0]
|
||||
test_case(.02f, static_cast<uint8_t>(128), 1, opset); // [-2.56, 2.54]
|
||||
test_case(.04f, static_cast<uint8_t>(31), 1, opset); // [-1.24, 8.96]
|
||||
}
|
||||
|
||||
// opset_version = 10
|
||||
test_case(.02f, static_cast<int8_t>(-128), 0, 10); // [0, 5.1]
|
||||
|
|
@ -1973,10 +2055,12 @@ TEST(QDQTransformerTests, Clip) {
|
|||
test_case(.03f, static_cast<uint8_t>(0), 1, 10); // [0, 7.65]
|
||||
|
||||
// difference between lower/upper and min/max are within epsilon
|
||||
test_case(epsilon, static_cast<int8_t>(-127), 0); // [-epsilon, x] (x <= 6 + epsilon)
|
||||
test_case((6 + epsilon) / 255, static_cast<int8_t>(-128), 0); // [0, 6 + epsilon]
|
||||
test_case(epsilon, static_cast<uint8_t>(1), 0); // [-epsilon, x] (x <= 6 + epsilon)
|
||||
test_case((6 + epsilon) / 255, static_cast<uint8_t>(0), 0); // [0, 6 + epsilon]
|
||||
for (auto opset : opsets) {
|
||||
test_case(epsilon, static_cast<int8_t>(-127), 0, opset); // [-epsilon, x] (x <= 6 + epsilon)
|
||||
test_case((6 + epsilon) / 255, static_cast<int8_t>(-128), 0, opset); // [0, 6 + epsilon]
|
||||
test_case(epsilon, static_cast<uint8_t>(1), 0, opset); // [-epsilon, x] (x <= 6 + epsilon)
|
||||
test_case((6 + epsilon) / 255, static_cast<uint8_t>(0), 0, opset); // [0, 6 + epsilon]
|
||||
}
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, Concat) {
|
||||
|
|
@ -2536,7 +2620,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) {
|
|||
|
||||
// regression test to validate TransposeOptimizer and QDQ Propagation don't loop
|
||||
// see https://github.com/microsoft/onnxruntime/issues/11605
|
||||
TEST(QDQTransformerTests, QDQPropagation_GH11605) {
|
||||
TEST(QDQTransformerTests, QDQPropagation_GH11605_Opset12) {
|
||||
auto test_case = [&]() {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<uint8_t>({1, 4, 4},
|
||||
|
|
@ -2585,7 +2669,61 @@ TEST(QDQTransformerTests, QDQPropagation_GH11605) {
|
|||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level2);
|
||||
TransformerLevel::Level2,
|
||||
12);
|
||||
};
|
||||
|
||||
test_case();
|
||||
}
|
||||
|
||||
TEST(QDQTransformerTests, QDQPropagation_GH11605_Opset13) {
|
||||
auto test_case = [&]() {
|
||||
auto build_test_case = [&](ModelTestBuilder& builder) {
|
||||
auto* input_arg = builder.MakeInput<uint8_t>({1, 4, 4},
|
||||
std::numeric_limits<uint8_t>::min(),
|
||||
std::numeric_limits<uint8_t>::max());
|
||||
// add DQ
|
||||
auto* dq_output = builder.MakeIntermediate();
|
||||
builder.AddDequantizeLinearNode(input_arg, 0.123f, uint8_t(0), dq_output);
|
||||
|
||||
// add Transpose 0, 2, 1
|
||||
const std::vector<int64_t>& perms{0, 2, 1};
|
||||
auto* transpose_output = builder.MakeIntermediate();
|
||||
Node& transpose_node = builder.AddNode("Transpose", {dq_output}, {transpose_output});
|
||||
transpose_node.AddAttribute("perm", perms);
|
||||
|
||||
// add Softmax with axis=2 (to block the Transpose moving past it due to the transpose perms)
|
||||
auto* softmax_output = builder.MakeIntermediate();
|
||||
Node& softmax_node = builder.AddNode("Softmax", {transpose_output}, {softmax_output});
|
||||
softmax_node.AddAttribute("axis", int64_t(2));
|
||||
|
||||
// add second Transpose. this is so the check in TransposeOptimizer::ProcessTranspose for outputs leading to
|
||||
// a Transpose is satisfied, allowing the first Transpose to move past the Q/DQ inserted by QDQ Propagation
|
||||
Node& transpose_node2 = builder.AddNode("Transpose", {softmax_output}, {builder.MakeOutput()});
|
||||
transpose_node2.AddAttribute("perm", perms);
|
||||
};
|
||||
|
||||
// check that an edge case where transpose optimization gets blocked is handled gracefully.
|
||||
// Original: DQ -> Tr -> SoftM -> Tr
|
||||
// QDQ Prop inserts a Q/DQ pair to create a QDQ node group for the Transpose: DQ -> Tr -> Q -> DQ -> SoftM -> Tr
|
||||
// Transpose opt phase 1 moves the Tr down until it blocks on the SoftMax: DQ -> Q -> DQ -> Tr -> SoftM -> Tr
|
||||
// Transpose opt phase 2 flips the Tr to prior to the DQ as it's not part of a QDQ node group at that point, as
|
||||
// running the transpose on 8-bit data should be cheaper: DQ -> Q -> Tr -> DQ -> SoftM -> Tr
|
||||
// QDQ cleanup in Level2 removes the unnecessary DQ/Q pair at the start: Tr -> DQ -> SoftM -> Tr
|
||||
// this is the optimal result as the Transpose is using 8-bit data and we have no surplus Q/DQ pairs
|
||||
auto check_graph = [&](InferenceSessionWrapper& session) {
|
||||
std::vector<std::string> expected_op_types_in_order{
|
||||
"DequantizeLinear",
|
||||
"Softmax"};
|
||||
const auto op_types_in_order = GetNodeOpTypesInTopologicalOrder(session.GetGraph());
|
||||
EXPECT_EQ(op_types_in_order, expected_op_types_in_order);
|
||||
};
|
||||
|
||||
TransformerTester(build_test_case,
|
||||
check_graph,
|
||||
TransformerLevel::Default,
|
||||
TransformerLevel::Level2,
|
||||
13);
|
||||
};
|
||||
|
||||
test_case();
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue