mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-29 23:06:41 +00:00
Set NonZero Output Shape for Gradient Building. (#4246)
* Set NonZero output shape for gradient building. * Resolve comments. Co-authored-by: Vincent Wang <weicwang@AiFramework2080ti2.corp.microsoft.com>
This commit is contained in:
parent
20e205aa0a
commit
f26c149d7d
6 changed files with 122 additions and 0 deletions
BIN
onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/nonzero_shape_setter.onnx
vendored
Normal file
Binary file not shown.
42
onnxruntime/test/testdata/transform/nonzero_shape_setter.py
vendored
Normal file
42
onnxruntime/test/testdata/transform/nonzero_shape_setter.py
vendored
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto, GraphProto, OperatorSetIdProto
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 4])
|
||||
Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
|
||||
Y.type.tensor_type.shape.Clear()
|
||||
|
||||
nonzero = helper.make_node('NonZero', ['input'], ['nonzero'], name='nonzero')
|
||||
transpose = helper.make_node('Transpose', ['nonzero'], ['transpose'], name='transpose', perm=[1,0])
|
||||
gathernd = helper.make_node('GatherND', ['input', 'transpose'], ['output'], name='gathernd')
|
||||
|
||||
# Create the graph (GraphProto)
|
||||
graph_def = helper.make_graph(
|
||||
[nonzero, transpose, gathernd],
|
||||
'nonzero_shape_setter_model',
|
||||
[X],
|
||||
[Y]
|
||||
)
|
||||
|
||||
opsets = []
|
||||
onnxdomain = OperatorSetIdProto()
|
||||
onnxdomain.version = 12
|
||||
onnxdomain.domain = "" # Empty string implies the operator set that is defined as part of the ONNX specification.
|
||||
opsets.append(onnxdomain)
|
||||
|
||||
msdomain = OperatorSetIdProto()
|
||||
msdomain.version = 1
|
||||
msdomain.domain = 'com.microsoft'
|
||||
|
||||
opsets.append(msdomain)
|
||||
kwargs={}
|
||||
kwargs['opset_imports'] = opsets
|
||||
|
||||
# Create the model (ModelProto)
|
||||
model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs)
|
||||
onnx.save(model_def, 'nonzero_shape_setter.onnx')
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "orttraining/core/optimizer/graph_transformer_utils.h"
|
||||
#include "orttraining/core/optimizer/insert_output_rewriter.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
#include "core/optimizer/identity_elimination.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/conv_mul_fusion.h"
|
||||
|
|
@ -58,6 +59,7 @@ std::vector<std::unique_ptr<GraphTransformer>> GeneratePreTrainingTransformers(T
|
|||
rule_transformer->Register(make_unique<UnsqueezeElimination>());
|
||||
rule_transformer->Register(make_unique<ExpandElimination>());
|
||||
rule_transformer->Register(make_unique<CastElimination>());
|
||||
rule_transformer->Register(make_unique<NonZeroShapeSetter>());
|
||||
rule_transformer->Register(make_unique<InsertSoftmaxCrossEntropyLossOutput>());
|
||||
|
||||
transformers.emplace_back(onnxruntime::make_unique<GeluFusion>(compatible_eps));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status NonZeroShapeSetter::Apply(Graph& /*graph*/,
|
||||
Node& node,
|
||||
RewriteRuleEffect& rule_effect,
|
||||
const logging::Logger& /*logger*/) const {
|
||||
// The output shape of the NonZero is [num_of_input_dims, dynamic_nonzero_element_counts].
|
||||
ONNX_NAMESPACE::TensorShapeProto result_shape;
|
||||
result_shape.add_dim()->set_dim_value(node.InputDefs()[0]->Shape()->dim_size());
|
||||
result_shape.add_dim()->set_dim_param(node.OutputDefs()[0]->Name() + "_nonzero_count");
|
||||
node.MutableOutputDefs()[0]->SetShape(result_shape);
|
||||
rule_effect = RewriteRuleEffect::kUpdatedCurrentNode;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool NonZeroShapeSetter::SatisfyCondition(const Graph& /*graph*/,
|
||||
const Node& node,
|
||||
const logging::Logger& /*logger*/) const {
|
||||
return node.InputDefs()[0]->Shape() != nullptr
|
||||
&& node.InputDefs()[0]->Shape()->dim_size() > 0
|
||||
&& node.OutputDefs()[0]->Shape() == nullptr;
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/optimizer/rewrite_rule.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
// Rewrite rule that set the output shape of NonZero Ops.
|
||||
class NonZeroShapeSetter : public RewriteRule {
|
||||
public:
|
||||
NonZeroShapeSetter() noexcept
|
||||
: RewriteRule("NonZeroShapeSetter") {
|
||||
}
|
||||
|
||||
std::vector<std::string> TargetOpTypes() const noexcept override {
|
||||
return {"NonZero"};
|
||||
}
|
||||
|
||||
private:
|
||||
bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override;
|
||||
|
||||
Status Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override;
|
||||
};
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -11,6 +11,7 @@
|
|||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/utils.h"
|
||||
#include "orttraining/core/optimizer/gist_encode_decode.h"
|
||||
#include "orttraining/core/optimizer/nonzero_shape_setter.h"
|
||||
#include "orttraining/core/optimizer/megatron_transformer.h"
|
||||
#include "test/optimizer/graph_transform_test_fixture.h"
|
||||
#include "test/util/include/default_providers.h"
|
||||
|
|
@ -58,6 +59,25 @@ Node* GetNodeByName(Graph& graph, std::string node_name) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
TEST_F(GraphTransformationTests, NonZeroShapeSetter) {
|
||||
auto model_uri = MODEL_FOLDER "nonzero_shape_setter.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
ASSERT_TRUE(Model::Load(model_uri, p_model, nullptr, *logger_).IsOK());
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("NonZeroShapeSetter1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<NonZeroShapeSetter>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{1};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
|
||||
auto ret = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_);
|
||||
ASSERT_TRUE(ret.IsOK());
|
||||
|
||||
auto nonzero_shape = GetNodeByName(graph, "nonzero")->OutputDefs()[0]->Shape();
|
||||
ASSERT_TRUE(nonzero_shape->dim_size() == 2);
|
||||
ASSERT_TRUE(nonzero_shape->dim(0).dim_value() == 2);
|
||||
ASSERT_TRUE(nonzero_shape->dim(1).dim_param() == "nonzero_nonzero_count");
|
||||
}
|
||||
|
||||
// MegatronF/G is defined only for training, and in msdomain.
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
|
|
|
|||
Loading…
Reference in a new issue