Fix for BiasGelu fusion optimizer (#4160)

* Fix for BiasGelu fusion optimizer

* changes per review comments
This commit is contained in:
Ashwini Khade 2020-06-09 14:33:34 -07:00 committed by GitHub
parent 2b3ce1b090
commit 9eba9fba7c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 4 deletions

View file

@ -234,7 +234,8 @@ static void MoveAllNodeInputEdges(Graph& graph, Node& src_node, Node& target_nod
auto input_edges = GetNodeInputEdges(src_node);
for (auto cur = input_edges.cbegin(), end = input_edges.cend(); cur != end; ++cur) {
graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, cur->dst_arg_index);
auto target_arg_index = GetNodeInputIndexFromInputName(target_node, cur->arg_name);
graph.AddEdge(cur->src_node, target_idx, cur->src_arg_index, target_arg_index);
}
RemoveGraphEdges(graph, input_edges);
@ -261,6 +262,16 @@ static void MoveAllNodeOutputs(Graph& graph, Node& src_node, Node& target_node)
//--- end of local helpers ---
//----------------------------
int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name) {
auto itr = std::find_if(node.InputDefs().begin(), node.InputDefs().end(),
[&input_name](const NodeArg* input) { return input->Name() == input_name; });
ORT_ENFORCE(itr != node.InputDefs().end(),
"Attempting to get index for an input which does not exist.");
auto index = std::distance(node.InputDefs().begin(), itr);
return static_cast<int>(index);
}
const std::string& GetNodeInputName(const Node& node, int index) {
const auto& inputs = node.InputDefs();
ORT_ENFORCE(index >= 0 && static_cast<size_t>(index) < inputs.size(),

View file

@ -69,6 +69,9 @@ bool AllNodeInputsAreConstant(const Graph& graph, const Node& node, InitializedT
/** Gets the name of the incoming NodeArg with the specified index for the given node. */
const std::string& GetNodeInputName(const Node& node, int index);
/** Gets the index of an input arg with the specified input arg name. */
int GetNodeInputIndexFromInputName(const Node& node, const std::string& input_name);
/** Gets the name of the outgoing NodeArg with the specified index for the given node. */
const std::string& GetNodeOutputName(const Node& node, int index);

View file

@ -48,6 +48,8 @@
#include "test/capturing_sink.h"
#include "test/framework/test_utils.h"
#include "test/optimizer/graph_transform_test_fixture.h"
#include "test/compare_ortvalue.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
#include "test/test_environment.h"
#include "asserts.h"
@ -1199,7 +1201,7 @@ TEST_F(GraphTransformationTests, ReshapeFusionGraphInputsTest) {
ASSERT_EQ(op_to_count["Concat"], 1);
ASSERT_EQ(op_to_count["Reshape"], 1);
}
TEST_F(GraphTransformationTests, ReshapeFusionMultipleValuesInInitializerSubgraphTest) {
auto model_uri = MODEL_FOLDER "fusion/reshape_fusion_multiple_values_in_initializer_tensor_1.onnx";
std::shared_ptr<Model> p_model;
@ -1399,8 +1401,6 @@ TEST_F(GraphTransformationTests, ExpandElimination) {
ASSERT_TRUE(op_to_count["Expand"] == 3);
}
TEST_F(GraphTransformationTests, CastElimination) {
auto model_uri = MODEL_FOLDER "cast_elimination.onnx";
std::shared_ptr<Model> model;
@ -1721,6 +1721,57 @@ TEST_F(GraphTransformationTests, BiasGeluTest) {
ASSERT_TRUE(op_to_count["BiasGelu"] == 1);
}
// BiasGelu allows input switching based on input dimensions.
// This test validates the input edges are plugged correct in the optimized graph.
TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) {
auto model_uri = MODEL_FOLDER "fusion/bias_gelu_fusion_format_2.onnx";
// create inputs and outputs
RandomValueGenerator random{};
NameMLValMap feeds;
OrtValue mlvalue_b_i;
std::vector<int64_t> dims_b_i = {3072};
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_b_i,
random.Uniform<float>(dims_b_i, 0.0f, 1.0f), &mlvalue_b_i);
feeds.insert(std::make_pair("B_I", mlvalue_b_i));
OrtValue mlvalue_a_i;
std::vector<int64_t> dims_a_i = {3, 512, 3072};
CreateMLValue<float>(TestCPUExecutionProvider()->GetAllocator(0, OrtMemTypeDefault), dims_a_i,
random.Uniform<float>(dims_a_i, 0.0f, 1.0f), &mlvalue_a_i);
feeds.insert(std::make_pair("A_I", mlvalue_a_i));
std::vector<std::string> output_names;
output_names.push_back("C");
auto run_model_test = [&](TransformerLevel level, std::vector<OrtValue>& fetches) {
SessionOptions session_options;
session_options.graph_optimization_level = level;
session_options.session_logid = "OptimizerTests";
InferenceSession session{session_options, GetEnvironment()};
ASSERT_TRUE(session.Load(model_uri).IsOK());
ASSERT_TRUE(session.Initialize().IsOK());
RunOptions run_options;
ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches));
};
// run model with and w/o optimizations and compare the results
std::vector<OrtValue> unoptimized_fetches;
run_model_test(TransformerLevel::Default, unoptimized_fetches);
std::vector<OrtValue> optimized_fetches;
run_model_test(TransformerLevel::MaxLevel, optimized_fetches);
// Compare results
double per_sample_tolerance = 1e-3;
double relative_per_sample_tolerance = 0.0;
auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false);
EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second;
}
// Test Gelu -> FastGelu
TEST_F(GraphTransformationTests, GeluApproximation_Gelu) {
auto model_uri = MODEL_FOLDER "approximation/gelu.onnx";