Set NonZero Output Shape for Gradient Building. (#4246)

* Set NonZero output shape for gradient building.

* Resolve comments.

Co-authored-by: Vincent Wang <weicwang@AiFramework2080ti2.corp.microsoft.com>
This commit is contained in:
Vincent Wang 2020-06-24 13:43:22 +08:00 committed by GitHub
parent 20e205aa0a
commit f26c149d7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 122 additions and 0 deletions

Binary file not shown.

View file

@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import onnx
from onnx import helper
from onnx import TensorProto, GraphProto, OperatorSetIdProto
from onnx import numpy_helper
import numpy as np
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 4])
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
Y.type.tensor_type.shape.Clear()
nonzero = helper.make_node('NonZero', ['input'], ['nonzero'], name='nonzero')
transpose = helper.make_node('Transpose', ['nonzero'], ['transpose'], name='transpose', perm=[1,0])
gathernd = helper.make_node('GatherND', ['input', 'transpose'], ['output'], name='gathernd')
# Create the graph (GraphProto)
graph_def = helper.make_graph(
[nonzero, transpose, gathernd],
'nonzero_shape_setter_model',
[X],
[Y]
)
opsets = []
onnxdomain = OperatorSetIdProto()
onnxdomain.version = 12
onnxdomain.domain = "" # Empty string implies the operator set that is defined as part of the ONNX specification.
opsets.append(onnxdomain)
msdomain = OperatorSetIdProto()
msdomain.version = 1
msdomain.domain = 'com.microsoft'
opsets.append(msdomain)
kwargs={}
kwargs['opset_imports'] = opsets
# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
onnx.save(model_def, 'nonzero_shape_setter.onnx')

View file

@ -5,6 +5,7 @@
#include "orttraining/core/optimizer/graph_transformer_utils.h"
#include "orttraining/core/optimizer/insert_output_rewriter.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
#include "core/optimizer/identity_elimination.h"
#include "core/optimizer/slice_elimination.h"
#include "core/optimizer/conv_mul_fusion.h"
@ -58,6 +59,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
rule_transformer->Register(make_unique<UnsqueezeElimination>());
rule_transformer->Register(make_unique<ExpandElimination>());
rule_transformer->Register(make_unique<CastElimination>());
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));

View file

@ -0,0 +1,32 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/graph/graph_utils.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
using namespace ONNX_NAMESPACE;
namespace onnxruntime {
Status NonZeroShapeSetter::Apply(Graph& /*graph*/,
Node& node,
RewriteRuleEffect& rule_effect,
const logging::Logger& /*logger*/) const {
// The output shape of the NonZero is [num_of_input_dims, dynamic_nonzero_element_counts].
ONNX_NAMESPACE::TensorShapeProto result_shape;
result_shape.add_dim()->set_dim_value(node.InputDefs()[0]->Shape()->dim_size());
result_shape.add_dim()->set_dim_param(node.OutputDefs()[0]->Name() + "_nonzero_count");
node.MutableOutputDefs()[0]->SetShape(result_shape);
rule_effect = RewriteRuleEffect::kUpdatedCurrentNode;
return Status::OK();
}
bool NonZeroShapeSetter::SatisfyCondition(const Graph& /*graph*/,
const Node& node,
const logging::Logger& /*logger*/) const {
return node.InputDefs()[0]->Shape() != nullptr
&& node.InputDefs()[0]->Shape()->dim_size() > 0
&& node.OutputDefs()[0]->Shape() == nullptr;
}
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/optimizer/rewrite_rule.h"
namespace onnxruntime {
// Rewrite rule that set the output shape of NonZero Ops.
class NonZeroShapeSetter : public RewriteRule {
public:
NonZeroShapeSetter() noexcept
: RewriteRule("NonZeroShapeSetter") {
}
std::vector<std::string> TargetOpTypes() const noexcept override {
return {"NonZero"};
}
private:
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
};
} // namespace onnxruntime

View file

@ -11,6 +11,7 @@
#include "core/optimizer/rule_based_graph_transformer.h"
#include "core/optimizer/utils.h"
#include "orttraining/core/optimizer/gist_encode_decode.h"
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
#include "orttraining/core/optimizer/megatron_transformer.h"
#include "test/optimizer/graph_transform_test_fixture.h"
#include "test/util/include/default_providers.h"
@ -58,6 +59,25 @@ Node* GetNodeByName(Graph& graph, std::string node_name) {
return nullptr;
}
TEST_F(GraphTransformationTests, NonZeroShapeSetter) {
auto model_uri = MODEL_FOLDER "nonzero_shape_setter.onnx";
std::shared_ptr<Model> p_model;
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
Graph& graph = p_model->MainGraph();
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("NonZeroShapeSetter1");
rule_transformer_L1->Register(onnxruntime::make_unique<NonZeroShapeSetter>());
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
ASSERT_TRUE(ret.IsOK());
auto nonzero_shape = GetNodeByName(graph, "nonzero")->OutputDefs()[0]->Shape();
ASSERT_TRUE(nonzero_shape->dim_size() == 2);
ASSERT_TRUE(nonzero_shape->dim(0).dim_value() == 2);
ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count");
}
// MegatronF/G is defined only for training, and in msdomain.
#ifndef DISABLE_CONTRIB_OPS