mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add support for opset 11 Clip in optimizers. (#2059)
This commit is contained in:
parent
a41c71cbf2
commit
ddbc2086e4
5 changed files with 375 additions and 32 deletions
|
|
@ -4,6 +4,7 @@
|
|||
#include <deque>
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -27,6 +28,65 @@ void HandleActivationNodeEdges(Graph& g, const Node& act, Node& fused_conv) {
|
|||
}
|
||||
}
|
||||
|
||||
// get min/max values from Clip if they are constant. Returns false if mutable and cannot be used
|
||||
static bool GetClipConstantMinMax(const Graph& graph, const Node& node, float& min, float& max) {
|
||||
min = std::numeric_limits<float>::lowest();
|
||||
max = std::numeric_limits<float>::max();
|
||||
|
||||
// Clip opset 6 has min and max as attributes. they're inputs from opset 11 on.
|
||||
bool min_max_are_attributes = graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {6});
|
||||
bool min_max_are_constant_values = true;
|
||||
|
||||
if (min_max_are_attributes) {
|
||||
min = graph_utils::GetNodeAttribute(node, "min")->f();
|
||||
max = graph_utils::GetNodeAttribute(node, "max")->f();
|
||||
} else {
|
||||
// update min/max if provided via a constant initializer
|
||||
// return true if value is default or coming from a constant initializer and update 'value'
|
||||
// return false if value is mutable
|
||||
auto update_if_constant_value = [&graph](const Node& node, size_t input_idx, float& value) {
|
||||
const auto& input_defs = node.InputDefs();
|
||||
const NodeArg* input = (input_defs.size() > input_idx) ? input_defs[input_idx] : nullptr;
|
||||
|
||||
if (input == nullptr || !input->Exists()) {
|
||||
// optional input not specified so using default value
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_constant = true;
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = graph_utils::GetConstantInitializer(graph, input->Name());
|
||||
if (initializer) {
|
||||
Initializer i(*initializer);
|
||||
switch (initializer->data_type()) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
value = *i.data<float>();
|
||||
break;
|
||||
// double isn't currently supported
|
||||
//case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
|
||||
// value = static_cast<float>(*i.data<double>());
|
||||
// break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
value = math::halfToFloat(i.data<BFloat16>()->val);
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unexpected data type for Clip input of ", initializer->data_type());
|
||||
}
|
||||
} else {
|
||||
is_constant = false;
|
||||
}
|
||||
|
||||
return is_constant;
|
||||
};
|
||||
|
||||
// 'min' is input 1, 'max' is input 2. both are optional.
|
||||
// if the input is constant, 'min' or 'max' is updated by the call to get_if_constant_value
|
||||
min_max_are_constant_values = update_if_constant_value(node, 1, min) &&
|
||||
update_if_constant_value(node, 2, max);
|
||||
}
|
||||
|
||||
return min_max_are_constant_values;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level) const {
|
||||
|
|
@ -57,9 +117,14 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l
|
|||
!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6})) {
|
||||
if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) {
|
||||
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f());
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6})) {
|
||||
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "min")->f());
|
||||
activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "max")->f());
|
||||
} else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11})) {
|
||||
float min, max;
|
||||
if (GetClipConstantMinMax(graph, next_node, min, max)) {
|
||||
activation_params.push_back(min);
|
||||
activation_params.push_back(max);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,21 +5,102 @@
|
|||
#include "core/graph/graph.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/op.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
||||
Status FuseReluClip::Apply(Graph& graph, Node& node, RewriteRuleEffect& rule_effect) const {
|
||||
// get the following Clip node before we delete the Relu node
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
|
||||
// Clip opset 6 has min and max as attributes. they're inputs from opset 11 on.
|
||||
bool min_max_are_attributes = graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6});
|
||||
|
||||
// if Clip had a min < 0 we need to replace that value with 0 to do what Relu would have done using just Clip
|
||||
bool replace_min = false;
|
||||
ONNX_NAMESPACE::TensorProto replacement_min;
|
||||
|
||||
if (min_max_are_attributes) {
|
||||
replace_min = graph_utils::GetNodeAttribute(next_node, "min")->f() < 0.f;
|
||||
} else {
|
||||
// we can fuse if the optional 'min' input is not provided, or if it is provided via a constant initializer
|
||||
const auto& clip_inputs = next_node.InputDefs();
|
||||
const NodeArg* min_input = (clip_inputs.size() > 1) ? clip_inputs[1] : nullptr;
|
||||
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
|
||||
|
||||
if (min_input && min_input->Exists()) {
|
||||
initializer = graph_utils::GetConstantInitializer(graph, min_input->Name());
|
||||
if (!initializer) {
|
||||
// non-const initializer. can't proceed
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
int32_t data_type;
|
||||
|
||||
if (!initializer) {
|
||||
// 'min' is using the default value of std::numeric_limits<>::lowest so we can fuse and provide a constant
|
||||
// value of '0' for 'min'
|
||||
|
||||
// we need to know the correct data type to create a valid initializer for the value 0.
|
||||
// get that from 'input' as that must match the 'min' input type.
|
||||
const auto* input_type = next_node.InputDefs()[0]->TypeAsProto();
|
||||
if (input_type == nullptr || !input_type->tensor_type().has_elem_type()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
data_type = input_type->tensor_type().elem_type();
|
||||
replace_min = true;
|
||||
} else {
|
||||
// 'min' is provided by a constant initializer so we can fuse.
|
||||
// see if we need to replace with an initializer with a value of 0
|
||||
|
||||
data_type = initializer->data_type();
|
||||
// construct an initializer to gracefully handle typed or raw data in the TensorProto
|
||||
Initializer i(*initializer);
|
||||
switch (data_type) {
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
|
||||
if (*i.data<float>() < 0.f) {
|
||||
replace_min = true;
|
||||
}
|
||||
break;
|
||||
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
|
||||
if (math::halfToFloat(i.data<MLFloat16>()->val) < 0.f) {
|
||||
replace_min = true;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
ORT_THROW("Unexpected data type for Clip 'min' input of ", initializer->data_type());
|
||||
}
|
||||
}
|
||||
|
||||
if (replace_min) {
|
||||
// create a new TensorProto with value of 0 and unique name in replacement_min
|
||||
auto new_name = graph.GenerateNodeArgName("FuseReluClip_" + node.Name() + "_min_zero_constant");
|
||||
Initializer(static_cast<ONNX_NAMESPACE::TensorProto::DataType>(data_type), new_name, {})
|
||||
.ToProto(replacement_min);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove the Relu node, and update the following Clip node if the 'min' is < 0.f, to set it to 0.f.
|
||||
// This essentially fuses the Relu and Clip. If the Clip 'min' is >= 0.f no change is required to the Clip node
|
||||
// as Relu would have set a lower min of 0.f.
|
||||
if (graph_utils::RemoveNode(graph, node)) {
|
||||
// update the following Clip node if the 'min' is < 0.f to set it to 0.f
|
||||
// this essentially fuses the Relu and Clip
|
||||
// if the Clip 'min' is >= 0.f no change is required as Relu would have set the min to 0.f
|
||||
if (graph_utils::GetNodeAttribute(next_node, "min")->f() < 0.f) {
|
||||
if (replace_min) {
|
||||
auto* mutable_next_node = graph.GetNode(next_node.Index());
|
||||
mutable_next_node->ClearAttribute("min");
|
||||
mutable_next_node->AddAttribute("min", 0.f);
|
||||
if (min_max_are_attributes) {
|
||||
mutable_next_node->ClearAttribute("min");
|
||||
mutable_next_node->AddAttribute("min", 0.f);
|
||||
} else {
|
||||
graph.AddInitializedTensor(replacement_min);
|
||||
auto& mutable_input_defs = mutable_next_node->MutableInputDefs();
|
||||
NodeArg* replacement_min_nodearg = graph.GetNodeArg(replacement_min.name());
|
||||
if (mutable_input_defs.size() == 1) { // Clip node only has the required 'input' so add optional 'min' input
|
||||
mutable_input_defs.push_back(replacement_min_nodearg);
|
||||
mutable_next_node->MutableInputArgsCount().push_back(1);
|
||||
} else {
|
||||
mutable_input_defs[1] = graph.GetNodeArg(replacement_min.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rule_effect = RewriteRuleEffect::kRemovedCurrentNode;
|
||||
|
|
@ -42,7 +123,7 @@ bool FuseReluClip::SatisfyCondition(const Graph& graph, const Node& node) const
|
|||
// as Clip will apply the minimum. If the Clip 'min' value is < 0 we need
|
||||
// to update it to 0 to apply what the Relu would have done. We do that in Apply.
|
||||
const auto& next_node = *node.OutputNodesBegin();
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6}) ||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11}) ||
|
||||
next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,34 +2,36 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/session/inference_session.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/graph/graph_utils.h"
|
||||
#include "core/graph/graph_viewer.h"
|
||||
#include "core/graph/model.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/graph_transformer_mgr.h"
|
||||
#include "core/optimizer/identity_elimination.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/conv_bn_fusion.h"
|
||||
#include "core/optimizer/conv_mul_fusion.h"
|
||||
#include "core/optimizer/conv_add_fusion.h"
|
||||
#include "core/optimizer/conv_activation_fusion.h"
|
||||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
#include "core/optimizer/dropout_elimination.h"
|
||||
#include "core/optimizer/gemm_activation_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/framework/data_types.h"
|
||||
#include "core/framework/ml_value.h"
|
||||
#include "core/util/math.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/capturing_sink.h"
|
||||
#include "test/test_environment.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/constant_folding.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/gelu_fusion.h"
|
||||
#include "core/optimizer/graph_transformer.h"
|
||||
#include "core/optimizer/graph_transformer_mgr.h"
|
||||
#include "core/optimizer/identity_elimination.h"
|
||||
#include "core/optimizer/initializer.h"
|
||||
#include "core/optimizer/matmul_add_fusion.h"
|
||||
#include "core/optimizer/relu_clip_fusion.h"
|
||||
#include "core/optimizer/rule_based_graph_transformer.h"
|
||||
#include "core/optimizer/shape_to_initializer.h"
|
||||
#include "core/optimizer/slice_elimination.h"
|
||||
#include "core/optimizer/unsqueeze_elimination.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/util/math.h"
|
||||
#include "test/capturing_sink.h"
|
||||
#include "test/framework/test_utils.h"
|
||||
#include "test/test_environment.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace ONNX_NAMESPACE;
|
||||
|
|
@ -312,6 +314,44 @@ TEST(GraphTransformationTests, FuseConvActivation) {
|
|||
ASSERT_TRUE(op_to_count[model.second] == 0);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvClip11Activation) {
|
||||
std::string model_uri = MODEL_FOLDER + "fusion/conv_clip11.onnx";
|
||||
std::shared_ptr<Model> p_model;
|
||||
auto status = Model::Load(model_uri, p_model);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
Graph& graph = p_model->MainGraph();
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Clip"], 3);
|
||||
|
||||
// Apply transformer
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(onnxruntime::make_unique<ConvActivationFusion>(), TransformerLevel::Level2);
|
||||
ASSERT_TRUE(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2).IsOK());
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_EQ(op_to_count["Clip"], 1);
|
||||
|
||||
for (const Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Conv") {
|
||||
EXPECT_TRUE(node.Name() == "Conv1") << "Conv1 should not have been fused as 'min' input to Clip was mutable.";
|
||||
}
|
||||
|
||||
if (node.OpType() == "FusedConv") {
|
||||
const ONNX_NAMESPACE::AttributeProto& attr_proto = node.GetAttributes().at("activation_params");
|
||||
const auto& params = attr_proto.floats();
|
||||
// check expected values for each. Conv0 is explicitly specified. Conv2 are defaults
|
||||
if (node.Name() == "Conv0") {
|
||||
EXPECT_TRUE(params.Get(0) == -1.f);
|
||||
EXPECT_TRUE(params.Get(1) == 1.f);
|
||||
} else if (node.Name() == "Conv2") {
|
||||
EXPECT_TRUE(params.Get(0) == std::numeric_limits<float>::lowest());
|
||||
EXPECT_TRUE(params.Get(1) == std::numeric_limits<float>::max());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(GraphTransformationTests, FuseConvMulNoBias) {
|
||||
|
|
@ -537,10 +577,10 @@ TEST(GraphTransformationTests, FuseConvBnAddMulFloat16) {
|
|||
ASSERT_EQ(expected_values_prod, found);
|
||||
}
|
||||
|
||||
TEST(GraphTransformationTests, ReluClipFusion) {
|
||||
TEST(GraphTransformationTests, ReluClip6Fusion) {
|
||||
// Clip op schema changed for opset version 11. Until Clip op is updated in ORT hard coding this model to use
|
||||
// older opset.
|
||||
Model model("ReluClipFusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {});
|
||||
Model model("ReluClip6Fusion", true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 10}}, {});
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
|
|
@ -606,6 +646,126 @@ TEST(GraphTransformationTests, ReluClipFusion) {
|
|||
}
|
||||
}
|
||||
|
||||
// test handling of Clip 11
|
||||
TEST(GraphTransformationTests, ReluClip11Fusion) {
|
||||
Model model("ReluClip6Fusion"); //, true, ModelMetaData(), IOnnxRuntimeOpSchemaRegistryList(), {{"", 11}}, {});
|
||||
auto& graph = model.MainGraph();
|
||||
|
||||
std::vector<NodeArg*> inputs;
|
||||
std::vector<NodeArg*> outputs;
|
||||
|
||||
TypeProto input_tensor_type;
|
||||
input_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
|
||||
input_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
TypeProto float16_tensor_type;
|
||||
float16_tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT16);
|
||||
float16_tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
|
||||
|
||||
// 4 paths in the model, each with Relu followed by Clip to test different aspects of Clip 11 handling
|
||||
// One has a Clip with mutable 'min' (don't fuse)
|
||||
// One has a Clip with constant 'min' < 0 (fuse and update 'min')
|
||||
// One has a Clip with constant 'min' > 0 (fuse and leave 'min')
|
||||
// One has a Clip with no 'min' (fuse and update to set min to 0 using type info from 'input')
|
||||
auto& input0 = graph.GetOrCreateNodeArg("input_0", &input_tensor_type);
|
||||
auto& input1 = graph.GetOrCreateNodeArg("input_1", &float16_tensor_type);
|
||||
auto& input2 = graph.GetOrCreateNodeArg("input_2", &input_tensor_type);
|
||||
auto& input3 = graph.GetOrCreateNodeArg("input_3", &input_tensor_type);
|
||||
|
||||
auto& min_input_0 = graph.GetOrCreateNodeArg("min_input_0", &input_tensor_type);
|
||||
auto& min_input_1 = graph.GetOrCreateNodeArg("min_input_1", &float16_tensor_type);
|
||||
auto& min_input_2 = graph.GetOrCreateNodeArg("min_input_2", &input_tensor_type);
|
||||
|
||||
// add initializer for min_input_1 so it's constant
|
||||
TensorProto const_min_1;
|
||||
Initializer i1(TensorProto_DataType_FLOAT16, "min_input_1", {1});
|
||||
i1.data<MLFloat16>()->val = math::floatToHalf(-1.f);
|
||||
i1.ToProto(const_min_1);
|
||||
graph.AddInitializedTensor(const_min_1);
|
||||
|
||||
TensorProto const_min_2;
|
||||
Initializer i2(TensorProto_DataType_FLOAT, "min_input_2", {1});
|
||||
*i2.data<float>() = 1.f;
|
||||
i2.ToProto(const_min_2);
|
||||
graph.AddInitializedTensor(const_min_2);
|
||||
|
||||
auto& relu0_output = graph.GetOrCreateNodeArg("relu0_output", &input_tensor_type);
|
||||
auto& relu1_output = graph.GetOrCreateNodeArg("relu1_output", &float16_tensor_type);
|
||||
auto& relu2_output = graph.GetOrCreateNodeArg("relu2_output", &input_tensor_type);
|
||||
auto& relu3_output = graph.GetOrCreateNodeArg("relu3_output", &input_tensor_type);
|
||||
|
||||
auto& clip0_output = graph.GetOrCreateNodeArg("clip0_output", &input_tensor_type);
|
||||
auto& clip1_output = graph.GetOrCreateNodeArg("clip1_output", &float16_tensor_type);
|
||||
auto& clip2_output = graph.GetOrCreateNodeArg("clip2_output", &input_tensor_type);
|
||||
auto& clip3_output = graph.GetOrCreateNodeArg("clip3_output", &input_tensor_type);
|
||||
|
||||
graph.AddNode("relu0", "Relu", "Relu0", {&input0}, {&relu0_output});
|
||||
graph.AddNode("relu1", "Relu", "Relu1", {&input1}, {&relu1_output});
|
||||
graph.AddNode("relu2", "Relu", "Relu2", {&input2}, {&relu2_output});
|
||||
graph.AddNode("relu3", "Relu", "Relu3", {&input3}, {&relu3_output});
|
||||
|
||||
auto& clip0 = graph.AddNode("clip0", "Clip", "Clip with mutable min", {&relu0_output, &min_input_0}, {&clip0_output});
|
||||
auto& clip1 = graph.AddNode("clip1", "Clip", "Clip with constant min < 0", {&relu1_output, &min_input_1}, {&clip1_output});
|
||||
auto& clip2 = graph.AddNode("clip2", "Clip", "Clip with constant min > 0", {&relu2_output, &min_input_2}, {&clip2_output});
|
||||
auto& clip3 = graph.AddNode("clip3", "Clip", "Clip with no min", {&relu3_output}, {&clip3_output});
|
||||
|
||||
graph.SetInputs({&input0, &input1, &input2, &input3, &min_input_0});
|
||||
auto status = graph.Resolve();
|
||||
EXPECT_EQ(status, Status::OK()) << status;
|
||||
|
||||
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 4);
|
||||
|
||||
auto rule_transformer_L1 = onnxruntime::make_unique<RuleBasedGraphTransformer>("RuleTransformer1");
|
||||
rule_transformer_L1->Register(onnxruntime::make_unique<FuseReluClip>());
|
||||
onnxruntime::GraphTransformerManager graph_transformation_mgr{5};
|
||||
graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1);
|
||||
status = graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1);
|
||||
ASSERT_TRUE(status.IsOK()) << status;
|
||||
|
||||
op_to_count = CountOpsInGraph(graph);
|
||||
ASSERT_TRUE(op_to_count["Relu"] == 1) << "All except the first Relu should have been fused";
|
||||
|
||||
for (const Node& node : graph.Nodes()) {
|
||||
if (node.OpType() == "Relu") {
|
||||
EXPECT_TRUE(node.Name() == "relu0") << "relu0 should be the only Relu node left";
|
||||
}
|
||||
|
||||
if (node.OpType() == "Clip") {
|
||||
auto* min_input = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name());
|
||||
|
||||
if (&node == &clip0) {
|
||||
EXPECT_TRUE(min_input == nullptr) << "clip0 should not have been fused as min_input_0 is not constant";
|
||||
} else {
|
||||
EXPECT_TRUE(min_input != nullptr)
|
||||
<< node.Name() << " should have been fused and have a constant initializer for 'min'";
|
||||
|
||||
auto type = min_input->data_type();
|
||||
|
||||
if (&node == &clip1) {
|
||||
// fusion with float16 data and min set to 0
|
||||
EXPECT_EQ(type, ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT16);
|
||||
MLFloat16 value = *Initializer(*min_input).data<MLFloat16>();
|
||||
EXPECT_EQ(math::halfToFloat(value.val), 0.f) << "Min was not 0.f. Got:" << math::halfToFloat(value.val);
|
||||
} else if (&node == &clip2) {
|
||||
// fusion with float data and min untouched
|
||||
EXPECT_EQ(type, ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT);
|
||||
float value = *Initializer(*min_input).data<float>();
|
||||
EXPECT_EQ(value, 1.0) << "Min should have remained unchanged but is now " << value;
|
||||
} else if (&node == &clip3) {
|
||||
// fusion with no min so type comes from input
|
||||
EXPECT_EQ(type, ONNX_NAMESPACE::TensorProto::DataType::TensorProto_DataType_FLOAT);
|
||||
float value = *Initializer(*min_input).data<float>();
|
||||
EXPECT_EQ(value, 0.f) << "Min was not 0.f. Got:" << value;
|
||||
|
||||
} else {
|
||||
EXPECT_TRUE(false) << "Unexpected node " << node.Name();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef DISABLE_CONTRIB_OPS
|
||||
TEST(GraphTransformationTests, GeluFusionTest) {
|
||||
string model_uri = MODEL_FOLDER + "fusion/gelu.onnx";
|
||||
|
|
|
|||
BIN
onnxruntime/test/testdata/transform/fusion/conv_clip11.onnx
vendored
Normal file
BIN
onnxruntime/test/testdata/transform/fusion/conv_clip11.onnx
vendored
Normal file
Binary file not shown.
37
onnxruntime/test/testdata/transform/fusion/create_conv_clip11.py
vendored
Normal file
37
onnxruntime/test/testdata/transform/fusion/create_conv_clip11.py
vendored
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
|
||||
graph = helper.make_graph(
|
||||
[ # nodes
|
||||
# fusable, const_min_negative should be replaced
|
||||
helper.make_node("Conv", ["X", "W"], ["conv0_out"], "Conv0"),
|
||||
helper.make_node("Clip", ["conv0_out", "const_min", "const_max"], ["clip0_out"], "Clip0"),
|
||||
|
||||
# mutable input. no fusion.
|
||||
helper.make_node("Conv", ["X", "W"], ["conv1_out"], "Conv1"),
|
||||
helper.make_node("Clip", ["conv1_out", "mutable_min", "const_max"], ["clip1_out"], "Clip1"),
|
||||
|
||||
# fusabled. default min/max.
|
||||
helper.make_node("Conv", ["X", "W"], ["conv2_out"], "Conv2"),
|
||||
helper.make_node("Clip", ["conv2_out"], ["clip2_out"], "Clip2"),
|
||||
],
|
||||
"ConvClipFusion", #name
|
||||
[ # inputs
|
||||
helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 7]),
|
||||
helper.make_tensor_value_info('W', TensorProto.FLOAT, [1, 1, 1]),
|
||||
helper.make_tensor_value_info('mutable_min', TensorProto.FLOAT, [1]),
|
||||
],
|
||||
[ # outputs
|
||||
helper.make_tensor_value_info('clip0_out', TensorProto.FLOAT, None),
|
||||
helper.make_tensor_value_info('clip1_out', TensorProto.FLOAT, None),
|
||||
helper.make_tensor_value_info('clip2_out', TensorProto.FLOAT, None),
|
||||
],
|
||||
[ # initializers
|
||||
helper.make_tensor('const_min', TensorProto.FLOAT, [1], [-1.0]),
|
||||
helper.make_tensor('const_max', TensorProto.FLOAT, [1], [10.0])
|
||||
]
|
||||
)
|
||||
|
||||
model = helper.make_model(graph)
|
||||
onnx.save(model, r'conv_clip11.onnx')
|
||||
Loading…
Reference in a new issue