Support More Cases in NoOpElimination (#13460)

Current NoOpElimination can support only Add node. This PR adds support
for: x-0, x*1, 1*x and x/1 besides x+0 and 0+x.

With this PR, all Div(x,1) and their gradients (also Div(x,1)) in
Huggingface's diffusers model can be removed, which takes ~1% of compute
time in total previously.
This commit is contained in:
Vincent Wang 2022-11-01 10:39:52 +08:00 committed by GitHub
parent 3d0db47c17
commit 17f0ffd1c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 307 additions and 31 deletions

View file

@ -13,14 +13,7 @@
namespace onnxruntime {
/**
Eliminate no op node - handling Add op for now
Add example:
X 0
\ /
Add
|
Y
Eliminate no op node - supporting x+0, 0+x, x-0, x*1, 1*x and x/1 for now.
*/
Status NoopElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger&) const {
if (graph_utils::RemoveNode(graph, node)) {
@ -31,7 +24,6 @@ Status NoopElimination::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_
}
bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const {
bool input0_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[0]->Name());
bool input1_is_initializer = graph_utils::IsConstantInitializer(graph, node.InputDefs()[1]->Name());
@ -40,7 +32,13 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con
return false;
}
const auto* initializer = graph_utils::GetConstantInitializer(graph, node.InputDefs()[input0_is_initializer ? 0 : 1]->Name());
const auto& op_type = node.OpType();
if ((op_type == "Sub" || op_type == "Div") && !input1_is_initializer) {
return false;
}
const auto* initializer =
graph_utils::GetConstantInitializer(graph, node.InputDefs()[input0_is_initializer ? 0 : 1]->Name());
// if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule,
// but it won't happen if the case is accepted, thus reject it
@ -59,42 +57,38 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con
if (add_init.size() == 0) {
return true;
}
float value = 0.0f;
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
if (*add_init.data<float>() != 0.f) {
return false;
}
value = *add_init.data<float>();
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
if (math::halfToFloat(add_init.data<MLFloat16>()->val) != 0.f) {
return false;
}
value = math::halfToFloat(add_init.data<MLFloat16>()->val);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
if (*add_init.data<double>() != static_cast<double>(0.f)) {
return false;
}
value = static_cast<float>(*add_init.data<double>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
if (*add_init.data<int32_t>() != static_cast<int32_t>(0)) {
return false;
}
value = static_cast<float>(*add_init.data<int32_t>());
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
if (*add_init.data<int64_t>() != static_cast<int64_t>(0)) {
return false;
}
value = static_cast<float>(*add_init.data<int64_t>());
break;
default:
return false;
}
// reject node output is graph output for now
if (!graph_utils::CanRemoveNode(graph, node, logger)) {
if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) {
return false;
}
return true;
if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) {
return false;
}
// reject node output is graph output for now
return graph_utils::CanRemoveNode(graph, node, logger);
}
} // namespace onnxruntime

View file

@ -11,15 +11,14 @@ namespace onnxruntime {
@Class NoopElimination
Rewrite rule that eliminates the no op node.
So far only Add node with 0 as one of its inputs is eliminated.
But this class could be the placeholder for other no op nodes in future.
Support x+0, 0+x, x-0, x*1, 1*x and x/1 for now.
*/
class NoopElimination : public RewriteRule {
public:
NoopElimination() noexcept : RewriteRule("NoopElimination") {}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"Add"};
return {"Add", "Sub", "Mul", "Div"};
}
private:

View file

@ -214,6 +214,289 @@ TEST_F(GraphTransformationTests, NoopElimination) {
op_to_count = CountOpsInGraph(graph);
ASSERT_TRUE(op_to_count["Add"] == 1);
auto pre_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Add"] + CountOpsInGraph(graph)["Sub"] + CountOpsInGraph(graph)["Mul"] +
CountOpsInGraph(graph)["Div"],
1);
};
auto post_graph_checker = [&](Graph& graph) {
ASSERT_EQ(CountOpsInGraph(graph)["Add"] + CountOpsInGraph(graph)["Sub"] + CountOpsInGraph(graph)["Mul"] +
CountOpsInGraph(graph)["Div"],
0);
};
// x+0, float.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<float>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<float>({}, {0.0f});
auto* add_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Add", {matmul_output, initializer_arg}, {add_out});
builder.AddNode("Identity", {add_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// 0+x, fp16.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<MLFloat16>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<MLFloat16>({1}, {MLFloat16(0.0f)});
auto* add_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Add", {initializer_arg, matmul_output}, {add_out});
builder.AddNode("Identity", {add_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// x-0, double.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<double>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<double>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<double>({1, 1}, {static_cast<double>(0.0f)});
auto* sub_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Sub", {matmul_output, initializer_arg}, {sub_out});
builder.AddNode("Identity", {sub_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// x*1, int32.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<int32_t>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<int32_t>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<int32_t>({1, 1, 1}, {1});
auto* mul_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Mul", {matmul_output, initializer_arg}, {mul_out});
builder.AddNode("Identity", {mul_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// 1*x, int64.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<int64_t>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<int64_t>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<int64_t>({1, 1, 1, 1}, {static_cast<int64_t>(1)});
auto* mul_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Mul", {initializer_arg, matmul_output}, {mul_out});
builder.AddNode("Identity", {mul_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// x/1, float.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<float>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<float>({}, {1.0f});
auto* div_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Div", {matmul_output, initializer_arg}, {div_out});
builder.AddNode("Identity", {div_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, post_graph_checker);
}
// Invalid case: x+1.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<float>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<float>({}, {1.0f});
auto* add_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Add", {matmul_output, initializer_arg}, {add_out});
builder.AddNode("Identity", {add_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
// Invalid case: initializer rank is larger.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<MLFloat16>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<MLFloat16>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<MLFloat16>({1, 1, 1, 1, 1}, {MLFloat16(0.0f)});
auto* add_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Add", {initializer_arg, matmul_output}, {add_out});
builder.AddNode("Identity", {add_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
// Invalid case: 0-x.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<double>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<double>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<double>({1, 1}, {static_cast<double>(0.0f)});
auto* sub_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Sub", {initializer_arg, matmul_output}, {sub_out});
builder.AddNode("Identity", {sub_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
// Invalid case: x-1.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<double>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<double>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<double>({1, 1}, {static_cast<double>(1.0f)});
auto* sub_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Sub", {matmul_output, initializer_arg}, {sub_out});
builder.AddNode("Identity", {sub_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
// Invalid case: 0*x.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<int32_t>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<int32_t>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<int32_t>({1, 1, 1}, {0});
auto* mul_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Mul", {initializer_arg, matmul_output}, {mul_out});
builder.AddNode("Identity", {mul_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
// Invalid case: output is graph output.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<int64_t>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<int64_t>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<int64_t>({1, 1, 1, 1}, {static_cast<int64_t>(1)});
auto* mul_out = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Mul", {initializer_arg, matmul_output}, {mul_out});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
// Invalid case: 1/x.
{
auto build_test_case = [&](ModelTestBuilder& builder) {
auto* input1_arg = builder.MakeInput<float>({{2, 3, 3, 3}});
auto* input2_arg = builder.MakeInput<float>({{3, 3}});
auto* matmul_output = builder.MakeIntermediate();
auto* initializer_arg = builder.MakeInitializer<float>({}, {1.0f});
auto* div_out = builder.MakeIntermediate();
auto* identity_output = builder.MakeOutput();
builder.AddNode("MatMul", {input1_arg, input2_arg}, {matmul_output});
builder.AddNode("Div", {initializer_arg, matmul_output}, {div_out});
builder.AddNode("Identity", {div_out}, {identity_output});
};
auto rule_transformer = std::make_unique<RuleBasedGraphTransformer>("RuleTransformer");
ASSERT_STATUS_OK(rule_transformer->Register(std::make_unique<NoopElimination>()));
TestGraphTransformer(build_test_case, 13, *logger_, std::move(rule_transformer), TransformerLevel::Level1, 1,
pre_graph_checker, pre_graph_checker);
}
}
TEST_F(GraphTransformationTests, DropoutElimination) {