From f26c149d7dab96da24f516cb9f785f2d2c656e0a Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 24 Jun 2020 13:43:22 +0800 Subject: [PATCH] Set NonZero Output Shape for Gradient Building. (#4246) * Set NonZero output shape for gradient building. * Resolve comments. Co-authored-by: Vincent Wang --- .../transform/nonzero_shape_setter.onnx | Bin 0 -> 258 bytes .../transform/nonzero_shape_setter.py | 42 ++++++++++++++++++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../core/optimizer/nonzero_shape_setter.cc | 32 +++++++++++++ .../core/optimizer/nonzero_shape_setter.h | 26 +++++++++++ .../test/optimizer/graph_transform_test.cc | 20 +++++++++ 6 files changed, 122 insertions(+) create mode 100644 onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx create mode 100644 onnxruntime/test/testdata/transform/nonzero_shape_setter.py create mode 100644 orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc create mode 100644 orttraining/orttraining/core/optimizer/nonzero_shape_setter.h diff --git a/onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx b/onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx new file mode 100644 index 0000000000000000000000000000000000000000..51d0fdd9ea306d816026ffae8f602c27d7d3bd3b GIT binary patch literal 258 zcmYL@u?oU45QZD6HXc=K7NKKD!R|VV2yRYJE~VDf7Mfg0QV^fUr}BBM($eAJCNRkKC1L)Ggu1Kb;OD?pkEF%r#~fokNpJx?z1;_EqN) mawa&8xc{4L1YM+> GeneratePreTrainingTransformers(T rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc new file mode 100644 index 0000000000..91a5ff92f5 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "orttraining/core/optimizer/nonzero_shape_setter.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { + +Status NonZeroShapeSetter::Apply(Graph& /*graph*/, + Node& node, + RewriteRuleEffect& rule_effect, + const logging::Logger& /*logger*/) const { + // The output shape of the NonZero is [num_of_input_dims, dynamic_nonzero_element_counts]. + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(node.InputDefs()[0]->Shape()->dim_size()); + result_shape.add_dim()->set_dim_param(node.OutputDefs()[0]->Name() + "_nonzero_count"); + node.MutableOutputDefs()[0]->SetShape(result_shape); + rule_effect = RewriteRuleEffect::kUpdatedCurrentNode; + return Status::OK(); +} + +bool NonZeroShapeSetter::SatisfyCondition(const Graph& /*graph*/, + const Node& node, + const logging::Logger& /*logger*/) const { + return node.InputDefs()[0]->Shape() != nullptr + && node.InputDefs()[0]->Shape()->dim_size() > 0 + && node.OutputDefs()[0]->Shape() == nullptr; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/nonzero_shape_setter.h b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.h new file mode 100644 index 0000000000..f228759a35 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +// Rewrite rule that set the output shape of NonZero Ops. +class NonZeroShapeSetter : public RewriteRule { + public: + NonZeroShapeSetter() noexcept + : RewriteRule("NonZeroShapeSetter") { + } + + std::vector TargetOpTypes() const noexcept override { + return {"NonZero"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 38bd46d946..b41b01609f 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -11,6 +11,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/utils.h" #include "orttraining/core/optimizer/gist_encode_decode.h" +#include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" @@ -58,6 +59,25 @@ Node* GetNodeByName(Graph& graph, std::string node_name) { return nullptr; } +TEST_F(GraphTransformationTests, NonZeroShapeSetter) { + auto model_uri = MODEL_FOLDER "nonzero_shape_setter.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = onnxruntime::make_unique("NonZeroShapeSetter1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + + auto nonzero_shape = GetNodeByName(graph, "nonzero")->OutputDefs()[0]->Shape(); + ASSERT_TRUE(nonzero_shape->dim_size() == 2); + ASSERT_TRUE(nonzero_shape->dim(0).dim_value() == 2); + ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count"); +} // MegatronF/G is defined only for training, and in msdomain. #ifndef DISABLE_CONTRIB_OPS