diff --git a/onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx b/onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx new file mode 100644 index 0000000000..51d0fdd9ea Binary files /dev/null and b/onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx differ diff --git a/onnxruntime/test/testdata/transform/nonzero_shape_setter.py b/onnxruntime/test/testdata/transform/nonzero_shape_setter.py new file mode 100644 index 0000000000..e0694de9a2 --- /dev/null +++ b/onnxruntime/test/testdata/transform/nonzero_shape_setter.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import onnx +from onnx import helper +from onnx import TensorProto, GraphProto, OperatorSetIdProto +from onnx import numpy_helper +import numpy as np + +X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 4]) +Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) +Y.type.tensor_type.shape.Clear() + +nonzero = helper.make_node('NonZero', ['input'], ['nonzero'], name='nonzero') +transpose = helper.make_node('Transpose', ['nonzero'], ['transpose'], name='transpose', perm=[1,0]) +gathernd = helper.make_node('GatherND', ['input', 'transpose'], ['output'], name='gathernd') + +# Create the graph (GraphProto) +graph_def = helper.make_graph( + [nonzero, transpose, gathernd], + 'nonzero_shape_setter_model', + [X], + [Y] +) + +opsets = [] +onnxdomain = OperatorSetIdProto() +onnxdomain.version = 12 +onnxdomain.domain = "" # Empty string implies the operator set that is defined as part of the ONNX specification. +opsets.append(onnxdomain) + +msdomain = OperatorSetIdProto() +msdomain.version = 1 +msdomain.domain = 'com.microsoft' + +opsets.append(msdomain) +kwargs={} +kwargs['opset_imports'] = opsets + +# Create the model (ModelProto) +model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +onnx.save(model_def, 'nonzero_shape_setter.onnx') diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 796a2746a4..2c1e7a1e58 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -5,6 +5,7 @@ #include "orttraining/core/optimizer/graph_transformer_utils.h" #include "orttraining/core/optimizer/insert_output_rewriter.h" #include "orttraining/core/optimizer/megatron_transformer.h" +#include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/slice_elimination.h" #include "core/optimizer/conv_mul_fusion.h" @@ -58,6 +59,7 @@ std::vector> GeneratePreTrainingTransformers(T rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); + rule_transformer->Register(make_unique()); rule_transformer->Register(make_unique()); transformers.emplace_back(onnxruntime::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc new file mode 100644 index 0000000000..91a5ff92f5 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "orttraining/core/optimizer/nonzero_shape_setter.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { + +Status NonZeroShapeSetter::Apply(Graph& /*graph*/, + Node& node, + RewriteRuleEffect& rule_effect, + const logging::Logger& /*logger*/) const { + // The output shape of the NonZero is [num_of_input_dims, dynamic_nonzero_element_counts]. + ONNX_NAMESPACE::TensorShapeProto result_shape; + result_shape.add_dim()->set_dim_value(node.InputDefs()[0]->Shape()->dim_size()); + result_shape.add_dim()->set_dim_param(node.OutputDefs()[0]->Name() + "_nonzero_count"); + node.MutableOutputDefs()[0]->SetShape(result_shape); + rule_effect = RewriteRuleEffect::kUpdatedCurrentNode; + return Status::OK(); +} + +bool NonZeroShapeSetter::SatisfyCondition(const Graph& /*graph*/, + const Node& node, + const logging::Logger& /*logger*/) const { + return node.InputDefs()[0]->Shape() != nullptr + && node.InputDefs()[0]->Shape()->dim_size() > 0 + && node.OutputDefs()[0]->Shape() == nullptr; +} + +} // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/nonzero_shape_setter.h b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.h new file mode 100644 index 0000000000..f228759a35 --- /dev/null +++ b/orttraining/orttraining/core/optimizer/nonzero_shape_setter.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { + +// Rewrite rule that set the output shape of NonZero Ops. +class NonZeroShapeSetter : public RewriteRule { + public: + NonZeroShapeSetter() noexcept + : RewriteRule("NonZeroShapeSetter") { + } + + std::vector TargetOpTypes() const noexcept override { + return {"NonZero"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index 38bd46d946..b41b01609f 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -11,6 +11,7 @@ #include "core/optimizer/rule_based_graph_transformer.h" #include "core/optimizer/utils.h" #include "orttraining/core/optimizer/gist_encode_decode.h" +#include "orttraining/core/optimizer/nonzero_shape_setter.h" #include "orttraining/core/optimizer/megatron_transformer.h" #include "test/optimizer/graph_transform_test_fixture.h" #include "test/util/include/default_providers.h" @@ -58,6 +59,25 @@ Node* GetNodeByName(Graph& graph, std::string node_name) { return nullptr; } +TEST_F(GraphTransformationTests, NonZeroShapeSetter) { + auto model_uri = MODEL_FOLDER "nonzero_shape_setter.onnx"; + std::shared_ptr p_model; + ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK()); + Graph& graph = p_model->MainGraph(); + + auto rule_transformer_L1 = onnxruntime::make_unique("NonZeroShapeSetter1"); + rule_transformer_L1->Register(onnxruntime::make_unique()); + onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; + graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1); + + auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_); + ASSERT_TRUE(ret.IsOK()); + + auto nonzero_shape = GetNodeByName(graph, "nonzero")->OutputDefs()[0]->Shape(); + ASSERT_TRUE(nonzero_shape->dim_size() == 2); + ASSERT_TRUE(nonzero_shape->dim(0).dim_value() == 2); + ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count"); +} // MegatronF/G is defined only for training, and in msdomain. #ifndef DISABLE_CONTRIB_OPS